Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 39 additions & 4 deletions schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ type AuthnRequest struct {
NameIDPolicy *NameIDPolicy `xml:"urn:oasis:names:tc:SAML:2.0:protocol NameIDPolicy"`
Conditions *Conditions
RequestedAuthnContext *RequestedAuthnContext
// Scoping *Scoping // TODO
Scoping *Scoping

ForceAuthn *bool `xml:",attr"`
IsPassive *bool `xml:",attr"`
Expand Down Expand Up @@ -209,9 +209,9 @@ func (r *AuthnRequest) Element() *etree.Element {
if r.RequestedAuthnContext != nil {
el.AddChild(r.RequestedAuthnContext.Element())
}
// if r.Scoping != nil {
// el.AddChild(r.Scoping.Element())
// }
if r.Scoping != nil {
el.AddChild(r.Scoping.Element())
}
if r.ForceAuthn != nil {
el.CreateAttr("ForceAuthn", strconv.FormatBool(*r.ForceAuthn))
}
Expand Down Expand Up @@ -321,6 +321,41 @@ func (a *NameIDPolicy) Element() *etree.Element {
return el
}

// Scoping represents the SAML object of the same name.
//
// See http://docs.oasis-open.org/security/saml/v2.0/saml-core-2.0-os.pdf § 3.4.1.2
type Scoping struct {
XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:protocol Scoping"`
ProxyCount *int `xml:",attr"`
IDPList []string `xml:"urn:oasis:names:tc:SAML:2.0:protocol IDPList"` // Only supports IDEntry, TODO support GetComplete{uri}
RequesterIDs []string `xml:"urn:oasis:names:tc:SAML:2.0:protocol RequesterID"`
}

// Element returns an etree.Element representing the object in XML form.
func (a *Scoping) Element() *etree.Element {
el := etree.NewElement("samlp:Scoping")
if a.ProxyCount != nil {
el.CreateAttr("ProxyCount", strconv.Itoa(*a.ProxyCount))
}
if len(a.IDPList) > 0 {
idpList := etree.NewElement("samlp:IDPList")
for _, idp := range a.IDPList {
idpEntry := etree.NewElement("samlp:IDPEntry")
idpEntry.CreateAttr("ProviderID", idp)
idpList.AddChild(idpEntry)
}
el.AddChild(idpList)
}
if len(a.RequesterIDs) > 0 {
for _, requesterID := range a.RequesterIDs {
requesterIDEntry := etree.NewElement("samlp:RequesterIDEntry")
requesterIDEntry.CreateAttr("ProviderID", requesterID)
el.AddChild(requesterIDEntry)
}
}
return el
}

// ArtifactResolve represents the SAML object of the same name.
type ArtifactResolve struct {
XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:protocol ArtifactResolve"`
Expand Down
17 changes: 17 additions & 0 deletions schema_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -277,3 +277,20 @@ func TestLogoutRequestMarshalWithoutNotOnOrAfter(t *testing.T) {
assert.Check(t, err)
assert.Check(t, is.DeepEqual(expected, actual))
}

func TestScopingElement(t *testing.T) {
proxyCount := 2
expected := AuthnRequest{Scoping: &Scoping{
XMLName: xml.Name{Space: "urn:oasis:names:tc:SAML:2.0:protocol", Local: "Scoping"},
IDPList: []string{"idp1", "idp2"},
ProxyCount: &proxyCount,
RequesterIDs: []string{"http://uri"},
}}

doc := etree.NewDocument()
doc.SetRoot(expected.Element())
x, err := doc.WriteToBytes()
assert.Check(t, err)
assert.Check(t, is.Equal(`<samlp:AuthnRequest xmlns:saml="urn:oasis:names:tc:SAML:2.0:assertion" xmlns:samlp="urn:oasis:names:tc:SAML:2.0:protocol" ID="" Version="" IssueInstant="0001-01-01T00:00:00Z"><samlp:Scoping ProxyCount="2"><samlp:IDPList><samlp:IDPEntry ProviderID="idp1"/><samlp:IDPEntry ProviderID="idp2"/></samlp:IDPList><samlp:RequesterIDEntry ProviderID="http://uri"/></samlp:Scoping></samlp:AuthnRequest>`,
string(x)))
}
10 changes: 10 additions & 0 deletions service_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,10 @@ type ServiceProvider struct {
// authentication requests
AuthnNameIDFormat NameIDFormat

// IDPList is a list of identity providers that are allowed to authenticate users. Send as part of AuthnRequest.Scoping.
// If empty, any delegate IDP can be used.
IDPList []string

// MetadataValidDuration is a duration used to calculate validUntil
// attribute in the metadata endpoint
MetadataValidDuration time.Duration
Expand Down Expand Up @@ -536,6 +540,7 @@ func (sp *ServiceProvider) MakeAuthenticationRequest(idpURL string, binding stri
Format: "urn:oasis:names:tc:SAML:2.0:nameid-format:entity",
Value: firstSet(sp.EntityID, sp.MetadataURL.String()),
},

NameIDPolicy: &NameIDPolicy{
AllowCreate: &allowCreate,
// TODO(ross): figure out exactly policy we need
Expand All @@ -546,6 +551,11 @@ func (sp *ServiceProvider) MakeAuthenticationRequest(idpURL string, binding stri
ForceAuthn: sp.ForceAuthn,
RequestedAuthnContext: sp.RequestedAuthnContext,
}
if len(sp.IDPList) > 0 {
req.Scoping = &Scoping{
IDPList: sp.IDPList,
}
}
// We don't need to sign the XML document if the IDP uses HTTP-Redirect binding
if len(sp.SignatureMethod) > 0 && binding == HTTPPostBinding {
if err := sp.SignAuthnRequest(&req); err != nil {
Expand Down