Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
15 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
with:
go-version: stable
go-version-file: go.mod
- name: golangci-lint
uses: golangci/golangci-lint-action@v7
with:
Expand Down
10 changes: 5 additions & 5 deletions go.mod
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
module github.com/crewjam/saml

go 1.22
go 1.24.0

require (
github.com/golang-jwt/jwt/v5 v5.2.2
github.com/beevik/etree v1.5.0
github.com/golang-jwt/jwt/v5 v5.3.1
github.com/google/go-cmp v0.7.0
github.com/mattermost/xml-roundtrip-validator v0.1.0
github.com/russellhaering/goxmldsig v1.4.0
golang.org/x/crypto v0.33.0
github.com/russellhaering/goxmldsig v1.5.0
golang.org/x/crypto v0.48.0
gotest.tools v2.2.0+incompatible
)

require (
github.com/jonboulle/clockwork v0.2.2 // indirect
github.com/jonboulle/clockwork v0.5.0 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/stretchr/testify v1.10.0 // indirect
)
31 changes: 8 additions & 23 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,45 +1,30 @@
github.com/beevik/etree v1.1.0/go.mod h1:r8Aw8JqVegEf0w2fDnATrX9VpkMcyFeM0FhwO62wh+A=
github.com/beevik/etree v1.5.0 h1:iaQZFSDS+3kYZiGoc9uKeOkUY3nYMXOKLl6KIJxiJWs=
github.com/beevik/etree v1.5.0/go.mod h1:gPNJNaBGVZ9AwsidazFZyygnd+0pAU38N4D+WemwKNs=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8=
github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/golang-jwt/jwt/v5 v5.3.1 h1:kYf81DTWFe7t+1VvL7eS+jKFVWaUnK9cB1qbwn63YCY=
github.com/golang-jwt/jwt/v5 v5.3.1/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/jonboulle/clockwork v0.2.2 h1:UOGuzwb1PwsrDAObMuhUnj0p5ULPj8V/xJ7Kx9qUBdQ=
github.com/jonboulle/clockwork v0.2.2/go.mod h1:Pkfl5aHPm1nk2H9h0bjmnJD/BcgbGXUBGnn1kMkgxc8=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/jonboulle/clockwork v0.5.0 h1:Hyh9A8u51kptdkR+cqRpT1EebBwTn1oK9YfGYbdFz6I=
github.com/jonboulle/clockwork v0.5.0/go.mod h1:3mZlmanh0g2NDKO5TWZVJAfofYk64M7XN3SzBPjZF60=
github.com/mattermost/xml-roundtrip-validator v0.1.0 h1:RXbVD2UAl7A7nOTR4u7E3ILa4IbtvKBHw64LDsmu9hU=
github.com/mattermost/xml-roundtrip-validator v0.1.0/go.mod h1:qccnGMcpgwcNaBnxqpJpWWUiPNr5H3O8eDgGV9gT5To=
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc=
github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE=
github.com/russellhaering/goxmldsig v1.4.0 h1:8UcDh/xGyQiyrW+Fq5t8f+l2DLB1+zlhYzkPUJ7Qhys=
github.com/russellhaering/goxmldsig v1.4.0/go.mod h1:gM4MDENBQf7M+V824SGfyIUVFWydB7n0KkEubVJl+Tw=
github.com/russellhaering/goxmldsig v1.5.0 h1:AU2UkkYIUOTyZRbe08XMThaOCelArgvNfYapcmSjBNw=
github.com/russellhaering/goxmldsig v1.5.0/go.mod h1:x98CjQNFJcWfMxeOrMnMKg70lvDP6tE0nTaeUnjXDmk=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
golang.org/x/crypto v0.33.0 h1:IOBPskki6Lysi0lo9qQvbxiQ+FvsCC/YWOecCHAixus=
golang.org/x/crypto v0.33.0/go.mod h1:bVdXmD7IV/4GdElGPozy6U7lWdRXA4qyRVGJV57uQ5M=
golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts=
golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gotest.tools v2.2.0+incompatible h1:VsBPFP1AI068pPrMxtb/S8Zkgf9xEmTLJjfM+P5UIEo=
Expand Down
271 changes: 271 additions & 0 deletions samlsp/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@ package samlsp

