1111import java .security .NoSuchAlgorithmException ;
1212import java .time .Duration ;
1313import java .util .HexFormat ;
14+ import java .util .Objects ;
1415import java .util .concurrent .CompletableFuture ;
16+ import java .util .concurrent .CompletionException ;
1517import java .util .concurrent .CompletionStage ;
1618import java .util .concurrent .ScheduledExecutorService ;
1719import java .util .concurrent .TimeUnit ;
@@ -63,13 +65,14 @@ public class OauthBearerValidationFilter
6365 SaslAuthenticateResponseFilter {
6466
6567 private static final Logger LOGGER = LoggerFactory .getLogger (OauthBearerValidationFilter .class );
66-
68+ private static final SaslAuthenticationException INVALID_SASL_STATE_EXCEPTION = new SaslAuthenticationException ( "invalid sasl state" );
6769 private final ScheduledExecutorService executorService ;
6870 private final BackoffStrategy strategy ;
6971 private final LoadingCache <String , AtomicInteger > rateLimiter ;
7072 private final OAuthBearerValidatorCallbackHandler oauthHandler ;
7173 private @ Nullable SaslServer saslServer ;
7274 private boolean validateAuthentication = true ;
75+ private @ Nullable String authorizationId ;
7376
7477 public OauthBearerValidationFilter (ScheduledExecutorService executorService , SharedOauthBearerValidationContext sharedContext ) {
7578 this .executorService = executorService ;
@@ -102,6 +105,7 @@ public CompletionStage<RequestFilterResult> onSaslHandshakeRequest(short apiVers
102105 }
103106 catch (SaslException e ) {
104107 LOGGER .debug ("SASL error : {}" , e .getMessage (), e );
108+ notifyThrowable (context , e );
105109 return context .requestFilterResultBuilder ()
106110 .shortCircuitResponse (new SaslHandshakeResponseData ().setErrorCode (UNKNOWN_SERVER_ERROR .code ()))
107111 .withCloseConnection ()
@@ -125,23 +129,31 @@ public CompletionStage<RequestFilterResult> onSaslAuthenticateRequest(short apiV
125129 .setErrorMessage ("Unexpected SASL request" )
126130 .setAuthBytes (request .authBytes ());
127131 LOGGER .debug ("SASL invalid state" );
132+ notifyThrowable (context , INVALID_SASL_STATE_EXCEPTION );
128133 return context .requestFilterResultBuilder ().shortCircuitResponse (failedResponse ).withCloseConnection ().completed ();
129134 }
130135 this .saslServer = null ;
131136
132137 return authenticate (server , request .authBytes ())
133138 .thenCompose (bytes -> context .forwardRequest (header , request ))
134139 .exceptionallyCompose (e -> {
135- if (e .getCause () instanceof SaslAuthenticationException ) {
140+ if (e .getCause () instanceof SaslAuthenticationException cause ) {
136141 SaslAuthenticateResponseData failedResponse = new SaslAuthenticateResponseData ()
137142 .setErrorCode (SASL_AUTHENTICATION_FAILED .code ())
138143 .setErrorMessage (e .getMessage ())
139144 .setAuthBytes (request .authBytes ());
140145 LOGGER .debug ("SASL Authentication failed : {}" , e .getMessage (), e );
146+ notifyThrowable (context , cause );
141147 return context .requestFilterResultBuilder ().shortCircuitResponse (failedResponse ).withCloseConnection ().completed ();
142148 }
143149 else {
144150 LOGGER .debug ("SASL error : {}" , e .getMessage (), e );
151+ if (e instanceof CompletionException ) {
152+ notifyThrowable (context , e .getCause ());
153+ }
154+ else {
155+ notifyThrowable (context , e );
156+ }
145157 return context .requestFilterResultBuilder ()
146158 .shortCircuitResponse (
147159 new SaslAuthenticateResponseData ()
@@ -154,11 +166,21 @@ public CompletionStage<RequestFilterResult> onSaslAuthenticateRequest(short apiV
154166 }
155167 }
156168
169+ private void notifyThrowable (FilterContext context , Throwable e ) {
170+ if (e instanceof Exception ex ) {
171+ context .clientSaslAuthenticationFailure (OAUTHBEARER_MECHANISM , authorizationId , ex );
172+ }
173+ else {
174+ context .clientSaslAuthenticationFailure (OAUTHBEARER_MECHANISM , authorizationId , new RuntimeException (e ));
175+ }
176+ }
177+
157178 @ Override
158179 public CompletionStage <ResponseFilterResult > onSaslAuthenticateResponse (short apiVersion , ResponseHeaderData header ,
159180 SaslAuthenticateResponseData response , FilterContext context ) {
160181 if (response .errorCode () == NONE .code ()) {
161182 this .validateAuthentication = false ;
183+ context .clientSaslAuthenticationSuccess (OAUTHBEARER_MECHANISM , Objects .requireNonNull (authorizationId ));
162184 }
163185 return context .forwardResponse (header , response );
164186 }
@@ -197,6 +219,7 @@ private byte[] doAuthenticate(SaslServer server, byte[] authBytes) throws SaslEx
197219 // at this step bytes would be a jsonResponseError from SASL server
198220 throw new SaslAuthenticationException ("SASL failed : " + new String (bytes , StandardCharsets .UTF_8 ));
199221 }
222+ this .authorizationId = server .getAuthorizationID ();
200223 return bytes ;
201224 }
202225 finally {
0 commit comments