Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@
public record Request(ApiKeys apiKeys,
short apiVersion,
String clientIdHeader,
ApiMessage message) {

ApiMessage message,
short responseApiVersion) {
public Request(ApiKeys apiKeys,
short apiVersion,
String clientIdHeader,
ApiMessage message) {
this(apiKeys, apiVersion, clientIdHeader, message, apiVersion);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,16 @@ public CorrelationManager() {
/**
* Allocate and return a correlation id for an outgoing request to the broker.
*
* @param apiKey The API key.
* @param apiVersion The API version.
* @param correlationId The request's correlation id.
* @param apiKey The API key.
* @param apiVersion The API version.
* @param correlationId The request's correlation id.
* @param responseFuture The future to complete with the response
* @param responseApiVersion
*/
public void putBrokerRequest(short apiKey,
short apiVersion,
int correlationId, CompletableFuture<SequencedResponse> responseFuture) {
Correlation existing = this.brokerRequests.put(correlationId, new Correlation(apiKey, apiVersion, responseFuture));
int correlationId, CompletableFuture<SequencedResponse> responseFuture, short responseApiVersion) {
Correlation existing = this.brokerRequests.put(correlationId, new Correlation(apiKey, apiVersion, responseFuture, responseApiVersion));
if (existing != null) {
LOGGER.error("Duplicate upstream correlation id {}", correlationId);
}
Expand Down Expand Up @@ -72,13 +73,15 @@ public void onChannelClose() {
*/
public record Correlation(short apiKey,
short apiVersion,
CompletableFuture<SequencedResponse> responseFuture) {
CompletableFuture<SequencedResponse> responseFuture,
short responseApiVersion) {

@Override
public String toString() {
return "Correlation(" +
"apiKey=" + ApiKeys.forId(apiKey) +
", apiVersion=" + apiVersion +
", responseApiVersion=" + responseApiVersion +
')';
}

Expand All @@ -91,12 +94,12 @@ public boolean equals(Object o) {
return false;
}
Correlation that = (Correlation) o;
return apiKey == that.apiKey && apiVersion == that.apiVersion;
return apiKey == that.apiKey && apiVersion == that.apiVersion && responseApiVersion == that.responseApiVersion;
}

@Override
public int hashCode() {
return Objects.hash(apiKey, apiVersion);
return Objects.hash(apiKey, apiVersion, responseApiVersion);
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ private static DecodedRequestFrame<?> toApiRequest(Request request) {
var header = new RequestHeaderData().setRequestApiKey(messageType.apiKey()).setRequestApiVersion(request.apiVersion());
header.setClientId(request.clientIdHeader());
header.setCorrelationId(correlationId.incrementAndGet());
return new DecodedRequestFrame<>(header.requestApiVersion(), header.correlationId(), header, request.message());
return new DecodedRequestFrame<>(header.requestApiVersion(), header.correlationId(), header, request.message(), request.responseApiVersion());
}

// TODO return a Response class with jsonObject() and frame() methods
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,24 @@ public class DecodedRequestFrame<B extends ApiMessage>
implements RequestFrame {

private final CompletableFuture<SequencedResponse> responseFuture = new CompletableFuture<>();
private final short responseApiVersion;

/**
* Create a decoded request frame
*
* @param apiVersion apiVersion
* @param correlationId correlationId
* @param header header
* @param body body
* @param responseApiVersion
*/
public DecodedRequestFrame(short apiVersion,
int correlationId,
RequestHeaderData header,
B body) {
B body,
short responseApiVersion) {
super(apiVersion, correlationId, header, body);
this.responseApiVersion = responseApiVersion;
}

@Override
Expand All @@ -54,4 +59,9 @@ public CompletableFuture<SequencedResponse> getResponseFuture() {
public boolean hasResponse() {
return !(body instanceof ProduceRequest pr && pr.acks() == 0);
}

@Override
public short responseApiVersion() {
return responseApiVersion;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ protected Frame decodeHeaderAndBody(ChannelHandlerContext ctx, ByteBuf in, final
log().trace("{}: body {}", ctx, body);
}

frame = new DecodedRequestFrame<>(apiVersion, correlationId, header, body);
frame = new DecodedRequestFrame<>(apiVersion, correlationId, header, body, apiVersion);
if (log().isTraceEnabled()) {
log().trace("{}: frame {}", ctx, frame);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ protected Logger log() {
protected void encode(ChannelHandlerContext ctx, RequestFrame frame, ByteBuf out) throws Exception {
super.encode(ctx, frame, out);
if (frame.hasResponse()) {
correlationManager.putBrokerRequest(frame.apiKey().id, frame.apiVersion(), frame.correlationId(), frame.getResponseFuture());
correlationManager.putBrokerRequest(frame.apiKey().id, frame.apiVersion(), frame.correlationId(), frame.getResponseFuture(), frame.responseApiVersion());
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,20 +55,28 @@ protected Frame decodeHeaderAndBody(ChannelHandlerContext ctx, ByteBuf in, int l
else if (LOGGER.isDebugEnabled()) {
LOGGER.debug("{}: Recovered correlation {} for upstream correlation id {}", ctx, correlation, correlationId);
}

final DecodedResponseFrame<?> frame;
ApiKeys apiKey = ApiKeys.forId(correlation.apiKey());
short apiVersion = correlation.apiVersion();
var accessor = new ByteBufAccessorImpl(in);
short headerVersion = apiKey.responseHeaderVersion(apiVersion);
log().trace("{}: Header version: {}", ctx, headerVersion);
ResponseHeaderData header = readHeader(headerVersion, accessor);
log().trace("{}: Header: {}", ctx, header);
ApiMessage body = BodyDecoder.decodeResponse(apiKey, apiVersion, accessor);
log().trace("{}: Body: {}", ctx, body);
frame = new DecodedResponseFrame<>(apiVersion, correlationId, header, body);
correlation.responseFuture().complete(new SequencedResponse(frame, i++));
return frame;
try {
final DecodedResponseFrame<?> frame;
ApiKeys apiKey = ApiKeys.forId(correlation.apiKey());
short apiVersion = correlation.responseApiVersion();
var accessor = new ByteBufAccessorImpl(in);
short headerVersion = apiKey.responseHeaderVersion(apiVersion);
log().trace("{}: Header version: {}", ctx, headerVersion);
ResponseHeaderData header = readHeader(headerVersion, accessor);
log().trace("{}: Header: {}", ctx, header);
ApiMessage body = BodyDecoder.decodeResponse(apiKey, apiVersion, accessor);
log().trace("{}: Body: {}", ctx, body);
frame = new DecodedResponseFrame<>(apiVersion, correlationId, header, body);
if (in.readableBytes() != 0) {
throw new RuntimeException("Unread bytes remaining in frame, potentially response api version differs from expectation");
}
correlation.responseFuture().complete(new SequencedResponse(frame, i++));
return frame;
}
catch (Exception e) {
correlation.responseFuture().completeExceptionally(e);
throw e;
}
}

private ResponseHeaderData readHeader(short headerVersion, Readable accessor) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,8 @@ default boolean hasResponse() {
*/
short apiVersion();

default short responseApiVersion() {
return apiVersion();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import io.netty.channel.ChannelInboundHandlerAdapter;

import io.kroxylicious.test.Request;
import io.kroxylicious.test.ResponsePayload;
import io.kroxylicious.test.codec.DecodedRequestFrame;

/**
Expand All @@ -51,11 +52,12 @@ private record ConditionalMockResponse(Matcher<Request> matcher, Action action,

/**
* Create mockhandler with initial message to serve
* @param message message to respond with, nullable
* @param payload payload to respond with, nullable
*/
public MockHandler(ApiMessage message) {
if (message != null) {
setMockResponseForApiKey(ApiKeys.forId(message.apiKey()), message);
public MockHandler(ResponsePayload payload) {
if (payload != null && payload.message() != null) {
ApiMessage message = payload.message();
setMockResponseForApiKey(ApiKeys.forId(message.apiKey()), message, payload.responseApiVersion());
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ public int start(int port, ResponsePayload response, SslContext serverSslContext
final EventGroupConfig eventGroupConfig = EventGroupConfig.create();
bossGroup = eventGroupConfig.newBossGroup();
workerGroup = eventGroupConfig.newWorkerGroup();
serverHandler = new MockHandler(response == null ? null : response.message());
serverHandler = new MockHandler(response);
ServerBootstrap b = new ServerBootstrap();
b.group(bossGroup, workerGroup)
.channel(eventGroupConfig.serverChannelClass())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,27 +73,7 @@ public class BodyDecoder {
return switch (apiKey) {
<#list messageSpecs as messageSpec>
<#if messageSpec.type?lower_case == 'response'>
<#if messageSpec.name == 'ApiVersionsResponse'>
case ${retrieveApiKey(messageSpec)} -> {
// KIP-511 when the client receives an unsupported version for the ApiVersionResponse, it fails back to version 0
// Use the same algorithm as https://github.com/apache/kafka/blob/a41c10fd49841381b5207c184a385622094ed440/clients/src/main/java/org/apache/kafka/common/requests/ApiVersionsResponse.java#L90-L106
int prev = accessor.readerIndex();
try {
yield new ${messageSpec.name}Data(accessor, apiVersion);
}
catch (RuntimeException e) {
accessor.readerIndex(prev);
if (apiVersion != 0) {
yield new ${messageSpec.name}Data(accessor, (short) 0);
}
else {
throw e;
}
}
}
<#else>
case ${retrieveApiKey(messageSpec)} -> new ${messageSpec.name}Data(accessor, apiVersion);
</#if>
case ${retrieveApiKey(messageSpec)} -> new ${messageSpec.name}Data(accessor, apiVersion);
</#if>
</#list>
default -> throw new IllegalArgumentException("Unsupported RPC " + apiKey);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ void testPendingFuturesCompletedExceptionallyOnChannelClose() {
CorrelationManager correlationManager = new CorrelationManager();
CompletableFuture<SequencedResponse> responseFuture = new CompletableFuture<>();
CompletableFuture<SequencedResponse> responseFuture2 = new CompletableFuture<>();
correlationManager.putBrokerRequest((short) 1, (short) 1, 1, responseFuture);
correlationManager.putBrokerRequest((short) 1, (short) 1, 2, responseFuture2);
correlationManager.putBrokerRequest((short) 1, (short) 1, 1, responseFuture, (short) 1);
correlationManager.putBrokerRequest((short) 1, (short) 1, 2, responseFuture2, (short) 1);

// when
correlationManager.onChannelClose();
Expand All @@ -37,8 +37,8 @@ void testNullFutureToleratedOnChannelClose() {
// given
CorrelationManager correlationManager = new CorrelationManager();
CompletableFuture<SequencedResponse> responseFuture = new CompletableFuture<>();
correlationManager.putBrokerRequest((short) 1, (short) 1, 1, responseFuture);
correlationManager.putBrokerRequest((short) 1, (short) 1, 2, null);
correlationManager.putBrokerRequest((short) 1, (short) 1, 1, responseFuture, (short) 1);
correlationManager.putBrokerRequest((short) 1, (short) 1, 2, null, (short) 1);

// when
correlationManager.onChannelClose();
Expand All @@ -53,7 +53,7 @@ void testCorrelationRetrievableOnceOnly() {
int correlationId = 1;
CorrelationManager correlationManager = new CorrelationManager();
CompletableFuture<SequencedResponse> responseFuture = new CompletableFuture<>();
correlationManager.putBrokerRequest((short) 1, (short) 1, correlationId, responseFuture);
correlationManager.putBrokerRequest((short) 1, (short) 1, correlationId, responseFuture, (short) 1);
CorrelationManager.Correlation brokerCorrelation = correlationManager.getBrokerCorrelation(correlationId);
assertThat(brokerCorrelation).isNotNull();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,13 +88,13 @@ void testClientCanSendOpaqueFrame() {

// brokers can respond with a v0 response if they do not support the ApiVersions request version, see KIP-511
@Test
void testClientCanTolerateV0ApiVersionsResponseToHigherRequestVersion() {
void testClientCanHandleResponseApiVersionDifferentFromRequestApiVersion() {
ApiVersionsResponseData message = new ApiVersionsResponseData();
message.setErrorCode(Errors.UNSUPPORTED_VERSION.code());
ResponsePayload v0Payload = new ResponsePayload(ApiKeys.API_VERSIONS, (short) 0, message);
try (var mockServer = MockServer.startOnRandomPort(v0Payload);
var kafkaClient = new KafkaClient("localhost", mockServer.port())) {
CompletableFuture<Response> future = kafkaClient.get(new Request(ApiKeys.API_VERSIONS, (short) 0, "client", new ApiVersionsRequestData()));
CompletableFuture<Response> future = kafkaClient.get(new Request(ApiKeys.API_VERSIONS, (short) 3, "client", new ApiVersionsRequestData(), (short) 0));
assertThat(future).succeedsWithin(10, TimeUnit.SECONDS).satisfies(response -> {
assertThat(response.payload().message()).isInstanceOfSatisfying(ApiVersionsResponseData.class, apiVersionsRequestData -> {
assertThat(apiVersionsRequestData.errorCode()).isEqualTo(Errors.UNSUPPORTED_VERSION.code());
Expand All @@ -103,6 +103,34 @@ void testClientCanTolerateV0ApiVersionsResponseToHigherRequestVersion() {
}
}

@Test
void unreadBytesAfterFrameDecodeThrowsException() {
ApiVersionsResponseData message = new ApiVersionsResponseData();
message.setErrorCode(Errors.UNSUPPORTED_VERSION.code());
message.setThrottleTimeMs(22);
ResponsePayload v0Payload = new ResponsePayload(ApiKeys.API_VERSIONS, (short) 1, message);
try (var mockServer = MockServer.startOnRandomPort(v0Payload);
var kafkaClient = new KafkaClient("localhost", mockServer.port())) {
CompletableFuture<Response> future = kafkaClient.get(new Request(ApiKeys.API_VERSIONS, (short) 3, "client", new ApiVersionsRequestData(), (short) 0));
// fails to decode v0 response because client encodes v1, which is backwards compatible but emits additional bytes. We want to be as sure as we can be that it was a v0 response.
assertThat(future).failsWithin(10, TimeUnit.SECONDS).withThrowableThat()
.withMessageContaining("Unread bytes remaining in frame, potentially response api version differs from expectation");
}
}

@Test
void unexpectedResponseFormatTriggersFailure() {
ApiVersionsResponseData message = new ApiVersionsResponseData();
message.setErrorCode(Errors.UNSUPPORTED_VERSION.code());
ResponsePayload v0Payload = new ResponsePayload(ApiKeys.API_VERSIONS, (short) 0, message);
try (var mockServer = MockServer.startOnRandomPort(v0Payload);
var kafkaClient = new KafkaClient("localhost", mockServer.port())) {
CompletableFuture<Response> future = kafkaClient.get(new Request(ApiKeys.API_VERSIONS, (short) 3, "client", new ApiVersionsRequestData()));
// fails to decode v0 response because client expects v3
assertThat(future).failsWithin(10, TimeUnit.SECONDS).withThrowableThat().withMessageContaining("non-nullable field apiKeys was serialized as null");
}
}

@Test
void shouldWorkWithTls() throws Exception {
shouldWorkWithTls(SslContextBuilder.forClient()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ void shouldHandleVersionZeroErrorResponseWhenKroxyliciousIsAheadOfBroker() {
var client = tester.simpleTestClient()) {
short brokerMaxVersion = (short) (ApiKeys.API_VERSIONS.latestVersion() - 1);
givenMockRespondsWithDowngradedV0ApiVersionsResponse(tester, ApiKeys.API_VERSIONS, ApiKeys.API_VERSIONS.oldestVersion(), brokerMaxVersion);
Response response = whenGetApiVersionsFromKroxylicious(client);
Response response = whenGetApiVersionsFromKroxylicious(client, (short) 3, (short) 0);

assertKroxyliciousResponseOffersApiVersionsForApiKey(response, ApiKeys.API_VERSIONS, ApiKeys.API_VERSIONS.oldestVersion(),
brokerMaxVersion, Errors.UNSUPPORTED_VERSION.code());
Expand Down Expand Up @@ -179,17 +179,20 @@ private static void givenMockRespondsWithDowngradedV0ApiVersionsResponse(MockSer
version.setApiKey(keys.id).setMinVersion(minVersion).setMaxVersion(maxVersion);
mockResponse.apiKeys().add(version);
mockResponse.setErrorCode(Errors.UNSUPPORTED_VERSION.code());
tester.addMockResponseForApiKey(new ResponsePayload(ApiKeys.API_VERSIONS, (short) 3, mockResponse));
tester.addMockResponseForApiKey(new ResponsePayload(ApiKeys.API_VERSIONS, (short) 0, mockResponse));
}

private static Response whenGetApiVersionsFromKroxylicious(KafkaClient client) {
return client.getSync(new Request(ApiKeys.API_VERSIONS, (short) 3, "client", new ApiVersionsRequestData()));
return whenGetApiVersionsFromKroxylicious(client, (short) 3, (short) 3);
}

private static Response whenGetApiVersionsFromKroxylicious(KafkaClient client, short requestApiVersion, short responseApiVersion) {
return client.getSync(new Request(ApiKeys.API_VERSIONS, requestApiVersion, "client", new ApiVersionsRequestData(), responseApiVersion));
}

private static void assertKroxyliciousResponseOffersApiVersionsForApiKey(Response response, ApiKeys apiKeys, short minVersion, short maxVersion, short expected) {
ResponsePayload payload = response.payload();
assertEquals(ApiKeys.API_VERSIONS, payload.apiKeys());
assertEquals((short) 3, payload.apiVersion());
ApiVersionsResponseData message = (ApiVersionsResponseData) payload.message();
assertThat(message.errorCode()).isEqualTo(expected);
assertThat(message.apiKeys())
Expand Down
Loading
Loading