Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/128805.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 128805
summary: Add "extension" attribute validation to IdP SPs
area: IdentityProvider
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@ static TransportVersion def(int id) {
public static final TransportVersion ML_INFERENCE_MISTRAL_CHAT_COMPLETION_ADDED_8_19 = def(8_841_0_47);
public static final TransportVersion ML_INFERENCE_ELASTIC_RERANK_ADDED_8_19 = def(8_841_0_48);
public static final TransportVersion NONE_CHUNKING_STRATEGY_8_19 = def(8_841_0_49);
public static final TransportVersion IDP_CUSTOM_SAML_ATTRIBUTES_ALLOW_LIST_8_19 = def(8_841_0_50);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@
},
"roles": {
"type": "keyword"
},
"extensions": {
"type": "keyword"
}
}
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,8 @@ public void testCustomAttributesInIdpInitiatedSso() throws Exception {
Map.entry("principal", "https://idp.test.es.elasticsearch.org/attribute/principal"),
Map.entry("name", "https://idp.test.es.elasticsearch.org/attribute/name"),
Map.entry("email", "https://idp.test.es.elasticsearch.org/attribute/email"),
Map.entry("roles", "https://idp.test.es.elasticsearch.org/attribute/roles")
Map.entry("roles", "https://idp.test.es.elasticsearch.org/attribute/roles"),
Map.entry("extensions", List.of("department", "region"))
)
)
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
import java.util.Map;
import java.util.Set;

import static org.elasticsearch.common.Strings.collectionToCommaDelimitedString;
import static org.opensaml.saml.saml2.core.NameIDType.TRANSIENT;

