Skip to content

Commit d1a9255

Browse files
committed
KAFKA-19561: Set OP_WRITE interest after SASL reauthentication to resume pending writes (apache#20258)
https://issues.apache.org/jira/browse/KAFKA-19561 Addresses a race condition during SASL reauthentication where the server-side `KafkaChannel.send()` queues a response, but OP_WRITE is removed before the channel becomes writable — resulting in stuck responses and client timeouts. Reviewers: Rajini Sivaram <[email protected]>
1 parent b7b2676 commit d1a9255

File tree

3 files changed

+85
-1
lines changed

3 files changed

+85
-1
lines changed

clients/src/main/java/org/apache/kafka/common/network/KafkaChannel.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -681,4 +681,14 @@ private void swapAuthenticatorsAndBeginReauthentication(ReauthenticationContext
681681
public ChannelMetadataRegistry channelMetadataRegistry() {
682682
return metadataRegistry;
683683
}
684+
685+
686+
/**
687+
* Maybe add write interest after re-authentication. This is to ensure that any pending write operation
688+
* is resumed.
689+
*/
690+
public void maybeAddWriteInterestAfterReauth() {
691+
if (send != null)
692+
this.transportLayer.addInterestOps(SelectionKey.OP_WRITE);
693+
}
684694
}

clients/src/main/java/org/apache/kafka/common/network/Selector.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -566,6 +566,7 @@ void pollSelectionKeys(Set<SelectionKey> selectionKeys,
566566
boolean isReauthentication = channel.successfulAuthentications() > 1;
567567
if (isReauthentication) {
568568
sensors.successfulReauthentication.record(1.0, readyTimeMs);
569+
channel.maybeAddWriteInterestAfterReauth();
569570
if (channel.reauthenticationLatencyMs() == null)
570571
log.warn(
571572
"Should never happen: re-authentication latency for a re-authenticated channel was null; continuing...");

clients/src/test/java/org/apache/kafka/common/security/authenticator/SaslAuthenticatorTest.java

Lines changed: 74 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
import org.apache.kafka.common.network.ChannelState;
4545
import org.apache.kafka.common.network.ConnectionMode;
4646
import org.apache.kafka.common.network.ListenerName;
47+
import org.apache.kafka.common.network.NetworkReceive;
4748
import org.apache.kafka.common.network.NetworkSend;
4849
import org.apache.kafka.common.network.NetworkTestUtils;
4950
import org.apache.kafka.common.network.NioEchoServer;
@@ -119,6 +120,7 @@
119120
import java.util.Map;
120121
import java.util.Random;
121122
import java.util.Set;
123+
import java.util.concurrent.Semaphore;
122124
import java.util.concurrent.atomic.AtomicInteger;
123125
import java.util.function.Function;
124126
import java.util.stream.Collectors;
@@ -1856,6 +1858,69 @@ public void testSslClientAuthRequiredOverriddenForSaslSslListener() throws Excep
18561858
verifySslClientAuthForSaslSslListener(false, SslClientAuth.REQUIRED);
18571859
}
18581860

1861+
@Test
1862+
public void testServerSidePendingSendDuringReauthentication() throws Exception {
1863+
SecurityProtocol securityProtocol = SecurityProtocol.SASL_PLAINTEXT;
1864+
TestJaasConfig jaasConfig = configureMechanisms("PLAIN", Collections.singletonList("PLAIN"));
1865+
jaasConfig.createOrUpdateEntry(TestJaasConfig.LOGIN_CONTEXT_SERVER, PlainLoginModule.class.getName(), new HashMap<>());
1866+
jaasConfig.setClientOptions("PLAIN", TestServerCallbackHandler.USERNAME, TestServerCallbackHandler.PASSWORD);
1867+
String callbackPrefix = ListenerName.forSecurityProtocol(securityProtocol).saslMechanismConfigPrefix("PLAIN");
1868+
saslServerConfigs.put(callbackPrefix + BrokerSecurityConfigs.SASL_SERVER_CALLBACK_HANDLER_CLASS_CONFIG,
1869+
TestServerCallbackHandler.class.getName());
1870+
server = createEchoServer(securityProtocol);
1871+
1872+
String node = "node1";
1873+
try {
1874+
createClientConnection(securityProtocol, node);
1875+
NetworkTestUtils.waitForChannelReady(selector, node);
1876+
server.verifyAuthenticationMetrics(1, 0);
1877+
1878+
/*
1879+
* Now start the reauthentication on the connection. First, we have to sleep long enough so
1880+
* that the next write will cause re-authentication
1881+
*/
1882+
delay((long) (CONNECTIONS_MAX_REAUTH_MS_VALUE * 1.1));
1883+
server.verifyReauthenticationMetrics(0, 0);
1884+
1885+
// block reauthentication to complete
1886+
TestServerCallbackHandler.sem.acquire();
1887+
1888+
String prefix = TestUtils.randomString(100);
1889+
// send a client request to start a reauthentication.
1890+
selector.send(new NetworkSend(node, ByteBufferSend.sizePrefixed(ByteBuffer.wrap((prefix + "-0").getBytes(StandardCharsets.UTF_8)))));
1891+
// wait till reauthentication is blocked
1892+
TestUtils.waitForCondition(() -> {
1893+
selector.poll(10L);
1894+
return TestServerCallbackHandler.sem.hasQueuedThreads();
1895+
}, 5000, "Reauthentication is not blocked");
1896+
1897+
// Set the client's channel `send` to null to allow setting a new send on the server's selector.
1898+
// Without this, NioEchoServer will throw an error while processing the client request,
1899+
// since we're manually setting a server side send to simulate the issue.
1900+
TestUtils.setFieldValue(selector.channel(node), "send", null);
1901+
1902+
// extract the channel id from the server's selector and directly set a send on it.
1903+
String channelId = server.selector().channels().get(0).id();
1904+
String payload = prefix + "-1";
1905+
server.selector().send(new NetworkSend(channelId, ByteBufferSend.sizePrefixed(ByteBuffer.wrap(payload.getBytes(StandardCharsets.UTF_8)))));
1906+
// allow reauthentication to complete
1907+
TestServerCallbackHandler.sem.release();
1908+
1909+
TestUtils.waitForCondition(() -> {
1910+
selector.poll(10L);
1911+
for (NetworkReceive receive : selector.completedReceives()) {
1912+
assertEquals(payload, new String(Utils.toArray(receive.payload()), StandardCharsets.UTF_8));
1913+
return true;
1914+
}
1915+
return false;
1916+
}, 5000, "Failed Receive the server send after reauthentication");
1917+
1918+
server.verifyReauthenticationMetrics(1, 0);
1919+
} finally {
1920+
closeClientConnectionIfNecessary();
1921+
}
1922+
}
1923+
18591924
private void verifySslClientAuthForSaslSslListener(boolean useListenerPrefix,
18601925
SslClientAuth configuredClientAuth) throws Exception {
18611926

@@ -2312,6 +2377,7 @@ public static class TestServerCallbackHandler extends PlainServerCallbackHandler
23122377
static final String USERNAME = "TestServerCallbackHandler-user";
23132378
static final String PASSWORD = "TestServerCallbackHandler-password";
23142379
private volatile boolean configured;
2380+
public static Semaphore sem = new Semaphore(1);
23152381

23162382
@Override
23172383
public void configure(Map<String, ?> configs, String mechanism, List<AppConfigurationEntry> jaasConfigEntries) {
@@ -2325,7 +2391,14 @@ public void configure(Map<String, ?> configs, String mechanism, List<AppConfigur
23252391
protected boolean authenticate(String username, char[] password) {
23262392
if (!configured)
23272393
throw new IllegalStateException("Server callback handler not configured");
2328-
return USERNAME.equals(username) && new String(password).equals(PASSWORD);
2394+
try {
2395+
sem.acquire();
2396+
return USERNAME.equals(username) && new String(password).equals(PASSWORD);
2397+
} catch (InterruptedException e) {
2398+
throw new RuntimeException(e);
2399+
} finally {
2400+
sem.release();
2401+
}
23292402
}
23302403
}
23312404

0 commit comments

Comments
 (0)