diff --git a/docs/changelog/128176.yaml b/docs/changelog/128176.yaml new file mode 100644 index 0000000000000..2cf76c4513772 --- /dev/null +++ b/docs/changelog/128176.yaml @@ -0,0 +1,5 @@ +pr: 128176 +summary: Implement SAML custom attributes support for Identity Provider +area: Authentication +type: enhancement +issues: [] diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 57ba6de0b973c..440aa8cacc903 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -273,6 +273,7 @@ static TransportVersion def(int id) { public static final TransportVersion INFERENCE_CUSTOM_SERVICE_ADDED = def(9_084_0_00); public static final TransportVersion ESQL_LIMIT_ROW_SIZE = def(9_085_0_00); public static final TransportVersion ESQL_REGEX_MATCH_WITH_CASE_INSENSITIVITY = def(9_086_0_00); + public static final TransportVersion IDP_CUSTOM_SAML_ATTRIBUTES = def(9_087_0_00); /* * STOP! READ THIS FIRST! No, really, diff --git a/x-pack/plugin/identity-provider/qa/idp-rest-tests/src/javaRestTest/java/org/elasticsearch/xpack/idp/IdentityProviderAuthenticationIT.java b/x-pack/plugin/identity-provider/qa/idp-rest-tests/src/javaRestTest/java/org/elasticsearch/xpack/idp/IdentityProviderAuthenticationIT.java index c065e8d7e1d12..7f3b39fb75d50 100644 --- a/x-pack/plugin/identity-provider/qa/idp-rest-tests/src/javaRestTest/java/org/elasticsearch/xpack/idp/IdentityProviderAuthenticationIT.java +++ b/x-pack/plugin/identity-provider/qa/idp-rest-tests/src/javaRestTest/java/org/elasticsearch/xpack/idp/IdentityProviderAuthenticationIT.java @@ -18,23 +18,38 @@ import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.core.Nullable; import org.elasticsearch.xcontent.ObjectPath; +import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.json.JsonXContent; import org.elasticsearch.xpack.core.security.action.saml.SamlPrepareAuthenticationResponse; import org.elasticsearch.xpack.idp.saml.sp.SamlServiceProviderIndex; import org.junit.Before; +import org.w3c.dom.Document; +import org.w3c.dom.Element; +import org.w3c.dom.NodeList; +import org.xml.sax.InputSource; import java.io.IOException; +import java.io.StringReader; import java.nio.charset.StandardCharsets; +import java.util.ArrayList; import java.util.Base64; import java.util.List; import java.util.Map; import java.util.Set; +import javax.xml.parsers.DocumentBuilder; +import javax.xml.parsers.DocumentBuilderFactory; +import javax.xml.xpath.XPath; +import javax.xml.xpath.XPathConstants; +import javax.xml.xpath.XPathFactory; + import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.notNullValue; public class IdentityProviderAuthenticationIT extends IdpRestTestCase { @@ -74,6 +89,81 @@ public void testRegistrationAndIdpInitiatedSso() throws Exception { authenticateWithSamlResponse(samlResponse, null); } + public void testCustomAttributesInIdpInitiatedSso() throws Exception { + final Map request = Map.ofEntries( + Map.entry("name", "Test SP With Custom Attributes"), + Map.entry("acs", SP_ACS), + Map.entry("privileges", Map.ofEntries(Map.entry("resource", SP_ENTITY_ID), Map.entry("roles", List.of("sso:(\\w+)")))), + Map.entry( + "attributes", + Map.ofEntries( + Map.entry("principal", "https://idp.test.es.elasticsearch.org/attribute/principal"), + Map.entry("name", "https://idp.test.es.elasticsearch.org/attribute/name"), + Map.entry("email", "https://idp.test.es.elasticsearch.org/attribute/email"), + Map.entry("roles", "https://idp.test.es.elasticsearch.org/attribute/roles") + ) + ) + ); + final SamlServiceProviderIndex.DocumentVersion docVersion = createServiceProvider(SP_ENTITY_ID, request); + checkIndexDoc(docVersion); + ensureGreen(SamlServiceProviderIndex.INDEX_NAME); + + // Create custom attributes + Map> attributesMap = Map.of("department", List.of("engineering", "product"), "region", List.of("APJ")); + + // Generate SAML response with custom attributes + final String samlResponse = generateSamlResponseWithAttributes(SP_ENTITY_ID, SP_ACS, null, attributesMap); + + // Parse XML directly from samlResponse (it's not base64 encoded at this point) + DocumentBuilderFactory factory = DocumentBuilderFactory.newInstance(); + factory.setNamespaceAware(true); // Required for XPath + DocumentBuilder builder = factory.newDocumentBuilder(); + Document document = builder.parse(new InputSource(new StringReader(samlResponse))); + + // Create XPath evaluator + XPathFactory xPathFactory = XPathFactory.newInstance(); + XPath xpath = xPathFactory.newXPath(); + + // Validate SAML Response structure + Element responseElement = (Element) xpath.evaluate("//*[local-name()='Response']", document, XPathConstants.NODE); + assertThat("SAML Response element should exist", responseElement, notNullValue()); + + Element assertionElement = (Element) xpath.evaluate("//*[local-name()='Assertion']", document, XPathConstants.NODE); + assertThat("SAML Assertion element should exist", assertionElement, notNullValue()); + + // Validate department attribute + NodeList departmentAttributes = (NodeList) xpath.evaluate( + "//*[local-name()='Attribute' and @Name='department']/*[local-name()='AttributeValue']", + document, + XPathConstants.NODESET + ); + + assertThat("Should have two values for department attribute", departmentAttributes.getLength(), is(2)); + + // Verify department values + List departmentValues = new ArrayList<>(); + for (int i = 0; i < departmentAttributes.getLength(); i++) { + departmentValues.add(departmentAttributes.item(i).getTextContent()); + } + assertThat( + "Department attribute should contain 'engineering' and 'product'", + departmentValues, + containsInAnyOrder("engineering", "product") + ); + + // Validate region attribute + NodeList regionAttributes = (NodeList) xpath.evaluate( + "//*[local-name()='Attribute' and @Name='region']/*[local-name()='AttributeValue']", + document, + XPathConstants.NODESET + ); + + assertThat("Should have one value for region attribute", regionAttributes.getLength(), is(1)); + assertThat("Region attribute should contain 'APJ'", regionAttributes.item(0).getTextContent(), equalTo("APJ")); + + authenticateWithSamlResponse(samlResponse, null); + } + public void testRegistrationAndSpInitiatedSso() throws Exception { final Map request = Map.ofEntries( Map.entry("name", "Test SP"), @@ -125,17 +215,37 @@ private SamlPrepareAuthenticationResponse generateSamlAuthnRequest(String realmN } } - private String generateSamlResponse(String entityId, String acs, @Nullable Map authnState) throws Exception { + private String generateSamlResponse(String entityId, String acs, @Nullable Map authnState) throws IOException { + return generateSamlResponseWithAttributes(entityId, acs, authnState, null); + } + + private String generateSamlResponseWithAttributes( + String entityId, + String acs, + @Nullable Map authnState, + @Nullable Map> attributes + ) throws IOException { final Request request = new Request("POST", "/_idp/saml/init"); - if (authnState != null && authnState.isEmpty() == false) { - request.setJsonEntity(Strings.format(""" - {"entity_id":"%s", "acs":"%s","authn_state":%s} - """, entityId, acs, Strings.toString(JsonXContent.contentBuilder().map(authnState)))); - } else { - request.setJsonEntity(Strings.format(""" - {"entity_id":"%s", "acs":"%s"} - """, entityId, acs)); + + XContentBuilder builder = JsonXContent.contentBuilder(); + builder.startObject(); + builder.field("entity_id", entityId); + builder.field("acs", acs); + + if (authnState != null) { + builder.field("authn_state"); + builder.map(authnState); + } + + if (attributes != null) { + builder.field("attributes"); + builder.map(attributes); } + + builder.endObject(); + String jsonEntity = Strings.toString(builder); + + request.setJsonEntity(jsonEntity); request.setOptions( RequestOptions.DEFAULT.toBuilder() .addHeader("es-secondary-authorization", basicAuthHeaderValue("idp_user", new SecureString("idp-password".toCharArray()))) diff --git a/x-pack/plugin/identity-provider/src/main/java/org/elasticsearch/xpack/idp/action/SamlInitiateSingleSignOnRequest.java b/x-pack/plugin/identity-provider/src/main/java/org/elasticsearch/xpack/idp/action/SamlInitiateSingleSignOnRequest.java index 6070b247093e1..b93616f54fb3a 100644 --- a/x-pack/plugin/identity-provider/src/main/java/org/elasticsearch/xpack/idp/action/SamlInitiateSingleSignOnRequest.java +++ b/x-pack/plugin/identity-provider/src/main/java/org/elasticsearch/xpack/idp/action/SamlInitiateSingleSignOnRequest.java @@ -6,12 +6,14 @@ */ package org.elasticsearch.xpack.idp.action; +import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionRequestValidationException; import org.elasticsearch.action.LegacyActionRequest; import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.xpack.idp.saml.support.SamlAuthenticationState; +import org.elasticsearch.xpack.idp.saml.support.SamlInitiateSingleSignOnAttributes; import java.io.IOException; @@ -22,12 +24,16 @@ public class SamlInitiateSingleSignOnRequest extends LegacyActionRequest { private String spEntityId; private String assertionConsumerService; private SamlAuthenticationState samlAuthenticationState; + private SamlInitiateSingleSignOnAttributes attributes; public SamlInitiateSingleSignOnRequest(StreamInput in) throws IOException { super(in); spEntityId = in.readString(); assertionConsumerService = in.readString(); samlAuthenticationState = in.readOptionalWriteable(SamlAuthenticationState::new); + if (in.getTransportVersion().onOrAfter(TransportVersions.IDP_CUSTOM_SAML_ATTRIBUTES)) { + attributes = in.readOptionalWriteable(SamlInitiateSingleSignOnAttributes::new); + } } public SamlInitiateSingleSignOnRequest() {} @@ -41,6 +47,17 @@ public ActionRequestValidationException validate() { if (Strings.isNullOrEmpty(assertionConsumerService)) { validationException = addValidationError("acs is missing", validationException); } + + // Validate attributes if present + if (attributes != null) { + ActionRequestValidationException attributesValidationException = attributes.validate(); + if (attributesValidationException != null) { + for (String error : attributesValidationException.validationErrors()) { + validationException = addValidationError(error, validationException); + } + } + } + return validationException; } @@ -68,17 +85,38 @@ public void setSamlAuthenticationState(SamlAuthenticationState samlAuthenticatio this.samlAuthenticationState = samlAuthenticationState; } + public SamlInitiateSingleSignOnAttributes getAttributes() { + return attributes; + } + + public void setAttributes(SamlInitiateSingleSignOnAttributes attributes) { + this.attributes = attributes; + } + @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); out.writeString(spEntityId); out.writeString(assertionConsumerService); out.writeOptionalWriteable(samlAuthenticationState); + if (out.getTransportVersion().onOrAfter(TransportVersions.IDP_CUSTOM_SAML_ATTRIBUTES)) { + out.writeOptionalWriteable(attributes); + } } @Override public String toString() { - return getClass().getSimpleName() + "{spEntityId='" + spEntityId + "', acs='" + assertionConsumerService + "'}"; + return getClass().getSimpleName() + + "{" + + "spEntityId='" + + spEntityId + + "', " + + "acs='" + + assertionConsumerService + + "', " + + "attributes='" + + attributes + + "'}"; } } diff --git a/x-pack/plugin/identity-provider/src/main/java/org/elasticsearch/xpack/idp/action/TransportSamlInitiateSingleSignOnAction.java b/x-pack/plugin/identity-provider/src/main/java/org/elasticsearch/xpack/idp/action/TransportSamlInitiateSingleSignOnAction.java index 07d71481326d4..ea75e68506773 100644 --- a/x-pack/plugin/identity-provider/src/main/java/org/elasticsearch/xpack/idp/action/TransportSamlInitiateSingleSignOnAction.java +++ b/x-pack/plugin/identity-provider/src/main/java/org/elasticsearch/xpack/idp/action/TransportSamlInitiateSingleSignOnAction.java @@ -139,7 +139,7 @@ protected void doExecute( identityProvider ); try { - final Response response = builder.build(user, authenticationState); + final Response response = builder.build(user, authenticationState, request.getAttributes()); listener.onResponse( new SamlInitiateSingleSignOnResponse( user.getServiceProvider().getEntityId(), diff --git a/x-pack/plugin/identity-provider/src/main/java/org/elasticsearch/xpack/idp/saml/authn/SuccessfulAuthenticationResponseMessageBuilder.java b/x-pack/plugin/identity-provider/src/main/java/org/elasticsearch/xpack/idp/saml/authn/SuccessfulAuthenticationResponseMessageBuilder.java index 5d8cbf2607338..301ada3561299 100644 --- a/x-pack/plugin/identity-provider/src/main/java/org/elasticsearch/xpack/idp/saml/authn/SuccessfulAuthenticationResponseMessageBuilder.java +++ b/x-pack/plugin/identity-provider/src/main/java/org/elasticsearch/xpack/idp/saml/authn/SuccessfulAuthenticationResponseMessageBuilder.java @@ -18,6 +18,7 @@ import org.elasticsearch.xpack.idp.saml.support.SamlAuthenticationState; import org.elasticsearch.xpack.idp.saml.support.SamlFactory; import org.elasticsearch.xpack.idp.saml.support.SamlInit; +import org.elasticsearch.xpack.idp.saml.support.SamlInitiateSingleSignOnAttributes; import org.elasticsearch.xpack.idp.saml.support.SamlObjectSigner; import org.opensaml.core.xml.schema.XSString; import org.opensaml.saml.saml2.core.Assertion; @@ -44,6 +45,7 @@ import java.util.ArrayList; import java.util.Collection; import java.util.List; +import java.util.Map; import java.util.Set; import static org.opensaml.saml.saml2.core.NameIDType.TRANSIENT; @@ -66,7 +68,30 @@ public SuccessfulAuthenticationResponseMessageBuilder(SamlFactory samlFactory, C this.idp = idp; } + /** + * Builds and signs a SAML Response Message with a single assertion for the provided user + * + * @param user The user who is authenticated (actually a combination of user+sp) + * @param authnState The authentication state as presented in the SAML request (or {@code null}) + * @return A SAML Response + */ public Response build(UserServiceAuthentication user, @Nullable SamlAuthenticationState authnState) { + return build(user, authnState, null); + } + + /** + * Builds and signs a SAML Response Message with a single assertion for the provided user + * + * @param user The user who is authenticated (actually a combination of user+sp) + * @param authnState The authentication state as presented in the SAML request (or {@code null}) + * @param customAttributes Optional custom attributes to include in the response (or {@code null}) + * @return A SAML Response + */ + public Response build( + UserServiceAuthentication user, + @Nullable SamlAuthenticationState authnState, + @Nullable SamlInitiateSingleSignOnAttributes customAttributes + ) { logger.debug("Building success response for [{}] from [{}]", user, authnState); final Instant now = clock.instant(); final SamlServiceProvider serviceProvider = user.getServiceProvider(); @@ -87,10 +112,13 @@ public Response build(UserServiceAuthentication user, @Nullable SamlAuthenticati assertion.setIssueInstant(now); assertion.setConditions(buildConditions(now, serviceProvider)); assertion.setSubject(buildSubject(now, user, authnState)); - assertion.getAuthnStatements().add(buildAuthnStatement(now, user)); - final AttributeStatement attributes = buildAttributes(user); - if (attributes != null) { - assertion.getAttributeStatements().add(attributes); + + final AuthnStatement authnStatement = buildAuthnStatement(now, user); + assertion.getAuthnStatements().add(authnStatement); + + final AttributeStatement attributeStatement = buildAttributes(user, customAttributes); + if (attributeStatement != null) { + assertion.getAttributeStatements().add(attributeStatement); } response.getAssertions().add(assertion); return sign(response); @@ -179,7 +207,10 @@ private static String resolveAuthnClass(Set authentication } } - private AttributeStatement buildAttributes(UserServiceAuthentication user) { + private AttributeStatement buildAttributes( + UserServiceAuthentication user, + @Nullable SamlInitiateSingleSignOnAttributes customAttributes + ) { final SamlServiceProvider serviceProvider = user.getServiceProvider(); final AttributeStatement statement = samlFactory.object(AttributeStatement.class, AttributeStatement.DEFAULT_ELEMENT_NAME); final List attributes = new ArrayList<>(); @@ -199,6 +230,16 @@ private AttributeStatement buildAttributes(UserServiceAuthentication user) { if (name != null) { attributes.add(name); } + // Add custom attributes if provided + if (customAttributes != null && customAttributes.getAttributes().isEmpty() == false) { + for (Map.Entry> entry : customAttributes.getAttributes().entrySet()) { + Attribute attribute = buildAttribute(entry.getKey(), null, entry.getValue()); + if (attribute != null) { + attributes.add(attribute); + } + } + } + if (attributes.isEmpty()) { return null; } @@ -206,20 +247,22 @@ private AttributeStatement buildAttributes(UserServiceAuthentication user) { return statement; } - private Attribute buildAttribute(String formalName, String friendlyName, String value) { + private Attribute buildAttribute(String formalName, @Nullable String friendlyName, String value) { if (Strings.isNullOrEmpty(value)) { return null; } return buildAttribute(formalName, friendlyName, List.of(value)); } - private Attribute buildAttribute(String formalName, String friendlyName, Collection values) { + private Attribute buildAttribute(String formalName, @Nullable String friendlyName, Collection values) { if (values.isEmpty() || Strings.isNullOrEmpty(formalName)) { return null; } final Attribute attribute = samlFactory.object(Attribute.class, Attribute.DEFAULT_ELEMENT_NAME); attribute.setName(formalName); - attribute.setFriendlyName(friendlyName); + if (Strings.isNullOrEmpty(friendlyName) == false) { + attribute.setFriendlyName(friendlyName); + } attribute.setNameFormat(Attribute.URI_REFERENCE); for (String val : values) { final XSString string = samlFactory.object(XSString.class, AttributeValue.DEFAULT_ELEMENT_NAME, XSString.TYPE_NAME); diff --git a/x-pack/plugin/identity-provider/src/main/java/org/elasticsearch/xpack/idp/saml/rest/action/RestSamlInitiateSingleSignOnAction.java b/x-pack/plugin/identity-provider/src/main/java/org/elasticsearch/xpack/idp/saml/rest/action/RestSamlInitiateSingleSignOnAction.java index 3e4d57860fdae..5170254411dd9 100644 --- a/x-pack/plugin/identity-provider/src/main/java/org/elasticsearch/xpack/idp/saml/rest/action/RestSamlInitiateSingleSignOnAction.java +++ b/x-pack/plugin/identity-provider/src/main/java/org/elasticsearch/xpack/idp/saml/rest/action/RestSamlInitiateSingleSignOnAction.java @@ -20,6 +20,7 @@ import org.elasticsearch.xpack.idp.action.SamlInitiateSingleSignOnRequest; import org.elasticsearch.xpack.idp.action.SamlInitiateSingleSignOnResponse; import org.elasticsearch.xpack.idp.saml.support.SamlAuthenticationState; +import org.elasticsearch.xpack.idp.saml.support.SamlInitiateSingleSignOnAttributes; import java.io.IOException; import java.util.Collections; @@ -41,6 +42,11 @@ public class RestSamlInitiateSingleSignOnAction extends IdpBaseRestHandler { (p, c) -> SamlAuthenticationState.fromXContent(p), new ParseField("authn_state") ); + PARSER.declareObject( + SamlInitiateSingleSignOnRequest::setAttributes, + (p, c) -> SamlInitiateSingleSignOnAttributes.fromXContent(p), + new ParseField("attributes") + ); } public RestSamlInitiateSingleSignOnAction(XPackLicenseState licenseState) { diff --git a/x-pack/plugin/identity-provider/src/main/java/org/elasticsearch/xpack/idp/saml/support/SamlInitiateSingleSignOnAttributes.java b/x-pack/plugin/identity-provider/src/main/java/org/elasticsearch/xpack/idp/saml/support/SamlInitiateSingleSignOnAttributes.java new file mode 100644 index 0000000000000..e6d533a7bc224 --- /dev/null +++ b/x-pack/plugin/identity-provider/src/main/java/org/elasticsearch/xpack/idp/saml/support/SamlInitiateSingleSignOnAttributes.java @@ -0,0 +1,104 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ +package org.elasticsearch.xpack.idp.saml.support; + +import org.elasticsearch.action.ActionRequestValidationException; +import org.elasticsearch.action.ValidateActions; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.XContentParserUtils; +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +/** + * Represents a collection of SAML attributes to be included in the SAML response. + * Each attribute has a key and a list of values. + */ +public class SamlInitiateSingleSignOnAttributes implements Writeable, ToXContentObject { + private final Map> attributes; + + public SamlInitiateSingleSignOnAttributes(Map> attributes) { + this.attributes = attributes; + } + + /** + * @return A map of SAML attribute key to list of values + */ + public Map> getAttributes() { + return Collections.unmodifiableMap(attributes); + } + + /** + * Creates a SamlInitiateSingleSignOnAttributes object by parsing the provided JSON content. + * Expects a JSON structure like: { "attr1": ["val1", "val2"], "attr2": ["val3"] } + * + * @param parser The XContentParser positioned at the start of the object + * @return A new SamlInitiateSingleSignOnAttributes instance + */ + public static SamlInitiateSingleSignOnAttributes fromXContent(XContentParser parser) throws IOException { + final Map> attributes = parser.map(HashMap::new, p -> XContentParserUtils.parseList(p, XContentParser::text)); + return new SamlInitiateSingleSignOnAttributes(attributes); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + return builder.map(attributes); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeMap(attributes, StreamOutput::writeStringCollection); + } + + public SamlInitiateSingleSignOnAttributes(StreamInput in) throws IOException { + this.attributes = in.readImmutableMap(StreamInput::readStringCollectionAsImmutableList); + } + + /** + * Validates the attributes for correctness. + * An attribute with an empty key is considered invalid. + */ + public ActionRequestValidationException validate() { + ActionRequestValidationException validationException = null; + if (attributes.isEmpty() == false) { + for (String key : attributes.keySet()) { + if (Strings.isNullOrEmpty(key)) { + validationException = ValidateActions.addValidationError("attribute key cannot be null or empty", validationException); + } + } + } + return validationException; + } + + @Override + public String toString() { + return getClass().getSimpleName() + "{attributes=" + attributes + "}"; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + SamlInitiateSingleSignOnAttributes that = (SamlInitiateSingleSignOnAttributes) o; + return Objects.equals(attributes, that.attributes); + } + + @Override + public int hashCode() { + return Objects.hash(attributes); + } +} diff --git a/x-pack/plugin/identity-provider/src/test/java/org/elasticsearch/xpack/idp/action/SamlInitiateSingleSignOnRequestTests.java b/x-pack/plugin/identity-provider/src/test/java/org/elasticsearch/xpack/idp/action/SamlInitiateSingleSignOnRequestTests.java index f930faffd40c7..4a0ec674ab4c8 100644 --- a/x-pack/plugin/identity-provider/src/test/java/org/elasticsearch/xpack/idp/action/SamlInitiateSingleSignOnRequestTests.java +++ b/x-pack/plugin/identity-provider/src/test/java/org/elasticsearch/xpack/idp/action/SamlInitiateSingleSignOnRequestTests.java @@ -6,9 +6,20 @@ */ package org.elasticsearch.xpack.idp.action; +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionRequestValidationException; import org.elasticsearch.common.io.stream.BytesStreamOutput; +import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.TransportVersionUtils; +import org.elasticsearch.xpack.idp.saml.support.SamlInitiateSingleSignOnAttributes; + +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; import static org.hamcrest.CoreMatchers.containsString; import static org.hamcrest.CoreMatchers.equalTo; @@ -16,19 +27,76 @@ public class SamlInitiateSingleSignOnRequestTests extends ESTestCase { - public void testSerialization() throws Exception { + public void testSerializationCurrentVersion() throws Exception { final SamlInitiateSingleSignOnRequest request = new SamlInitiateSingleSignOnRequest(); request.setSpEntityId("https://kibana_url"); request.setAssertionConsumerService("https://kibana_url/acs"); + request.setAttributes( + new SamlInitiateSingleSignOnAttributes( + Map.ofEntries( + Map.entry("http://idp.elastic.co/attribute/custom1", List.of("foo")), + Map.entry("http://idp.elastic.co/attribute/custom2", List.of("bar", "baz")) + ) + ) + ); assertThat("An invalid request is not guaranteed to serialize correctly", request.validate(), nullValue()); final BytesStreamOutput out = new BytesStreamOutput(); + if (randomBoolean()) { + out.setTransportVersion( + TransportVersionUtils.randomVersionBetween( + random(), + TransportVersions.IDP_CUSTOM_SAML_ATTRIBUTES, + TransportVersion.current() + ) + ); + } request.writeTo(out); - final SamlInitiateSingleSignOnRequest request1 = new SamlInitiateSingleSignOnRequest(out.bytes().streamInput()); - assertThat(request1.getSpEntityId(), equalTo(request.getSpEntityId())); - assertThat(request1.getAssertionConsumerService(), equalTo(request.getAssertionConsumerService())); - final ActionRequestValidationException validationException = request1.validate(); - assertNull(validationException); + try (StreamInput in = out.bytes().streamInput()) { + in.setTransportVersion(out.getTransportVersion()); + final SamlInitiateSingleSignOnRequest request1 = new SamlInitiateSingleSignOnRequest(in); + assertThat(request1.getSpEntityId(), equalTo(request.getSpEntityId())); + assertThat(request1.getAssertionConsumerService(), equalTo(request.getAssertionConsumerService())); + assertThat(request1.getAttributes(), equalTo(request.getAttributes())); + final ActionRequestValidationException validationException = request1.validate(); + assertNull(validationException); + } + } + + public void testSerializationOldTransportVersion() throws Exception { + final SamlInitiateSingleSignOnRequest request = new SamlInitiateSingleSignOnRequest(); + request.setSpEntityId("https://kibana_url"); + request.setAssertionConsumerService("https://kibana_url/acs"); + if (randomBoolean()) { + request.setAttributes( + new SamlInitiateSingleSignOnAttributes( + Map.ofEntries( + Map.entry("http://idp.elastic.co/attribute/custom1", List.of("foo")), + Map.entry("http://idp.elastic.co/attribute/custom2", List.of("bar", "baz")) + ) + ) + ); + } + assertThat("An invalid request is not guaranteed to serialize correctly", request.validate(), nullValue()); + final BytesStreamOutput out = new BytesStreamOutput(); + out.setTransportVersion( + TransportVersionUtils.randomVersionBetween( + random(), + TransportVersions.MINIMUM_COMPATIBLE, + TransportVersionUtils.getPreviousVersion(TransportVersions.IDP_CUSTOM_SAML_ATTRIBUTES) + ) + ); + request.writeTo(out); + + try (StreamInput in = out.bytes().streamInput()) { + in.setTransportVersion(out.getTransportVersion()); + final SamlInitiateSingleSignOnRequest request1 = new SamlInitiateSingleSignOnRequest(in); + assertThat(request1.getSpEntityId(), equalTo(request.getSpEntityId())); + assertThat(request1.getAssertionConsumerService(), equalTo(request.getAssertionConsumerService())); + assertThat(request1.getAttributes(), nullValue()); + final ActionRequestValidationException validationException = request1.validate(); + assertNull(validationException); + } } public void testValidation() { @@ -39,4 +107,35 @@ public void testValidation() { assertThat(validationException.validationErrors().get(0), containsString("entity_id is missing")); assertThat(validationException.validationErrors().get(1), containsString("acs is missing")); } + + public void testBlankAttributeKeysValidation() { + // Create request with valid required fields + final SamlInitiateSingleSignOnRequest request = new SamlInitiateSingleSignOnRequest(); + request.setSpEntityId("https://kibana_url"); + request.setAssertionConsumerService("https://kibana_url/acs"); + + // Test with valid attribute keys + Map> attributeMap = new HashMap<>(); + attributeMap.put("key1", Collections.singletonList("value1")); + attributeMap.put("key2", Arrays.asList("value2A", "value2B")); + SamlInitiateSingleSignOnAttributes attributes = new SamlInitiateSingleSignOnAttributes(attributeMap); + request.setAttributes(attributes); + + // Should pass validation + ActionRequestValidationException validationException = request.validate(); + assertNull("Request with valid attribute keys should pass validation", validationException); + + // Test with empty attribute key - should be invalid + attributeMap = new HashMap<>(); + attributeMap.put("", Collections.singletonList("value1")); + attributeMap.put("unique_key", Collections.singletonList("value2")); + attributes = new SamlInitiateSingleSignOnAttributes(attributeMap); + request.setAttributes(attributes); + + // Should fail validation with appropriate error message + validationException = request.validate(); + assertNotNull("Request with empty attribute key should fail validation", validationException); + assertThat(validationException.validationErrors().size(), equalTo(1)); + assertThat(validationException.validationErrors().get(0), containsString("attribute key cannot be null or empty")); + } } diff --git a/x-pack/plugin/identity-provider/src/test/java/org/elasticsearch/xpack/idp/saml/authn/SuccessfulAuthenticationResponseMessageBuilderTests.java b/x-pack/plugin/identity-provider/src/test/java/org/elasticsearch/xpack/idp/saml/authn/SuccessfulAuthenticationResponseMessageBuilderTests.java index bd433e828436b..0e58935c07e31 100644 --- a/x-pack/plugin/identity-provider/src/test/java/org/elasticsearch/xpack/idp/saml/authn/SuccessfulAuthenticationResponseMessageBuilderTests.java +++ b/x-pack/plugin/identity-provider/src/test/java/org/elasticsearch/xpack/idp/saml/authn/SuccessfulAuthenticationResponseMessageBuilderTests.java @@ -12,14 +12,22 @@ import org.elasticsearch.xpack.idp.saml.sp.ServiceProviderDefaults; import org.elasticsearch.xpack.idp.saml.support.SamlFactory; import org.elasticsearch.xpack.idp.saml.support.SamlInit; +import org.elasticsearch.xpack.idp.saml.support.SamlInitiateSingleSignOnAttributes; import org.elasticsearch.xpack.idp.saml.support.XmlValidator; import org.elasticsearch.xpack.idp.saml.test.IdpSamlTestCase; import org.junit.Before; +import org.opensaml.saml.saml2.core.Attribute; +import org.opensaml.saml.saml2.core.AttributeStatement; import org.opensaml.saml.saml2.core.Response; -import java.net.URL; +import java.net.URI; import java.time.Clock; import java.time.Duration; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; import java.util.Set; import static org.hamcrest.Matchers.containsString; @@ -46,20 +54,63 @@ public void setupSaml() throws Exception { } public void testSignedResponseIsValidAgainstXmlSchema() throws Exception { - final Response response = buildResponse(); + final Response response = buildResponse(null); final String xml = super.toString(response); assertThat(xml, containsString("SignedInfo>")); validator.validate(xml); } - private Response buildResponse() throws Exception { + public void testSignedResponseWithCustomAttributes() throws Exception { + // Create custom attributes + Map> attributeMap = new HashMap<>(); + attributeMap.put("customAttr1", Collections.singletonList("value1")); + + List multipleValues = new ArrayList<>(); + multipleValues.add("value2A"); + multipleValues.add("value2B"); + attributeMap.put("customAttr2", multipleValues); + SamlInitiateSingleSignOnAttributes attributes = new SamlInitiateSingleSignOnAttributes(attributeMap); + + // Build response with custom attributes + final Response response = buildResponse(attributes); + final String xml = super.toString(response); + + // Validate that response is correctly signed + assertThat(xml, containsString("SignedInfo>")); + validator.validate(xml); + + // Verify custom attributes are included + boolean foundCustomAttr1 = false; + boolean foundCustomAttr2 = false; + + for (AttributeStatement statement : response.getAssertions().get(0).getAttributeStatements()) { + for (Attribute attribute : statement.getAttributes()) { + String name = attribute.getName(); + if (name.equals("customAttr1")) { + foundCustomAttr1 = true; + assertEquals(1, attribute.getAttributeValues().size()); + assertThat(attribute.getAttributeValues().get(0).getDOM().getTextContent(), containsString("value1")); + } else if (name.equals("customAttr2")) { + foundCustomAttr2 = true; + assertEquals(2, attribute.getAttributeValues().size()); + assertThat(attribute.getAttributeValues().get(0).getDOM().getTextContent(), containsString("value2A")); + assertThat(attribute.getAttributeValues().get(1).getDOM().getTextContent(), containsString("value2B")); + } + } + } + + assertTrue("Custom attribute 'customAttr1' not found in SAML response", foundCustomAttr1); + assertTrue("Custom attribute 'customAttr2' not found in SAML response", foundCustomAttr2); + } + + private Response buildResponse(SamlInitiateSingleSignOnAttributes customAttributes) throws Exception { final Clock clock = Clock.systemUTC(); final SamlServiceProvider sp = mock(SamlServiceProvider.class); final String baseServiceUrl = "https://" + randomAlphaOfLength(32) + ".us-east-1.aws.found.io/"; final String acs = baseServiceUrl + "api/security/saml/callback"; when(sp.getEntityId()).thenReturn(baseServiceUrl); - when(sp.getAssertionConsumerService()).thenReturn(new URL(acs)); + when(sp.getAssertionConsumerService()).thenReturn(URI.create(acs).toURL()); when(sp.getAuthnExpiry()).thenReturn(Duration.ofMinutes(10)); when(sp.getAttributeNames()).thenReturn(new SamlServiceProvider.AttributeNames("principal", null, null, null)); @@ -75,7 +126,7 @@ private Response buildResponse() throws Exception { clock, idp ); - return builder.build(user, null); + return builder.build(user, null, customAttributes); } } diff --git a/x-pack/plugin/identity-provider/src/test/java/org/elasticsearch/xpack/idp/saml/support/SamlInitiateSingleSignOnAttributesTests.java b/x-pack/plugin/identity-provider/src/test/java/org/elasticsearch/xpack/idp/saml/support/SamlInitiateSingleSignOnAttributesTests.java new file mode 100644 index 0000000000000..5bf38ad8e8260 --- /dev/null +++ b/x-pack/plugin/identity-provider/src/test/java/org/elasticsearch/xpack/idp/saml/support/SamlInitiateSingleSignOnAttributesTests.java @@ -0,0 +1,207 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.idp.saml.support; + +import org.elasticsearch.action.ActionRequestValidationException; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.io.stream.InputStreamStreamInput; +import org.elasticsearch.common.io.stream.OutputStreamStreamOutput; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.ToXContent; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.json.JsonXContent; +import org.hamcrest.Matchers; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.equalTo; + +public class SamlInitiateSingleSignOnAttributesTests extends ESTestCase { + + public void testConstructors() throws Exception { + final SamlInitiateSingleSignOnAttributes attributes1 = new SamlInitiateSingleSignOnAttributes(Collections.emptyMap()); + assertThat(attributes1.getAttributes(), Matchers.anEmptyMap()); + + // Test with non-empty map + Map> attributeMap = new HashMap<>(); + attributeMap.put("key1", Collections.singletonList("value1")); + final SamlInitiateSingleSignOnAttributes attributes3 = new SamlInitiateSingleSignOnAttributes(attributeMap); + assertThat(attributes3.getAttributes().size(), equalTo(1)); + } + + public void testEmptyAttributes() throws Exception { + final SamlInitiateSingleSignOnAttributes attributes = new SamlInitiateSingleSignOnAttributes(Collections.emptyMap()); + + // Test toXContent + XContentBuilder builder = XContentFactory.jsonBuilder(); + attributes.toXContent(builder, ToXContent.EMPTY_PARAMS); + String json = BytesReference.bytes(builder).utf8ToString(); + + final SamlInitiateSingleSignOnAttributes parsedAttributes = parseFromJson(json); + assertThat(parsedAttributes.getAttributes(), Matchers.anEmptyMap()); + + // Test serialization + SamlInitiateSingleSignOnAttributes serialized = copySerialize(attributes); + assertThat(serialized.getAttributes(), Matchers.anEmptyMap()); + } + + public void testWithAttributes() throws Exception { + Map> attributeMap = new HashMap<>(); + attributeMap.put("key1", Arrays.asList("value1", "value2")); + attributeMap.put("key2", Collections.singletonList("value3")); + final SamlInitiateSingleSignOnAttributes attributes = new SamlInitiateSingleSignOnAttributes(attributeMap); + + // Test getAttributes + Map> returnedAttributes = attributes.getAttributes(); + assertThat(returnedAttributes.size(), equalTo(2)); + assertThat(returnedAttributes.get("key1").size(), equalTo(2)); + assertThat(returnedAttributes.get("key1").get(0), equalTo("value1")); + assertThat(returnedAttributes.get("key1").get(1), equalTo("value2")); + assertThat(returnedAttributes.get("key2").size(), equalTo(1)); + assertThat(returnedAttributes.get("key2").get(0), equalTo("value3")); + + // Test immutability of returned attributes + expectThrows(UnsupportedOperationException.class, () -> returnedAttributes.put("newKey", Collections.singletonList("value"))); + expectThrows(UnsupportedOperationException.class, () -> returnedAttributes.get("key1").add("value3")); + + // Test validate + ActionRequestValidationException validationException = attributes.validate(); + assertNull(validationException); + + // Test toXContent + XContentBuilder builder = XContentFactory.jsonBuilder(); + attributes.toXContent(builder, ToXContent.EMPTY_PARAMS); + String json = BytesReference.bytes(builder).utf8ToString(); + + // Test parsing from JSON + final SamlInitiateSingleSignOnAttributes parsedAttributes = parseFromJson(json); + assertThat(parsedAttributes.getAttributes().size(), equalTo(2)); + assertThat(parsedAttributes.getAttributes().get("key1").size(), equalTo(2)); + assertThat(parsedAttributes.getAttributes().get("key1").get(0), equalTo("value1")); + assertThat(parsedAttributes.getAttributes().get("key1").get(1), equalTo("value2")); + assertThat(parsedAttributes.getAttributes().get("key2").size(), equalTo(1)); + assertThat(parsedAttributes.getAttributes().get("key2").get(0), equalTo("value3")); + + // Test serialization + SamlInitiateSingleSignOnAttributes serialized = copySerialize(attributes); + assertThat(serialized.getAttributes().size(), equalTo(2)); + assertThat(serialized.getAttributes().get("key1").size(), equalTo(2)); + assertThat(serialized.getAttributes().get("key1").get(0), equalTo("value1")); + assertThat(serialized.getAttributes().get("key1").get(1), equalTo("value2")); + assertThat(serialized.getAttributes().get("key2").size(), equalTo(1)); + assertThat(serialized.getAttributes().get("key2").get(0), equalTo("value3")); + } + + public void testToString() { + Map> attributeMap = new HashMap<>(); + attributeMap.put("key1", Arrays.asList("value1", "value2")); + final SamlInitiateSingleSignOnAttributes attributes = new SamlInitiateSingleSignOnAttributes(attributeMap); + + String toString = attributes.toString(); + assertThat(toString, containsString("SamlInitiateSingleSignOnAttributes")); + assertThat(toString, containsString("key1")); + assertThat(toString, containsString("value1")); + assertThat(toString, containsString("value2")); + + // Test empty attributes + final SamlInitiateSingleSignOnAttributes emptyAttributes = new SamlInitiateSingleSignOnAttributes(Collections.emptyMap()); + toString = emptyAttributes.toString(); + assertThat(toString, containsString("SamlInitiateSingleSignOnAttributes")); + assertThat(toString, containsString("attributes={}")); + } + + public void testValidation() throws Exception { + // Test validation with empty key + Map> attributeMap = new HashMap<>(); + attributeMap.put("", Arrays.asList("value1", "value2")); + SamlInitiateSingleSignOnAttributes attributes = new SamlInitiateSingleSignOnAttributes(attributeMap); + + ActionRequestValidationException validationException = attributes.validate(); + assertNotNull(validationException); + assertThat(validationException.getMessage(), containsString("attribute key cannot be null or empty")); + + // Test validation with null key + attributeMap = new HashMap<>(); + attributeMap.put(null, Collections.singletonList("value")); + attributes = new SamlInitiateSingleSignOnAttributes(attributeMap); + + validationException = attributes.validate(); + assertNotNull(validationException); + assertThat(validationException.getMessage(), containsString("attribute key cannot be null or empty")); + } + + public void testEqualsAndHashCode() { + Map> attributeMap1 = new HashMap<>(); + attributeMap1.put("key1", Arrays.asList("value1", "value2")); + attributeMap1.put("key2", Collections.singletonList("value3")); + + SamlInitiateSingleSignOnAttributes attributes1 = new SamlInitiateSingleSignOnAttributes(attributeMap1); + + Map> attributeMap2 = new HashMap<>(); + attributeMap2.put("key1", Arrays.asList("value1", "value2")); + attributeMap2.put("key2", Collections.singletonList("value3")); + + SamlInitiateSingleSignOnAttributes attributes2 = new SamlInitiateSingleSignOnAttributes(attributeMap2); + + // Test equals + assertEquals(attributes1, attributes2); + assertEquals(attributes2, attributes1); + + // Test hashCode + assertThat(attributes1.hashCode(), equalTo(attributes2.hashCode())); + + // Test with different values + Map> attributeMap3 = new HashMap<>(); + attributeMap3.put("key1", Arrays.asList("different", "value2")); + attributeMap3.put("key2", Collections.singletonList("value3")); + + SamlInitiateSingleSignOnAttributes attributes3 = new SamlInitiateSingleSignOnAttributes(attributeMap3); + + assertNotEquals(attributes1, attributes3); + + // Test with missing key + Map> attributeMap4 = new HashMap<>(); + attributeMap4.put("key1", Arrays.asList("value1", "value2")); + + SamlInitiateSingleSignOnAttributes attributes4 = new SamlInitiateSingleSignOnAttributes(attributeMap4); + + assertNotEquals(attributes1, attributes4); + } + + private SamlInitiateSingleSignOnAttributes parseFromJson(String json) throws IOException { + try ( + InputStream stream = new ByteArrayInputStream(json.getBytes("UTF-8")); + XContentParser parser = JsonXContent.jsonXContent.createParser(null, null, stream) + ) { + parser.nextToken(); // Start object + return SamlInitiateSingleSignOnAttributes.fromXContent(parser); + } + } + + private SamlInitiateSingleSignOnAttributes copySerialize(SamlInitiateSingleSignOnAttributes original) throws IOException { + ByteArrayOutputStream outBuffer = new ByteArrayOutputStream(); + OutputStreamStreamOutput out = new OutputStreamStreamOutput(outBuffer); + original.writeTo(out); + out.flush(); + + ByteArrayInputStream inBuffer = new ByteArrayInputStream(outBuffer.toByteArray()); + InputStreamStreamInput in = new InputStreamStreamInput(inBuffer); + return new SamlInitiateSingleSignOnAttributes(in); + } +}