Skip to content

Commit 2235ef3

Browse files
authored
feat(OAuthBearerValidatior) OauthBearerValidationFilter notify the context on auth outcome (kroxylicious#2526)
1 parent 9761058 commit 2235ef3

File tree

6 files changed

+224
-70
lines changed

6 files changed

+224
-70
lines changed

kroxylicious-filter-test-support/src/main/java/io/kroxylicious/test/assertj/HeaderAssert.java

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,15 @@
1414
import org.assertj.core.api.AbstractStringAssert;
1515
import org.assertj.core.api.Assertions;
1616
import org.assertj.core.api.InstanceOfAssertFactories;
17+
import org.assertj.core.api.ThrowingConsumer;
1718

19+
@SuppressWarnings("UnusedReturnValue")
1820
public class HeaderAssert extends AbstractAssert<HeaderAssert, Header> {
21+
22+
private static final String VALUE_SUFFIX = "value";
23+
private static final String KEY_SUFFIX = "key";
24+
public static final String DESCRIBED_AS_PATTERN = "%s %s";
25+
1926
protected HeaderAssert(Header header) {
2027
super(header, HeaderAssert.class);
2128
describedAs(header == null ? "null header" : "header");
@@ -25,16 +32,18 @@ public static HeaderAssert assertThat(Header actual) {
2532
return new HeaderAssert(actual);
2633
}
2734

35+
@SuppressWarnings("java:S1452")
2836
private AbstractStringAssert<?> key() {
2937
var existingDescription = descriptionText();
3038
return Assertions.assertThat(actual.key())
31-
.describedAs(existingDescription + " key");
39+
.describedAs(DESCRIBED_AS_PATTERN, existingDescription, KEY_SUFFIX);
3240
}
3341

42+
@SuppressWarnings("java:S1452")
3443
public AbstractByteArrayAssert<?> value() {
3544
var existingDescription = descriptionText();
3645
return Assertions.assertThat(actual.value())
37-
.describedAs(existingDescription + " value");
46+
.describedAs(DESCRIBED_AS_PATTERN, existingDescription, VALUE_SUFFIX);
3847
}
3948

4049
public HeaderAssert hasKeyEqualTo(String expected) {
@@ -47,18 +56,13 @@ public HeaderAssert hasValueEqualTo(String expected) {
4756
isNotNull().value().isNull();
4857
}
4958
else {
50-
String existingDescription = descriptionText();
51-
isNotNull().value()
52-
.asInstanceOf(InstanceOfAssertFactories.BYTE_ARRAY)
53-
.asString(StandardCharsets.UTF_8)
54-
.as(existingDescription + " value")
55-
.isEqualTo(expected);
59+
hasStringValueSatisfying(val -> Assertions.assertThat(val).isEqualTo(expected));
5660
}
5761
return this;
5862
}
5963

6064
public HeaderAssert hasValueEqualTo(byte[] expected) {
61-
isNotNull().value().isEqualTo(expected);
65+
hasByteValueSatisfying(val -> Assertions.assertThat(val).isEqualTo(expected));
6266
return this;
6367
}
6468

@@ -67,4 +71,25 @@ public HeaderAssert hasNullValue() {
6771
return this;
6872
}
6973

74+
public HeaderAssert hasStringValueSatisfying(ThrowingConsumer<String> assertion) {
75+
String existingDescription = descriptionText();
76+
isNotNull().value()
77+
.asInstanceOf(InstanceOfAssertFactories.BYTE_ARRAY)
78+
.asString(StandardCharsets.UTF_8)
79+
.as(DESCRIBED_AS_PATTERN, existingDescription, VALUE_SUFFIX)
80+
.satisfies(assertion::accept);
81+
82+
return this;
83+
}
84+
85+
public HeaderAssert hasByteValueSatisfying(ThrowingConsumer<byte[]> assertion) {
86+
String existingDescription = descriptionText();
87+
isNotNull().value()
88+
.asInstanceOf(InstanceOfAssertFactories.BYTE_ARRAY)
89+
.as(DESCRIBED_AS_PATTERN, existingDescription, VALUE_SUFFIX)
90+
.satisfies(assertion::accept);
91+
92+
return this;
93+
}
94+
7095
}

kroxylicious-filter-test-support/src/test/java/io/kroxylicious/test/assertj/HeaderAssertTest.java

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,31 @@ void testHeaderHasValueEqualTo() {
5454
assertThrowsIfHeaderNull(nullAssert -> nullAssert.hasValueEqualTo("any"));
5555
}
5656

57+
@Test
58+
void testHeaderHasByteValue() {
59+
byte[] expectedBytes = "abc".getBytes(StandardCharsets.UTF_8);
60+
RecordHeader nonNullValue = new RecordHeader("foo", expectedBytes);
61+
HeaderAssert nonNullValueAssert = KafkaAssertions.assertThat(nonNullValue);
62+
63+
nonNullValueAssert.hasValueEqualTo(expectedBytes);
64+
throwsAssertionErrorContaining(() -> nonNullValueAssert.hasByteValueSatisfying(val -> org.assertj.core.api.Assertions.assertThat(val).isEmpty()),
65+
"[header value]");
66+
nonNullValueAssert.hasByteValueSatisfying(val -> org.assertj.core.api.Assertions.assertThat(val).isEqualTo(expectedBytes));
67+
}
68+
69+
@Test
70+
void testHeaderHasStringValue() {
71+
String expectedStr = "abc";
72+
byte[] expectedBytes = expectedStr.getBytes(StandardCharsets.UTF_8);
73+
RecordHeader nonNullValue = new RecordHeader("foo", expectedStr.getBytes(StandardCharsets.UTF_8));
74+
HeaderAssert nonNullValueAssert = KafkaAssertions.assertThat(nonNullValue);
75+
76+
nonNullValueAssert.hasValueEqualTo(expectedBytes);
77+
throwsAssertionErrorContaining(() -> nonNullValueAssert.hasStringValueSatisfying(val -> org.assertj.core.api.Assertions.assertThat(val).isEmpty()),
78+
"[header value]");
79+
nonNullValueAssert.hasStringValueSatisfying(val -> org.assertj.core.api.Assertions.assertThat(val).isEqualTo(expectedStr));
80+
}
81+
5782
void assertThrowsIfHeaderNull(ThrowingConsumer<HeaderAssert> action) {
5883
HeaderAssert headerAssert = KafkaAssertions.assertThat((RecordHeader) null);
5984
throwsAssertionErrorContaining(() -> action.accept(headerAssert), "[null header]");

kroxylicious-filters/kroxylicious-oauthbearer-validation/src/main/java/io/kroxylicious/proxy/filter/oauthbearer/OauthBearerValidationFilter.java

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@
1111
import java.security.NoSuchAlgorithmException;
1212
import java.time.Duration;
1313
import java.util.HexFormat;
14+
import java.util.Objects;
1415
import java.util.concurrent.CompletableFuture;
16+
import java.util.concurrent.CompletionException;
1517
import java.util.concurrent.CompletionStage;
1618
import java.util.concurrent.ScheduledExecutorService;
1719
import java.util.concurrent.TimeUnit;
@@ -63,13 +65,14 @@ public class OauthBearerValidationFilter
6365
SaslAuthenticateResponseFilter {
6466

6567
private static final Logger LOGGER = LoggerFactory.getLogger(OauthBearerValidationFilter.class);
66-
68+
private static final SaslAuthenticationException INVALID_SASL_STATE_EXCEPTION = new SaslAuthenticationException("invalid sasl state");
6769
private final ScheduledExecutorService executorService;
6870
private final BackoffStrategy strategy;
6971
private final LoadingCache<String, AtomicInteger> rateLimiter;
7072
private final OAuthBearerValidatorCallbackHandler oauthHandler;
7173
private @Nullable SaslServer saslServer;
7274
private boolean validateAuthentication = true;
75+
private @Nullable String authorizationId;
7376

7477
public OauthBearerValidationFilter(ScheduledExecutorService executorService, SharedOauthBearerValidationContext sharedContext) {
7578
this.executorService = executorService;
@@ -102,6 +105,7 @@ public CompletionStage<RequestFilterResult> onSaslHandshakeRequest(short apiVers
102105
}
103106
catch (SaslException e) {
104107
LOGGER.debug("SASL error : {}", e.getMessage(), e);
108+
notifyThrowable(context, e);
105109
return context.requestFilterResultBuilder()
106110
.shortCircuitResponse(new SaslHandshakeResponseData().setErrorCode(UNKNOWN_SERVER_ERROR.code()))
107111
.withCloseConnection()
@@ -125,23 +129,31 @@ public CompletionStage<RequestFilterResult> onSaslAuthenticateRequest(short apiV
125129
.setErrorMessage("Unexpected SASL request")
126130
.setAuthBytes(request.authBytes());
127131
LOGGER.debug("SASL invalid state");
132+
notifyThrowable(context, INVALID_SASL_STATE_EXCEPTION);
128133
return context.requestFilterResultBuilder().shortCircuitResponse(failedResponse).withCloseConnection().completed();
129134
}
130135
this.saslServer = null;
131136

132137
return authenticate(server, request.authBytes())
133138
.thenCompose(bytes -> context.forwardRequest(header, request))
134139
.exceptionallyCompose(e -> {
135-
if (e.getCause() instanceof SaslAuthenticationException) {
140+
if (e.getCause() instanceof SaslAuthenticationException cause) {
136141
SaslAuthenticateResponseData failedResponse = new SaslAuthenticateResponseData()
137142
.setErrorCode(SASL_AUTHENTICATION_FAILED.code())
138143
.setErrorMessage(e.getMessage())
139144
.setAuthBytes(request.authBytes());
140145
LOGGER.debug("SASL Authentication failed : {}", e.getMessage(), e);
146+
notifyThrowable(context, cause);
141147
return context.requestFilterResultBuilder().shortCircuitResponse(failedResponse).withCloseConnection().completed();
142148
}
143149
else {
144150
LOGGER.debug("SASL error : {}", e.getMessage(), e);
151+
if (e instanceof CompletionException) {
152+
notifyThrowable(context, e.getCause());
153+
}
154+
else {
155+
notifyThrowable(context, e);
156+
}
145157
return context.requestFilterResultBuilder()
146158
.shortCircuitResponse(
147159
new SaslAuthenticateResponseData()
@@ -154,11 +166,21 @@ public CompletionStage<RequestFilterResult> onSaslAuthenticateRequest(short apiV
154166
}
155167
}
156168

169+
private void notifyThrowable(FilterContext context, Throwable e) {
170+
if (e instanceof Exception ex) {
171+
context.clientSaslAuthenticationFailure(OAUTHBEARER_MECHANISM, authorizationId, ex);
172+
}
173+
else {
174+
context.clientSaslAuthenticationFailure(OAUTHBEARER_MECHANISM, authorizationId, new RuntimeException(e));
175+
}
176+
}
177+
157178
@Override
158179
public CompletionStage<ResponseFilterResult> onSaslAuthenticateResponse(short apiVersion, ResponseHeaderData header,
159180
SaslAuthenticateResponseData response, FilterContext context) {
160181
if (response.errorCode() == NONE.code()) {
161182
this.validateAuthentication = false;
183+
context.clientSaslAuthenticationSuccess(OAUTHBEARER_MECHANISM, Objects.requireNonNull(authorizationId));
162184
}
163185
return context.forwardResponse(header, response);
164186
}
@@ -197,6 +219,7 @@ private byte[] doAuthenticate(SaslServer server, byte[] authBytes) throws SaslEx
197219
// at this step bytes would be a jsonResponseError from SASL server
198220
throw new SaslAuthenticationException("SASL failed : " + new String(bytes, StandardCharsets.UTF_8));
199221
}
222+
this.authorizationId = server.getAuthorizationID();
200223
return bytes;
201224
}
202225
finally {

0 commit comments

Comments
 (0)