@@ -2,6 +2,11 @@ package samlsp
22
33import (
44 "bytes"
5+ "crypto"
6+ "crypto/ecdsa"
7+ "crypto/ed25519"
8+ "crypto/elliptic"
9+ "crypto/rand"
510 "crypto/rsa"
611 "crypto/sha256"
712 "crypto/x509"
@@ -17,6 +22,7 @@ import (
1722 "testing"
1823 "time"
1924
25+ "github.com/golang-jwt/jwt/v5"
2026 dsig "github.com/russellhaering/goxmldsig"
2127 "gotest.tools/assert"
2228 is "gotest.tools/assert/cmp"
@@ -520,3 +526,370 @@ func TestMiddlewareHandlesInvalidResponse(t *testing.T) {
520526 assert .Check (t , is .Equal ("" , resp .Header ().Get ("Location" )))
521527 assert .Check (t , is .Equal ("" , resp .Header ().Get ("Set-Cookie" )))
522528}
529+
530+ type mockSigner struct {
531+ signer crypto.Signer
532+ }
533+
534+ func (m * mockSigner ) Public () crypto.PublicKey {
535+ return m .signer .Public ()
536+ }
537+
538+ func (m * mockSigner ) Sign (rand io.Reader , digest []byte , opts crypto.SignerOpts ) ([]byte , error ) {
539+ return m .signer .Sign (rand , digest , opts )
540+ }
541+
542+ func newMockRSASigner (t * testing.T ) crypto.Signer {
543+ key := mustParsePrivateKey (golden .Get (t , "key.pem" ))
544+ return & mockSigner {signer : key .(crypto.Signer )}
545+ }
546+
547+ func TestMiddleware_WithCryptoSignerE2E (t * testing.T ) {
548+ saml .TimeNow = func () time.Time {
549+ rv , _ := time .Parse ("Mon Jan 2 15:04:05.999999999 MST 2006" , "Mon Dec 1 01:57:09.123456789 UTC 2015" )
550+ return rv
551+ }
552+ saml .Clock = dsig .NewFakeClockAt (saml .TimeNow ())
553+ saml .RandReader = & testRandomReader {}
554+
555+ cert := mustParseCertificate (golden .Get (t , "cert.pem" ))
556+ idpMetadata := golden .Get (t , "idp_metadata.xml" )
557+
558+ var metadata saml.EntityDescriptor
559+ if err := xml .Unmarshal (idpMetadata , & metadata ); err != nil {
560+ panic (err )
561+ }
562+
563+ mockSigner := newMockRSASigner (t )
564+
565+ opts := Options {
566+ URL : mustParseURL ("https://15661444.ngrok.io/" ),
567+ Key : mockSigner ,
568+ Certificate : cert ,
569+ IDPMetadata : & metadata ,
570+ }
571+
572+ middleware , err := New (opts )
573+ assert .Check (t , err )
574+
575+ sessionProvider := DefaultSessionProvider (opts )
576+ sessionProvider .Name = "ttt"
577+ sessionProvider .MaxAge = 7200 * time .Second
578+
579+ sessionCodec := sessionProvider .Codec .(JWTSessionCodec )
580+ sessionCodec .MaxAge = 7200 * time .Second
581+ sessionProvider .Codec = sessionCodec
582+
583+ middleware .Session = sessionProvider
584+ middleware .ServiceProvider .MetadataURL .Path = "/saml2/metadata"
585+ middleware .ServiceProvider .AcsURL .Path = "/saml2/acs"
586+ middleware .ServiceProvider .SloURL .Path = "/saml2/slo"
587+
588+ t .Run ("SessionEncodeDecode" , func (t * testing.T ) {
589+ var tc JWTSessionClaims
590+ if err := json .Unmarshal (golden .Get (t , "token.json" ), & tc ); err != nil {
591+ t .Fatal (err )
592+ }
593+
594+ encoded , err := sessionProvider .Codec .Encode (tc )
595+ assert .Check (t , err )
596+ assert .Assert (t , encoded != "" )
597+
598+ decoded , err := sessionProvider .Codec .Decode (encoded )
599+ assert .Check (t , err )
600+ decodedClaims := decoded .(JWTSessionClaims )
601+ assert .Equal (t , tc .Subject , decodedClaims .Subject )
602+ })
603+
604+ t .Run ("TrackedRequestEncodeDecode" , func (t * testing.T ) {
605+ codec := middleware .RequestTracker .(CookieRequestTracker ).Codec
606+ trackedReq := TrackedRequest {
607+ Index : "test-index" ,
608+ SAMLRequestID : "test-request-id" ,
609+ URI : "/test-uri" ,
610+ }
611+
612+ encoded , err := codec .Encode (trackedReq )
613+ assert .Check (t , err )
614+ assert .Assert (t , encoded != "" )
615+
616+ decoded , err := codec .Decode (encoded )
617+ assert .Check (t , err )
618+ assert .Equal (t , trackedReq .Index , decoded .Index )
619+ assert .Equal (t , trackedReq .SAMLRequestID , decoded .SAMLRequestID )
620+ })
621+
622+ t .Run ("RequireAccountFlow" , func (t * testing.T ) {
623+ handler := middleware .RequireAccount (
624+ http .HandlerFunc (func (_ http.ResponseWriter , _ * http.Request ) {
625+ panic ("not reached" )
626+ }))
627+
628+ req , _ := http .NewRequest ("GET" , "/protected" , nil )
629+ resp := httptest .NewRecorder ()
630+ handler .ServeHTTP (resp , req )
631+
632+ assert .Check (t , is .Equal (http .StatusFound , resp .Code ))
633+ assert .Assert (t , resp .Header ().Get ("Location" ) != "" )
634+ assert .Assert (t , resp .Header ().Get ("Set-Cookie" ) != "" )
635+ })
636+
637+ t .Run ("Metadata" , func (t * testing.T ) {
638+ req , _ := http .NewRequest ("GET" , "/saml2/metadata" , nil )
639+ resp := httptest .NewRecorder ()
640+ middleware .ServeHTTP (resp , req )
641+
642+ assert .Check (t , is .Equal (http .StatusOK , resp .Code ))
643+ assert .Check (t , is .Equal ("application/samlmetadata+xml" ,
644+ resp .Header ().Get ("Content-type" )))
645+ golden .Assert (t , resp .Body .String (), "expected_middleware_metadata.xml" )
646+ })
647+ }
648+
649+ func TestJWTSessionCodec_ECDSACryptoSignerEncodeDecode (t * testing.T ) {
650+ now := time .Now ()
651+ saml .TimeNow = func () time.Time {
652+ return now
653+ }
654+
655+ ecKey , err := ecdsa .GenerateKey (elliptic .P256 (), rand .Reader )
656+ assert .Check (t , err )
657+
658+ signer := & mockSigner {signer : ecKey }
659+
660+ audience := "https://example.com/"
661+ codec := JWTSessionCodec {
662+ SigningMethod : jwt .SigningMethodES256 ,
663+ Audience : audience ,
664+ Issuer : audience ,
665+ MaxAge : time .Hour ,
666+ Key : signer ,
667+ }
668+
669+ tc := JWTSessionClaims {
670+ RegisteredClaims : jwt.RegisteredClaims {
671+ Audience : jwt.ClaimStrings {audience },
672+ Issuer : audience ,
673+ Subject : "test-ecdsa-subject" ,
674+ IssuedAt : jwt .NewNumericDate (now ),
675+ ExpiresAt : jwt .NewNumericDate (now .Add (time .Hour )),
676+ NotBefore : jwt .NewNumericDate (now ),
677+ },
678+ SAMLSession : true ,
679+ Attributes : map [string ][]string {"email" : {"test@example.com" }},
680+ }
681+
682+ encoded , err := codec .Encode (tc )
683+ assert .Check (t , err )
684+ assert .Assert (t , encoded != "" )
685+
686+ decoded , err := codec .Decode (encoded )
687+ assert .Check (t , err )
688+ decodedClaims := decoded .(JWTSessionClaims )
689+ assert .Equal (t , tc .Subject , decodedClaims .Subject )
690+ assert .Equal (t , "test@example.com" , decodedClaims .Attributes .Get ("email" ))
691+ }
692+
693+ func TestJWTSessionCodec_ECDSACryptoSignerP384 (t * testing.T ) {
694+ now := time .Now ()
695+ saml .TimeNow = func () time.Time {
696+ return now
697+ }
698+
699+ ecKey , err := ecdsa .GenerateKey (elliptic .P384 (), rand .Reader )
700+ assert .Check (t , err )
701+
702+ signer := & mockSigner {signer : ecKey }
703+
704+ audience := "https://example.com/"
705+ codec := JWTSessionCodec {
706+ SigningMethod : jwt .SigningMethodES384 ,
707+ Audience : audience ,
708+ Issuer : audience ,
709+ MaxAge : time .Hour ,
710+ Key : signer ,
711+ }
712+
713+ tc := JWTSessionClaims {
714+ RegisteredClaims : jwt.RegisteredClaims {
715+ Audience : jwt.ClaimStrings {audience },
716+ Issuer : audience ,
717+ Subject : "test-p384" ,
718+ IssuedAt : jwt .NewNumericDate (now ),
719+ ExpiresAt : jwt .NewNumericDate (now .Add (time .Hour )),
720+ NotBefore : jwt .NewNumericDate (now ),
721+ },
722+ SAMLSession : true ,
723+ }
724+
725+ encoded , err := codec .Encode (tc )
726+ assert .Check (t , err )
727+ assert .Assert (t , encoded != "" )
728+
729+ decoded , err := codec .Decode (encoded )
730+ assert .Check (t , err )
731+ decodedClaims := decoded .(JWTSessionClaims )
732+ assert .Equal (t , "test-p384" , decodedClaims .Subject )
733+ }
734+
735+ func TestJWTSessionCodec_RSAPSSCryptoSigner (t * testing.T ) {
736+ now := time .Now ()
737+ saml .TimeNow = func () time.Time {
738+ return now
739+ }
740+
741+ rsaKey , err := rsa .GenerateKey (rand .Reader , 2048 )
742+ assert .Check (t , err )
743+
744+ signer := & mockSigner {signer : rsaKey }
745+
746+ audience := "https://example.com/"
747+ codec := JWTSessionCodec {
748+ SigningMethod : jwt .SigningMethodPS256 ,
749+ Audience : audience ,
750+ Issuer : audience ,
751+ MaxAge : time .Hour ,
752+ Key : signer ,
753+ }
754+
755+ tc := JWTSessionClaims {
756+ RegisteredClaims : jwt.RegisteredClaims {
757+ Audience : jwt.ClaimStrings {audience },
758+ Issuer : audience ,
759+ Subject : "test-pss" ,
760+ IssuedAt : jwt .NewNumericDate (now ),
761+ ExpiresAt : jwt .NewNumericDate (now .Add (time .Hour )),
762+ NotBefore : jwt .NewNumericDate (now ),
763+ },
764+ SAMLSession : true ,
765+ }
766+
767+ encoded , err := codec .Encode (tc )
768+ assert .Check (t , err )
769+ assert .Assert (t , encoded != "" )
770+
771+ decoded , err := codec .Decode (encoded )
772+ assert .Check (t , err )
773+ decodedClaims := decoded .(JWTSessionClaims )
774+ assert .Equal (t , "test-pss" , decodedClaims .Subject )
775+ }
776+
777+ func TestJWTSessionCodec_ECDSACryptoSignerP521 (t * testing.T ) {
778+ now := time .Now ()
779+ saml .TimeNow = func () time.Time {
780+ return now
781+ }
782+
783+ ecKey , err := ecdsa .GenerateKey (elliptic .P521 (), rand .Reader )
784+ assert .Check (t , err )
785+
786+ signer := & mockSigner {signer : ecKey }
787+
788+ audience := "https://example.com/"
789+ codec := JWTSessionCodec {
790+ SigningMethod : jwt .SigningMethodES512 ,
791+ Audience : audience ,
792+ Issuer : audience ,
793+ MaxAge : time .Hour ,
794+ Key : signer ,
795+ }
796+
797+ tc := JWTSessionClaims {
798+ RegisteredClaims : jwt.RegisteredClaims {
799+ Audience : jwt.ClaimStrings {audience },
800+ Issuer : audience ,
801+ Subject : "test-p521" ,
802+ IssuedAt : jwt .NewNumericDate (now ),
803+ ExpiresAt : jwt .NewNumericDate (now .Add (time .Hour )),
804+ NotBefore : jwt .NewNumericDate (now ),
805+ },
806+ SAMLSession : true ,
807+ }
808+
809+ encoded , err := codec .Encode (tc )
810+ assert .Check (t , err )
811+ assert .Assert (t , encoded != "" )
812+
813+ decoded , err := codec .Decode (encoded )
814+ assert .Check (t , err )
815+ decodedClaims := decoded .(JWTSessionClaims )
816+ assert .Equal (t , "test-p521" , decodedClaims .Subject )
817+ }
818+
819+ func TestJWTSessionCodec_EdDSACryptoSignerEncodeDecode (t * testing.T ) {
820+ now := time .Now ()
821+ saml .TimeNow = func () time.Time {
822+ return now
823+ }
824+
825+ _ , edKey , err := ed25519 .GenerateKey (rand .Reader )
826+ assert .Check (t , err )
827+
828+ signer := & mockSigner {signer : edKey }
829+
830+ audience := "https://example.com/"
831+ codec := JWTSessionCodec {
832+ SigningMethod : jwt .SigningMethodEdDSA ,
833+ Audience : audience ,
834+ Issuer : audience ,
835+ MaxAge : time .Hour ,
836+ Key : signer ,
837+ }
838+
839+ tc := JWTSessionClaims {
840+ RegisteredClaims : jwt.RegisteredClaims {
841+ Audience : jwt.ClaimStrings {audience },
842+ Issuer : audience ,
843+ Subject : "test-eddsa" ,
844+ IssuedAt : jwt .NewNumericDate (now ),
845+ ExpiresAt : jwt .NewNumericDate (now .Add (time .Hour )),
846+ NotBefore : jwt .NewNumericDate (now ),
847+ },
848+ SAMLSession : true ,
849+ }
850+
851+ encoded , err := codec .Encode (tc )
852+ assert .Check (t , err )
853+ assert .Assert (t , encoded != "" )
854+
855+ decoded , err := codec .Decode (encoded )
856+ assert .Check (t , err )
857+ decodedClaims := decoded .(JWTSessionClaims )
858+ assert .Equal (t , "test-eddsa" , decodedClaims .Subject )
859+ }
860+
861+ func TestJWTSessionCodec_UnsupportedAlgorithmReturnsError (t * testing.T ) {
862+ now := time .Now ()
863+ saml .TimeNow = func () time.Time {
864+ return now
865+ }
866+
867+ rsaKey , err := rsa .GenerateKey (rand .Reader , 2048 )
868+ assert .Check (t , err )
869+
870+ signer := & mockSigner {signer : rsaKey }
871+
872+ audience := "https://example.com/"
873+ codec := JWTSessionCodec {
874+ SigningMethod : jwt .SigningMethodNone ,
875+ Audience : audience ,
876+ Issuer : audience ,
877+ MaxAge : time .Hour ,
878+ Key : signer ,
879+ }
880+
881+ tc := JWTSessionClaims {
882+ RegisteredClaims : jwt.RegisteredClaims {
883+ Audience : jwt.ClaimStrings {audience },
884+ Issuer : audience ,
885+ Subject : "test" ,
886+ IssuedAt : jwt .NewNumericDate (now ),
887+ ExpiresAt : jwt .NewNumericDate (now .Add (time .Hour )),
888+ NotBefore : jwt .NewNumericDate (now ),
889+ },
890+ SAMLSession : true ,
891+ }
892+
893+ _ , err = codec .Encode (tc )
894+ assert .Check (t , is .ErrorContains (err , "unsupported signing algorithm for crypto.Signer" ))
895+ }
0 commit comments