import (
"bytes"
"crypto"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
"crypto/x509"
Expand All @@ -17,6 +22,7 @@ import (
"testing"
"time"

"github.com/golang-jwt/jwt/v5"
dsig "github.com/russellhaering/goxmldsig"
"gotest.tools/assert"
is "gotest.tools/assert/cmp"
Expand Down Expand Up @@ -520,3 +526,268 @@ func TestMiddlewareHandlesInvalidResponse(t *testing.T) {
assert.Check(t, is.Equal("", resp.Header().Get("Location")))
assert.Check(t, is.Equal("", resp.Header().Get("Set-Cookie")))
}

type mockSigner struct {
signer crypto.Signer
}

func (m *mockSigner) Public() crypto.PublicKey {
return m.signer.Public()
}

func (m *mockSigner) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) ([]byte, error) {
return m.signer.Sign(rand, digest, opts)
}

func newMockRSASigner(t *testing.T) crypto.Signer {
key := mustParsePrivateKey(golden.Get(t, "key.pem"))
return &mockSigner{signer: key.(crypto.Signer)}
}

func TestMiddleware_WithCryptoSignerE2E(t *testing.T) {
origTimeNow := saml.TimeNow
origClock := saml.Clock
origRandReader := saml.RandReader
t.Cleanup(func() {
saml.TimeNow = origTimeNow
saml.Clock = origClock
saml.RandReader = origRandReader
})

saml.TimeNow = func() time.Time {
rv, _ := time.Parse("Mon Jan 2 15:04:05.999999999 MST 2006", "Mon Dec 1 01:57:09.123456789 UTC 2015")
return rv
}
saml.Clock = dsig.NewFakeClockAt(saml.TimeNow())
saml.RandReader = &testRandomReader{}

cert := mustParseCertificate(golden.Get(t, "cert.pem"))
idpMetadata := golden.Get(t, "idp_metadata.xml")

var metadata saml.EntityDescriptor
if err := xml.Unmarshal(idpMetadata, &metadata); err != nil {
panic(err)
}

mockSigner := newMockRSASigner(t)

opts := Options{
URL: mustParseURL("https://15661444.ngrok.io/"),
Key: mockSigner,
Certificate: cert,
IDPMetadata: &metadata,
}

middleware, err := New(opts)
assert.Check(t, err)

sessionProvider := DefaultSessionProvider(opts)
sessionProvider.Name = "ttt"
sessionProvider.MaxAge = 7200 * time.Second

sessionCodec := sessionProvider.Codec.(JWTSessionCodec)
sessionCodec.MaxAge = 7200 * time.Second
sessionProvider.Codec = sessionCodec

middleware.Session = sessionProvider
middleware.ServiceProvider.MetadataURL.Path = "/saml2/metadata"
middleware.ServiceProvider.AcsURL.Path = "/saml2/acs"
middleware.ServiceProvider.SloURL.Path = "/saml2/slo"

t.Run("SessionEncodeDecode", func(t *testing.T) {
var tc JWTSessionClaims
if err := json.Unmarshal(golden.Get(t, "token.json"), &tc); err != nil {
t.Fatal(err)
}

encoded, err := sessionProvider.Codec.Encode(tc)
assert.Check(t, err)
assert.Assert(t, encoded != "")

decoded, err := sessionProvider.Codec.Decode(encoded)
assert.Check(t, err)
decodedClaims := decoded.(JWTSessionClaims)
assert.Equal(t, tc.Subject, decodedClaims.Subject)
})

t.Run("TrackedRequestEncodeDecode", func(t *testing.T) {
codec := middleware.RequestTracker.(CookieRequestTracker).Codec
trackedReq := TrackedRequest{
Index: "test-index",
SAMLRequestID: "test-request-id",
URI: "/test-uri",
}

encoded, err := codec.Encode(trackedReq)
assert.Check(t, err)
assert.Assert(t, encoded != "")

decoded, err := codec.Decode(encoded)
assert.Check(t, err)
assert.Equal(t, trackedReq.Index, decoded.Index)
assert.Equal(t, trackedReq.SAMLRequestID, decoded.SAMLRequestID)
})

t.Run("RequireAccountFlow", func(t *testing.T) {
handler := middleware.RequireAccount(
http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
panic("not reached")
}))

req, _ := http.NewRequest("GET", "/protected", nil)
resp := httptest.NewRecorder()
handler.ServeHTTP(resp, req)

assert.Check(t, is.Equal(http.StatusFound, resp.Code))
assert.Assert(t, resp.Header().Get("Location") != "")
assert.Assert(t, resp.Header().Get("Set-Cookie") != "")
})

