Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions activemq-jaas/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,11 @@
<artifactId>jasypt</artifactId>
<optional>true</optional>
</dependency>
<dependency>
<groupId>com.nimbusds</groupId>
<artifactId>nimbus-jose-jwt</artifactId>
<optional>true</optional>
</dependency>
</dependencies>

<profiles>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,342 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.activemq.jaas;

import java.io.IOException;
import java.net.MalformedURLException;
import java.net.URL;
import java.security.Principal;
import java.text.ParseException;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

import javax.security.auth.Subject;
import javax.security.auth.callback.Callback;
import javax.security.auth.callback.CallbackHandler;
import javax.security.auth.callback.PasswordCallback;
import javax.security.auth.callback.UnsupportedCallbackException;
import javax.security.auth.login.FailedLoginException;
import javax.security.auth.login.LoginException;
import javax.security.auth.spi.LoginModule;

import com.nimbusds.jose.JOSEException;
import com.nimbusds.jose.JWSAlgorithm;
import com.nimbusds.jose.jwk.source.JWKSource;
import com.nimbusds.jose.jwk.source.JWKSourceBuilder;
import com.nimbusds.jose.proc.BadJOSEException;
import com.nimbusds.jose.proc.JWSKeySelector;
import com.nimbusds.jose.proc.JWSVerificationKeySelector;
import com.nimbusds.jose.proc.SecurityContext;
import com.nimbusds.jwt.JWTClaimsSet;
import com.nimbusds.jwt.proc.ConfigurableJWTProcessor;
import com.nimbusds.jwt.proc.DefaultJWTClaimsVerifier;
import com.nimbusds.jwt.proc.DefaultJWTProcessor;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
* A JAAS LoginModule that authenticates users via OAuth2 JWT access tokens.
* <p>
* The client passes the JWT access token as the password in the connection info.
* The module validates the token signature using the JWKS endpoint and verifies
* standard claims (issuer, audience, expiration).
* <p>
* Configuration options (in login.config):
* <ul>
* <li>{@code oauth2.jwksUrl} (required) - URL to the JWKS endpoint for signature verification</li>
* <li>{@code oauth2.issuer} (required) - Expected token issuer (iss claim)</li>
* <li>{@code oauth2.audience} (optional) - Expected token audience (aud claim)</li>
* <li>{@code oauth2.usernameClaim} (optional, default: "sub") - JWT claim to use as username</li>
* <li>{@code oauth2.groupsClaim} (optional, default: "groups") - JWT claim containing group memberships</li>
* <li>{@code debug} (optional) - Enable debug logging</li>
* </ul>
* <p>
* Example login.config:
* <pre>
* activemq-oauth2 {
* org.apache.activemq.jaas.OAuth2LoginModule required
* oauth2.jwksUrl="https://idp.example.com/.well-known/jwks.json"
* oauth2.issuer="https://idp.example.com"
* oauth2.audience="activemq"
* oauth2.usernameClaim="preferred_username"
* oauth2.groupsClaim="roles";
* };
* </pre>
*/
public class OAuth2LoginModule implements LoginModule {

private static final Logger LOG = LoggerFactory.getLogger(OAuth2LoginModule.class);

static final String JWKS_URL_OPTION = "oauth2.jwksUrl";
static final String ISSUER_OPTION = "oauth2.issuer";
static final String AUDIENCE_OPTION = "oauth2.audience";
static final String USERNAME_CLAIM_OPTION = "oauth2.usernameClaim";
static final String GROUPS_CLAIM_OPTION = "oauth2.groupsClaim";

private static final String DEFAULT_USERNAME_CLAIM = "sub";
private static final String DEFAULT_GROUPS_CLAIM = "groups";

private Subject subject;
private CallbackHandler callbackHandler;
private boolean debug;

private String jwksUrl;
private String issuer;
private String audience;
private String usernameClaim;
private String groupsClaim;

private String user;
private final Set<Principal> principals = new HashSet<>();
private boolean succeeded;
private boolean commitSucceeded;

private ConfigurableJWTProcessor<SecurityContext> jwtProcessor;

@Override
public void initialize(Subject subject, CallbackHandler callbackHandler, Map<String, ?> sharedState, Map<String, ?> options) {
this.subject = subject;
this.callbackHandler = callbackHandler;
this.succeeded = false;
this.debug = Boolean.parseBoolean((String) options.get("debug"));

this.jwksUrl = (String) options.get(JWKS_URL_OPTION);
this.issuer = (String) options.get(ISSUER_OPTION);
this.audience = (String) options.get(AUDIENCE_OPTION);

String userClaim = (String) options.get(USERNAME_CLAIM_OPTION);
this.usernameClaim = (userClaim != null && !userClaim.isEmpty()) ? userClaim : DEFAULT_USERNAME_CLAIM;

String grpClaim = (String) options.get(GROUPS_CLAIM_OPTION);
this.groupsClaim = (grpClaim != null && !grpClaim.isEmpty()) ? grpClaim : DEFAULT_GROUPS_CLAIM;

if (debug) {
LOG.debug("OAuth2LoginModule initialized with jwksUrl={}, issuer={}, audience={}, usernameClaim={}, groupsClaim={}",
jwksUrl, issuer, audience, usernameClaim, groupsClaim);
}
}

@Override
public boolean login() throws LoginException {
if (jwksUrl == null || jwksUrl.isEmpty()) {
throw new LoginException("OAuth2 JWKS URL (oauth2.jwksUrl) is required");
}
if (issuer == null || issuer.isEmpty()) {
throw new LoginException("OAuth2 issuer (oauth2.issuer) is required");
}

String token = getToken();
if (token == null || token.isEmpty()) {
throw new FailedLoginException("No JWT token provided");
}

try {
JWTClaimsSet claims = validateToken(token);
user = claims.getStringClaim(usernameClaim);
if (user == null || user.isEmpty()) {
throw new FailedLoginException("JWT token does not contain the username claim: " + usernameClaim);
}

principals.add(new UserPrincipal(user));

List<String> groups = getGroupsFromClaims(claims);
if (groups != null) {
for (String group : groups) {
principals.add(new GroupPrincipal(group));
}
}

succeeded = true;
if (debug) {
LOG.debug("OAuth2 login succeeded for user={} with groups={}", user, groups);
}
} catch (FailedLoginException e) {
throw e;
} catch (Exception e) {
LoginException le = new FailedLoginException("JWT token validation failed: " + e.getMessage());
le.initCause(e);
throw le;
}

return succeeded;
}

@Override
public boolean commit() throws LoginException {
if (!succeeded) {
clear();
if (debug) {
LOG.debug("commit, result: false");
}
return false;
}

subject.getPrincipals().addAll(principals);
commitSucceeded = true;

if (debug) {
LOG.debug("commit, result: true");
}
return true;
}

@Override
public boolean abort() throws LoginException {
if (debug) {
LOG.debug("abort");
}
if (!succeeded) {
return false;
} else if (commitSucceeded) {
logout();
} else {
clear();
succeeded = false;
}
return true;
}

@Override
public boolean logout() throws LoginException {
subject.getPrincipals().removeAll(principals);
clear();
if (debug) {
LOG.debug("logout");
}
succeeded = false;
commitSucceeded = false;
return true;
}

private String getToken() throws LoginException {
// Try OAuth2TokenCallback first, then fall back to PasswordCallback
try {
OAuth2TokenCallback tokenCallback = new OAuth2TokenCallback();
callbackHandler.handle(new Callback[]{tokenCallback});
if (tokenCallback.getToken() != null) {
return tokenCallback.getToken();
}
} catch (UnsupportedCallbackException e) {
// OAuth2TokenCallback not supported, fall back to PasswordCallback
if (debug) {
LOG.debug("OAuth2TokenCallback not supported, falling back to PasswordCallback");
}
} catch (IOException e) {
throw new LoginException("Error retrieving OAuth2 token: " + e.getMessage());
}

// Fall back to PasswordCallback (token passed as password)
try {
PasswordCallback passwordCallback = new PasswordCallback("Token: ", false);
callbackHandler.handle(new Callback[]{passwordCallback});
char[] tokenChars = passwordCallback.getPassword();
if (tokenChars != null) {
return new String(tokenChars);
}
} catch (IOException | UnsupportedCallbackException e) {
throw new LoginException("Error retrieving token from PasswordCallback: " + e.getMessage());
}

return null;
}

JWTClaimsSet validateToken(String token) throws LoginException {
try {
ConfigurableJWTProcessor<SecurityContext> processor = getJWTProcessor();
return processor.process(token, null);
} catch (ParseException e) {
throw new FailedLoginException("Invalid JWT format: " + e.getMessage());
} catch (BadJOSEException e) {
throw new FailedLoginException("JWT validation failed: " + e.getMessage());
} catch (JOSEException e) {
throw new FailedLoginException("JWT processing error: " + e.getMessage());
}
}

private ConfigurableJWTProcessor<SecurityContext> getJWTProcessor() throws LoginException {
if (jwtProcessor != null) {
return jwtProcessor;
}

try {
URL jwksEndpoint = new URL(jwksUrl);
JWKSource<SecurityContext> keySource = JWKSourceBuilder
.create(jwksEndpoint)
.retrying(true)
.build();

JWSKeySelector<SecurityContext> keySelector = new JWSVerificationKeySelector<>(
JWSAlgorithm.Family.RSA, keySource);

ConfigurableJWTProcessor<SecurityContext> processor = new DefaultJWTProcessor<>();
processor.setJWSKeySelector(keySelector);

// Build the claims verifier with issuer and optional audience
JWTClaimsSet.Builder exactMatchBuilder = new JWTClaimsSet.Builder()
.issuer(issuer);

Set<String> requiredClaims = new HashSet<>();
requiredClaims.add("sub");
requiredClaims.add("iss");
requiredClaims.add("exp");

if (audience != null && !audience.isEmpty()) {
exactMatchBuilder.audience(audience);
requiredClaims.add("aud");
}

processor.setJWTClaimsSetVerifier(new DefaultJWTClaimsVerifier<>(
exactMatchBuilder.build(),
requiredClaims));

jwtProcessor = processor;
return jwtProcessor;
} catch (MalformedURLException e) {
throw new LoginException("Invalid JWKS URL: " + jwksUrl);
}
}

@SuppressWarnings("unchecked")
private List<String> getGroupsFromClaims(JWTClaimsSet claims) {
try {
Object groupsValue = claims.getClaim(groupsClaim);
if (groupsValue instanceof List) {
return (List<String>) groupsValue;
} else if (groupsValue instanceof String) {
return List.of(((String) groupsValue).split(","));
}
} catch (Exception e) {
if (debug) {
LOG.debug("Could not extract groups from claim '{}': {}", groupsClaim, e.getMessage());
}
}
return null;
}

private void clear() {
user = null;
principals.clear();
}

// Visible for testing
void setJwtProcessor(ConfigurableJWTProcessor<SecurityContext> jwtProcessor) {
this.jwtProcessor = jwtProcessor;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.activemq.jaas;

import javax.security.auth.callback.Callback;

/**
* A JAAS Callback for passing an OAuth2 JWT token.
*/
public class OAuth2TokenCallback implements Callback {

private String token;

public String getToken() {
return token;
}

public void setToken(String token) {
this.token = token;
}
}
Loading
Loading