@@ -2,6 +2,11 @@ package samlsp
22
33import (
44 "bytes"
5+ "crypto"
6+ "crypto/ecdsa"
7+ "crypto/ed25519"
8+ "crypto/elliptic"
9+ "crypto/rand"
510 "crypto/rsa"
611 "crypto/sha256"
712 "crypto/x509"
@@ -17,6 +22,7 @@ import (
1722 "testing"
1823 "time"
1924
25+ "github.com/golang-jwt/jwt/v5"
2026 dsig "github.com/russellhaering/goxmldsig"
2127 "gotest.tools/assert"
2228 is "gotest.tools/assert/cmp"
@@ -520,3 +526,268 @@ func TestMiddlewareHandlesInvalidResponse(t *testing.T) {
520526 assert .Check (t , is .Equal ("" , resp .Header ().Get ("Location" )))
521527 assert .Check (t , is .Equal ("" , resp .Header ().Get ("Set-Cookie" )))
522528}
529+
530+ type mockSigner struct {
531+ signer crypto.Signer
532+ }
533+
534+ func (m * mockSigner ) Public () crypto.PublicKey {
535+ return m .signer .Public ()
536+ }
537+
538+ func (m * mockSigner ) Sign (rand io.Reader , digest []byte , opts crypto.SignerOpts ) ([]byte , error ) {
539+ return m .signer .Sign (rand , digest , opts )
540+ }
541+
542+ func newMockRSASigner (t * testing.T ) crypto.Signer {
543+ key := mustParsePrivateKey (golden .Get (t , "key.pem" ))
544+ return & mockSigner {signer : key .(crypto.Signer )}
545+ }
546+
547+ func TestMiddleware_WithCryptoSignerE2E (t * testing.T ) {
548+ origTimeNow := saml .TimeNow
549+ origClock := saml .Clock
550+ origRandReader := saml .RandReader
551+ t .Cleanup (func () {
552+ saml .TimeNow = origTimeNow
553+ saml .Clock = origClock
554+ saml .RandReader = origRandReader
555+ })
556+
557+ saml .TimeNow = func () time.Time {
558+ rv , _ := time .Parse ("Mon Jan 2 15:04:05.999999999 MST 2006" , "Mon Dec 1 01:57:09.123456789 UTC 2015" )
559+ return rv
560+ }
561+ saml .Clock = dsig .NewFakeClockAt (saml .TimeNow ())
562+ saml .RandReader = & testRandomReader {}
563+
564+ cert := mustParseCertificate (golden .Get (t , "cert.pem" ))
565+ idpMetadata := golden .Get (t , "idp_metadata.xml" )
566+
567+ var metadata saml.EntityDescriptor
568+ if err := xml .Unmarshal (idpMetadata , & metadata ); err != nil {
569+ panic (err )
570+ }
571+
572+ mockSigner := newMockRSASigner (t )
573+
574+ opts := Options {
575+ URL : mustParseURL ("https://15661444.ngrok.io/" ),
576+ Key : mockSigner ,
577+ Certificate : cert ,
578+ IDPMetadata : & metadata ,
579+ }
580+
581+ middleware , err := New (opts )
582+ assert .Check (t , err )
583+
584+ sessionProvider := DefaultSessionProvider (opts )
585+ sessionProvider .Name = "ttt"
586+ sessionProvider .MaxAge = 7200 * time .Second
587+
588+ sessionCodec := sessionProvider .Codec .(JWTSessionCodec )
589+ sessionCodec .MaxAge = 7200 * time .Second
590+ sessionProvider .Codec = sessionCodec
591+
592+ middleware .Session = sessionProvider
593+ middleware .ServiceProvider .MetadataURL .Path = "/saml2/metadata"
594+ middleware .ServiceProvider .AcsURL .Path = "/saml2/acs"
595+ middleware .ServiceProvider .SloURL .Path = "/saml2/slo"
596+
597+ t .Run ("SessionEncodeDecode" , func (t * testing.T ) {
598+ var tc JWTSessionClaims
599+ if err := json .Unmarshal (golden .Get (t , "token.json" ), & tc ); err != nil {
600+ t .Fatal (err )
601+ }
602+
603+ encoded , err := sessionProvider .Codec .Encode (tc )
604+ assert .Check (t , err )
605+ assert .Assert (t , encoded != "" )
606+
607+ decoded , err := sessionProvider .Codec .Decode (encoded )
608+ assert .Check (t , err )
609+ decodedClaims := decoded .(JWTSessionClaims )
610+ assert .Equal (t , tc .Subject , decodedClaims .Subject )
611+ })
612+
613+ t .Run ("TrackedRequestEncodeDecode" , func (t * testing.T ) {
614+ codec := middleware .RequestTracker .(CookieRequestTracker ).Codec
615+ trackedReq := TrackedRequest {
616+ Index : "test-index" ,
617+ SAMLRequestID : "test-request-id" ,
618+ URI : "/test-uri" ,
619+ }
620+
621+ encoded , err := codec .Encode (trackedReq )
622+ assert .Check (t , err )
623+ assert .Assert (t , encoded != "" )
624+
625+ decoded , err := codec .Decode (encoded )
626+ assert .Check (t , err )
627+ assert .Equal (t , trackedReq .Index , decoded .Index )
628+ assert .Equal (t , trackedReq .SAMLRequestID , decoded .SAMLRequestID )
629+ })
630+
631+ t .Run ("RequireAccountFlow" , func (t * testing.T ) {
632+ handler := middleware .RequireAccount (
633+ http .HandlerFunc (func (_ http.ResponseWriter , _ * http.Request ) {
634+ panic ("not reached" )
635+ }))
636+
637+ req , _ := http .NewRequest ("GET" , "/protected" , nil )
638+ resp := httptest .NewRecorder ()
639+ handler .ServeHTTP (resp , req )
640+
641+ assert .Check (t , is .Equal (http .StatusFound , resp .Code ))
642+ assert .Assert (t , resp .Header ().Get ("Location" ) != "" )
643+ assert .Assert (t , resp .Header ().Get ("Set-Cookie" ) != "" )
644+ })
645+
646+ t .Run ("Metadata" , func (t * testing.T ) {
647+ req , _ := http .NewRequest ("GET" , "/saml2/metadata" , nil )
648+ resp := httptest .NewRecorder ()
649+ middleware .ServeHTTP (resp , req )
650+
651+ assert .Check (t , is .Equal (http .StatusOK , resp .Code ))
652+ assert .Check (t , is .Equal ("application/samlmetadata+xml" ,
653+ resp .Header ().Get ("Content-type" )))
654+ golden .Assert (t , resp .Body .String (), "expected_middleware_metadata.xml" )
655+ })
656+ }
657+
658+ func TestJWTSessionCodec_CryptoSignerEncodeDecode (t * testing.T ) {
659+ tests := []struct {
660+ name string
661+ method jwt.SigningMethod
662+ genKey func (t * testing.T ) crypto.Signer
663+ subject string
664+ }{
665+ {
666+ name : "ECDSA-P256" ,
667+ method : jwt .SigningMethodES256 ,
668+ genKey : func (t * testing.T ) crypto.Signer {
669+ k , err := ecdsa .GenerateKey (elliptic .P256 (), rand .Reader )
670+ assert .Check (t , err )
671+ return k
672+ },
673+ subject : "test-ecdsa-p256" ,
674+ },
675+ {
676+ name : "ECDSA-P384" ,
677+ method : jwt .SigningMethodES384 ,
678+ genKey : func (t * testing.T ) crypto.Signer {
679+ k , err := ecdsa .GenerateKey (elliptic .P384 (), rand .Reader )
680+ assert .Check (t , err )
681+ return k
682+ },
683+ subject : "test-ecdsa-p384" ,
684+ },
685+ {
686+ name : "ECDSA-P521" ,
687+ method : jwt .SigningMethodES512 ,
688+ genKey : func (t * testing.T ) crypto.Signer {
689+ k , err := ecdsa .GenerateKey (elliptic .P521 (), rand .Reader )
690+ assert .Check (t , err )
691+ return k
692+ },
693+ subject : "test-ecdsa-p521" ,
694+ },
695+ {
696+ name : "RSA-PSS" ,
697+ method : jwt .SigningMethodPS256 ,
698+ genKey : func (t * testing.T ) crypto.Signer {
699+ k , err := rsa .GenerateKey (rand .Reader , 2048 )
700+ assert .Check (t , err )
701+ return k
702+ },
703+ subject : "test-rsa-pss" ,
704+ },
705+ {
706+ name : "EdDSA" ,
707+ method : jwt .SigningMethodEdDSA ,
708+ genKey : func (t * testing.T ) crypto.Signer {
709+ _ , k , err := ed25519 .GenerateKey (rand .Reader )
710+ assert .Check (t , err )
711+ return k
712+ },
713+ subject : "test-eddsa" ,
714+ },
715+ }
716+
717+ for _ , tt := range tests {
718+ t .Run (tt .name , func (t * testing.T ) {
719+ now := time .Now ()
720+ origTimeNow := saml .TimeNow
721+ t .Cleanup (func () { saml .TimeNow = origTimeNow })
722+ saml .TimeNow = func () time.Time { return now }
723+
724+ signer := & mockSigner {signer : tt .genKey (t )}
725+
726+ audience := "https://example.com/"
727+ codec := JWTSessionCodec {
728+ SigningMethod : tt .method ,
729+ Audience : audience ,
730+ Issuer : audience ,
731+ MaxAge : time .Hour ,
732+ Key : signer ,
733+ }
734+
735+ tc := JWTSessionClaims {
736+ RegisteredClaims : jwt.RegisteredClaims {
737+ Audience : jwt.ClaimStrings {audience },
738+ Issuer : audience ,
739+ Subject : tt .subject ,
740+ IssuedAt : jwt .NewNumericDate (now ),
741+ ExpiresAt : jwt .NewNumericDate (now .Add (time .Hour )),
742+ NotBefore : jwt .NewNumericDate (now ),
743+ },
744+ SAMLSession : true ,
745+ }
746+
747+ encoded , err := codec .Encode (tc )
748+ assert .Check (t , err )
749+ assert .Assert (t , encoded != "" )
750+
751+ decoded , err := codec .Decode (encoded )
752+ assert .Check (t , err )
753+ decodedClaims := decoded .(JWTSessionClaims )
754+ assert .Equal (t , tt .subject , decodedClaims .Subject )
755+ })
756+ }
757+ }
758+
759+ func TestJWTSessionCodec_UnsupportedAlgorithmReturnsError (t * testing.T ) {
760+ now := time .Now ()
761+ origTimeNow := saml .TimeNow
762+ t .Cleanup (func () { saml .TimeNow = origTimeNow })
763+ saml .TimeNow = func () time.Time { return now }
764+
765+ rsaKey , err := rsa .GenerateKey (rand .Reader , 2048 )
766+ assert .Check (t , err )
767+
768+ signer := & mockSigner {signer : rsaKey }
769+
770+ audience := "https://example.com/"
771+ codec := JWTSessionCodec {
772+ SigningMethod : jwt .SigningMethodNone ,
773+ Audience : audience ,
774+ Issuer : audience ,
775+ MaxAge : time .Hour ,
776+ Key : signer ,
777+ }
778+
779+ tc := JWTSessionClaims {
780+ RegisteredClaims : jwt.RegisteredClaims {
781+ Audience : jwt.ClaimStrings {audience },
782+ Issuer : audience ,
783+ Subject : "test" ,
784+ IssuedAt : jwt .NewNumericDate (now ),
785+ ExpiresAt : jwt .NewNumericDate (now .Add (time .Hour )),
786+ NotBefore : jwt .NewNumericDate (now ),
787+ },
788+ SAMLSession : true ,
789+ }
790+
791+ _ , err = codec .Encode (tc )
792+ assert .Check (t , is .ErrorContains (err , "unsupported algorithm for crypto.Signer" ))
793+ }
0 commit comments