t.Run("Metadata", func(t *testing.T) {
req, _ := http.NewRequest("GET", "/saml2/metadata", nil)
resp := httptest.NewRecorder()
middleware.ServeHTTP(resp, req)

assert.Check(t, is.Equal(http.StatusOK, resp.Code))
assert.Check(t, is.Equal("application/samlmetadata+xml",
resp.Header().Get("Content-type")))
golden.Assert(t, resp.Body.String(), "expected_middleware_metadata.xml")
})
}

func TestJWTSessionCodec_CryptoSignerEncodeDecode(t *testing.T) {
tests := []struct {
name string
method jwt.SigningMethod
genKey func(t *testing.T) crypto.Signer
subject string
}{
{
name: "ECDSA-P256",
method: jwt.SigningMethodES256,
genKey: func(t *testing.T) crypto.Signer {
k, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
assert.Check(t, err)
return k
},
subject: "test-ecdsa-p256",
},
{
name: "ECDSA-P384",
method: jwt.SigningMethodES384,
genKey: func(t *testing.T) crypto.Signer {
k, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader)
assert.Check(t, err)
return k
},
subject: "test-ecdsa-p384",
},
{
name: "ECDSA-P521",
method: jwt.SigningMethodES512,
genKey: func(t *testing.T) crypto.Signer {
k, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader)
assert.Check(t, err)
return k
},
subject: "test-ecdsa-p521",
},
{
name: "RSA-PSS",
method: jwt.SigningMethodPS256,
genKey: func(t *testing.T) crypto.Signer {
k, err := rsa.GenerateKey(rand.Reader, 2048)
assert.Check(t, err)
return k
},
subject: "test-rsa-pss",
},
{
name: "EdDSA",
method: jwt.SigningMethodEdDSA,
genKey: func(t *testing.T) crypto.Signer {
_, k, err := ed25519.GenerateKey(rand.Reader)
assert.Check(t, err)
return k
},
subject: "test-eddsa",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
now := time.Now()
origTimeNow := saml.TimeNow
t.Cleanup(func() { saml.TimeNow = origTimeNow })
saml.TimeNow = func() time.Time { return now }

signer := &mockSigner{signer: tt.genKey(t)}

audience := "https://example.com/"
codec := JWTSessionCodec{
SigningMethod: tt.method,
Audience: audience,
Issuer: audience,
MaxAge: time.Hour,
Key: signer,
}

tc := JWTSessionClaims{
RegisteredClaims: jwt.RegisteredClaims{
Audience: jwt.ClaimStrings{audience},
Issuer: audience,
Subject: tt.subject,
IssuedAt: jwt.NewNumericDate(now),
ExpiresAt: jwt.NewNumericDate(now.Add(time.Hour)),
NotBefore: jwt.NewNumericDate(now),
},
SAMLSession: true,
}

encoded, err := codec.Encode(tc)
assert.Check(t, err)
assert.Assert(t, encoded != "")

decoded, err := codec.Decode(encoded)
assert.Check(t, err)
decodedClaims := decoded.(JWTSessionClaims)
assert.Equal(t, tt.subject, decodedClaims.Subject)
})
}
}

func TestJWTSessionCodec_UnsupportedAlgorithmReturnsError(t *testing.T) {
now := time.Now()
origTimeNow := saml.TimeNow
t.Cleanup(func() { saml.TimeNow = origTimeNow })
saml.TimeNow = func() time.Time { return now }

rsaKey, err := rsa.GenerateKey(rand.Reader, 2048)
assert.Check(t, err)

signer := &mockSigner{signer: rsaKey}

audience := "https://example.com/"
codec := JWTSessionCodec{
SigningMethod: jwt.SigningMethodNone,
Audience: audience,
Issuer: audience,
MaxAge: time.Hour,
Key: signer,
}

tc := JWTSessionClaims{
RegisteredClaims: jwt.RegisteredClaims{
Audience: jwt.ClaimStrings{audience},
Issuer: audience,
Subject: "test",
IssuedAt: jwt.NewNumericDate(now),
ExpiresAt: jwt.NewNumericDate(now.Add(time.Hour)),
NotBefore: jwt.NewNumericDate(now),
},
SAMLSession: true,
}

_, err = codec.Encode(tc)
assert.Check(t, is.ErrorContains(err, "unsupported algorithm for crypto.Signer"))
}
Loading
Loading