/**
Expand Down Expand Up @@ -214,28 +215,42 @@ private AttributeStatement buildAttributes(
final SamlServiceProvider serviceProvider = user.getServiceProvider();
final AttributeStatement statement = samlFactory.object(AttributeStatement.class, AttributeStatement.DEFAULT_ELEMENT_NAME);
final List<Attribute> attributes = new ArrayList<>();
final Attribute roles = buildAttribute(serviceProvider.getAttributeNames().roles, "roles", user.getRoles());
final SamlServiceProvider.AttributeNames attributeNames = serviceProvider.getAttributeNames();
final Attribute roles = buildAttribute(attributeNames.roles, "roles", user.getRoles());
if (roles != null) {
attributes.add(roles);
}
final Attribute principal = buildAttribute(serviceProvider.getAttributeNames().principal, "principal", user.getPrincipal());
final Attribute principal = buildAttribute(attributeNames.principal, "principal", user.getPrincipal());
if (principal != null) {
attributes.add(principal);
}
final Attribute email = buildAttribute(serviceProvider.getAttributeNames().email, "email", user.getEmail());
final Attribute email = buildAttribute(attributeNames.email, "email", user.getEmail());
if (email != null) {
attributes.add(email);
}
final Attribute name = buildAttribute(serviceProvider.getAttributeNames().name, "name", user.getName());
final Attribute name = buildAttribute(attributeNames.name, "name", user.getName());
if (name != null) {
attributes.add(name);
}
// Add custom attributes if provided
if (customAttributes != null && customAttributes.getAttributes().isEmpty() == false) {
for (Map.Entry<String, List<String>> entry : customAttributes.getAttributes().entrySet()) {
Attribute attribute = buildAttribute(entry.getKey(), null, entry.getValue());
if (attribute != null) {
attributes.add(attribute);
final String attributeName = entry.getKey();
if (attributeNames.isAllowedExtension(attributeName)) {
Attribute attribute = buildAttribute(attributeName, null, entry.getValue());
if (attribute != null) {
attributes.add(attribute);
}
} else {
throw new IllegalArgumentException(
"the custom attribute ["
+ attributeName
+ "] is not permitted for the Service Provider ["
+ serviceProvider.getName()
+ "], allowed attribute names are ["
+ collectionToCommaDelimitedString(attributeNames.allowedExtensions)
+ "]"
);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,18 @@ class AttributeNames {
public final String name;
public final String email;
public final String roles;
public final Set<String> allowedExtensions;

public AttributeNames(String principal, String name, String email, String roles) {
public AttributeNames(String principal, String name, String email, String roles, Set<String> allowedExtensions) {
this.principal = principal;
this.name = name;
this.email = email;
this.roles = roles;
this.allowedExtensions = allowedExtensions == null ? Set.of() : Set.copyOf(allowedExtensions);
}

public boolean isAllowedExtension(String attributeName) {
return this.allowedExtensions.contains(attributeName);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@
import java.util.TreeSet;
import java.util.function.BiConsumer;

import static org.elasticsearch.TransportVersions.IDP_CUSTOM_SAML_ATTRIBUTES_ALLOW_LIST_8_19;

/**
* This class models the storage of a {@link SamlServiceProvider} as an Elasticsearch document.
*/
Expand Down Expand Up @@ -87,6 +89,12 @@ public static class AttributeNames {
@Nullable
public String roles;

/**
* Extensions are attributes that are provided at runtime (by the trusted client that initiates
* the SAML SSO. They are sourced from the rest request rather than the user object itself.
*/
public Set<String> extensions = Set.of();

public void setPrincipal(String principal) {
this.principal = principal;
}
Expand All @@ -103,6 +111,10 @@ public void setRoles(String roles) {
this.roles = roles;
}

public void setExtensions(Collection<String> names) {
this.extensions = names == null ? Set.of() : Set.copyOf(names);
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
Expand All @@ -111,7 +123,8 @@ public boolean equals(Object o) {
return Objects.equals(principal, that.principal)
&& Objects.equals(email, that.email)
&& Objects.equals(name, that.name)
&& Objects.equals(roles, that.roles);
&& Objects.equals(roles, that.roles)
&& Objects.equals(extensions, that.extensions);
}

@Override
Expand Down Expand Up @@ -263,6 +276,10 @@ public SamlServiceProviderDocument(StreamInput in) throws IOException {
attributeNames.name = in.readOptionalString();
attributeNames.roles = in.readOptionalString();

if (in.getTransportVersion().onOrAfter(IDP_CUSTOM_SAML_ATTRIBUTES_ALLOW_LIST_8_19)) {
attributeNames.extensions = in.readCollectionAsImmutableSet(StreamInput::readString);
}

certificates.serviceProviderSigning = in.readStringCollectionAsList();
certificates.identityProviderSigning = in.readStringCollectionAsList();
certificates.identityProviderMetadataSigning = in.readStringCollectionAsList();
Expand All @@ -288,6 +305,10 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalString(attributeNames.name);
out.writeOptionalString(attributeNames.roles);

if (out.getTransportVersion().onOrAfter(IDP_CUSTOM_SAML_ATTRIBUTES_ALLOW_LIST_8_19)) {
out.writeStringCollection(attributeNames.extensions);
}

out.writeStringCollection(certificates.serviceProviderSigning);
out.writeStringCollection(certificates.identityProviderSigning);
out.writeStringCollection(certificates.identityProviderMetadataSigning);
Expand Down Expand Up @@ -426,6 +447,7 @@ public int hashCode() {
ATTRIBUTES_PARSER.declareStringOrNull(AttributeNames::setEmail, Fields.Attributes.EMAIL);
ATTRIBUTES_PARSER.declareStringOrNull(AttributeNames::setName, Fields.Attributes.NAME);
ATTRIBUTES_PARSER.declareStringOrNull(AttributeNames::setRoles, Fields.Attributes.ROLES);
ATTRIBUTES_PARSER.declareStringArray(AttributeNames::setExtensions, Fields.Attributes.EXTENSIONS);

DOC_PARSER.declareObject(NULL_CONSUMER, (p, doc) -> CERTIFICATES_PARSER.parse(p, doc.certificates, null), Fields.CERTIFICATES);
CERTIFICATES_PARSER.declareStringArray(Certificates::setServiceProviderSigning, Fields.Certificates.SP_SIGNING);
Expand Down Expand Up @@ -516,6 +538,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.field(Fields.Attributes.EMAIL.getPreferredName(), attributeNames.email);
builder.field(Fields.Attributes.NAME.getPreferredName(), attributeNames.name);
builder.field(Fields.Attributes.ROLES.getPreferredName(), attributeNames.roles);
if (attributeNames.extensions != null && attributeNames.extensions.isEmpty() == false) {
builder.field(Fields.Attributes.EXTENSIONS.getPreferredName(), attributeNames.extensions);
}
builder.endObject();

builder.startObject(Fields.CERTIFICATES.getPreferredName());
Expand Down Expand Up @@ -553,6 +578,7 @@ interface Attributes {
ParseField EMAIL = new ParseField("email");
ParseField NAME = new ParseField("name");
ParseField ROLES = new ParseField("roles");
ParseField EXTENSIONS = new ParseField("extensions");
}

interface Certificates {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ SamlServiceProvider buildServiceProvider(SamlServiceProviderDocument document) {
document.attributeNames.principal,
document.attributeNames.name,
document.attributeNames.email,
document.attributeNames.roles
document.attributeNames.roles,
document.attributeNames.extensions
);
final Set<X509Credential> credentials = document.certificates.getServiceProviderX509SigningCertificates()
.stream()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,26 @@ public class SamlServiceProviderIndex implements Closeable {
static final String TEMPLATE_VERSION_STRING_DEPRECATED = "idp.template.version_deprecated";
static final String FINAL_TEMPLATE_VERSION_STRING_DEPRECATED = "8.14.0";

static final int CURRENT_TEMPLATE_VERSION = 1;
/**
* The object in the index mapping metadata that contains a version field
*/
private static final String INDEX_META_FIELD = "_meta";
/**
* The field in the {@link #INDEX_META_FIELD} metadata that contains the version number
*/
private static final String TEMPLATE_VERSION_META_FIELD = "idp-template-version";

/**
* The first version of this template (since it was moved to use {@link org.elasticsearch.xpack.core.template.IndexTemplateRegistry}
*/
private static final int VERSION_ORIGINAL = 1;
/**
* The version that added the {@code attributes.extensions} field to the SAML SP document
*/
private static final int VERSION_EXTENSION_ATTRIBUTES = 2;
static final int CURRENT_TEMPLATE_VERSION = VERSION_EXTENSION_ATTRIBUTES;

private volatile boolean indexUpToDate = false;

public static final class DocumentVersion {
public final String id;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,8 @@ private TransportSamlInitiateSingleSignOnAction setupTransportAction(boolean wit
"https://saml.elasticsearch.org/attributes/principal",
"https://saml.elasticsearch.org/attributes/name",
"https://saml.elasticsearch.org/attributes/email",
"https://saml.elasticsearch.org/attributes/roles"
"https://saml.elasticsearch.org/attributes/roles",
Set.of()
),
null,
false,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

package org.elasticsearch.xpack.idp.saml.authn;

import org.elasticsearch.core.Nullable;
import org.elasticsearch.xpack.idp.saml.idp.SamlIdentityProvider;
import org.elasticsearch.xpack.idp.saml.sp.SamlServiceProvider;
import org.elasticsearch.xpack.idp.saml.sp.ServiceProviderDefaults;
Expand Down Expand Up @@ -103,7 +104,46 @@ public void testSignedResponseWithCustomAttributes() throws Exception {
assertTrue("Custom attribute 'customAttr2' not found in SAML response", foundCustomAttr2);
}

private Response buildResponse(SamlInitiateSingleSignOnAttributes customAttributes) throws Exception {
public void testRejectInvalidCustomAttributes() throws Exception {
final var customAttributes = new SamlInitiateSingleSignOnAttributes(
Map.of("https://idp.example.org/attribute/department", Collections.singletonList("engineering"))
);

// Build response with custom attributes
final IllegalArgumentException ex = expectThrows(
IllegalArgumentException.class,
() -> buildResponse(
new SamlServiceProvider.AttributeNames(
"https://idp.example.org/attribute/principal",
null,
null,
null,
Set.of("https://idp.example.org/attribute/avatar")
),
customAttributes
)
);
assertThat(ex.getMessage(), containsString("custom attribute [https://idp.example.org/attribute/department]"));
assertThat(ex.getMessage(), containsString("allowed attribute names are [https://idp.example.org/attribute/avatar]"));
}

private Response buildResponse(@Nullable SamlInitiateSingleSignOnAttributes customAttributes) throws Exception {
return buildResponse(
new SamlServiceProvider.AttributeNames(
"principal",
null,
null,
null,
customAttributes == null ? Set.of() : customAttributes.getAttributes().keySet()
),
customAttributes
);
}

private Response buildResponse(
final SamlServiceProvider.AttributeNames attributes,
@Nullable SamlInitiateSingleSignOnAttributes customAttributes
) throws Exception {
final Clock clock = Clock.systemUTC();

final SamlServiceProvider sp = mock(SamlServiceProvider.class);
Expand All @@ -112,7 +152,7 @@ private Response buildResponse(SamlInitiateSingleSignOnAttributes customAttribut
when(sp.getEntityId()).thenReturn(baseServiceUrl);
when(sp.getAssertionConsumerService()).thenReturn(URI.create(acs).toURL());
when(sp.getAuthnExpiry()).thenReturn(Duration.ofMinutes(10));
when(sp.getAttributeNames()).thenReturn(new SamlServiceProvider.AttributeNames("principal", null, null, null));
when(sp.getAttributeNames()).thenReturn(attributes);

final UserServiceAuthentication user = mock(UserServiceAuthentication.class);
when(user.getPrincipal()).thenReturn(randomAlphaOfLengthBetween(4, 12));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import java.util.Set;

import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent;
import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.emptyIterable;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.not;
Expand Down Expand Up @@ -90,6 +91,25 @@ public void testStreamRoundTripWithAllFields() throws Exception {
assertThat(assertSerializationRoundTrip(doc2), equalTo(doc1));
}

public void testSerializationBeforeExtensionAttributes() throws Exception {
final SamlServiceProviderDocument original = createFullDocument();
final TransportVersion version = TransportVersionUtils.randomVersionBetween(
random(),
TransportVersions.V_7_7_0,
TransportVersionUtils.getPreviousVersion(TransportVersions.IDP_CUSTOM_SAML_ATTRIBUTES_ALLOW_LIST_8_19)
);
final SamlServiceProviderDocument copy = copyWriteable(
original,
new NamedWriteableRegistry(List.of()),
SamlServiceProviderDocument::new,
version
);
assertThat(copy.attributeNames.extensions, empty());

copy.attributeNames.setExtensions(original.attributeNames.extensions);
assertThat(copy, equalTo(original));
}

private SamlServiceProviderDocument createFullDocument() throws GeneralSecurityException, IOException {
final List<X509Credential> credentials = readCredentials();
final List<X509Certificate> certificates = credentials.stream().map(X509Credential::getEntityCertificate).toList();
Expand Down Expand Up @@ -121,6 +141,7 @@ private SamlServiceProviderDocument createFullDocument() throws GeneralSecurityE
doc1.attributeNames.setEmail("urn:" + randomAlphaOfLengthBetween(4, 8) + "." + randomAlphaOfLengthBetween(4, 8));
doc1.attributeNames.setName("urn:" + randomAlphaOfLengthBetween(4, 8) + "." + randomAlphaOfLengthBetween(4, 8));
doc1.attributeNames.setRoles("urn:" + randomAlphaOfLengthBetween(4, 8) + "." + randomAlphaOfLengthBetween(4, 8));
doc1.attributeNames.setExtensions(List.of("urn:" + randomAlphaOfLengthBetween(4, 8) + "." + randomAlphaOfLengthBetween(4, 8)));
return doc1;
}

Expand Down Expand Up @@ -162,7 +183,7 @@ private SamlServiceProviderDocument assertXContentRoundTrip(SamlServiceProviderD
private SamlServiceProviderDocument assertSerializationRoundTrip(SamlServiceProviderDocument doc) throws IOException {
final TransportVersion version = TransportVersionUtils.randomVersionBetween(
random(),
TransportVersions.V_7_7_0,
TransportVersions.IDP_CUSTOM_SAML_ATTRIBUTES_ALLOW_LIST_8_19,
TransportVersion.current()
);
final SamlServiceProviderDocument read = copyWriteable(
Expand Down
Loading