Skip to content

Commit d1718fa

Browse files
Add crypto.Signer support for KMS/HSM
Check public key type instead of private key type to support crypto.Signer implementations (e.g. GCP KMS, AWS KMS, HSM) that aren't concrete *rsa.PrivateKey or *ecdsa.PrivateKey types. Supports RSA (RS256/RS384/RS512), RSA-PSS (PS256/PS384/PS512), ECDSA (ES256/ES384/ES512), and EdDSA signing methods via crypto.Signer for JWT session and tracked request signing.
1 parent 06ae334 commit d1718fa

File tree

5 files changed

+568
-18
lines changed

5 files changed

+568
-18
lines changed

samlsp/middleware_test.go

Lines changed: 373 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,11 @@ package samlsp
22

33
import (
44
"bytes"
5+
"crypto"
6+
"crypto/ecdsa"
7+
"crypto/ed25519"
8+
"crypto/elliptic"
9+
"crypto/rand"
510
"crypto/rsa"
611
"crypto/sha256"
712
"crypto/x509"
@@ -17,6 +22,7 @@ import (
1722
"testing"
1823
"time"
1924

25+
"github.com/golang-jwt/jwt/v5"
2026
dsig "github.com/russellhaering/goxmldsig"
2127
"gotest.tools/assert"
2228
is "gotest.tools/assert/cmp"
@@ -520,3 +526,370 @@ func TestMiddlewareHandlesInvalidResponse(t *testing.T) {
520526
assert.Check(t, is.Equal("", resp.Header().Get("Location")))
521527
assert.Check(t, is.Equal("", resp.Header().Get("Set-Cookie")))
522528
}
529+
530+
type mockSigner struct {
531+
signer crypto.Signer
532+
}
533+
534+
func (m *mockSigner) Public() crypto.PublicKey {
535+
return m.signer.Public()
536+
}
537+
538+
func (m *mockSigner) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) ([]byte, error) {
539+
return m.signer.Sign(rand, digest, opts)
540+
}
541+
542+
func newMockRSASigner(t *testing.T) crypto.Signer {
543+
key := mustParsePrivateKey(golden.Get(t, "key.pem"))
544+
return &mockSigner{signer: key.(crypto.Signer)}
545+
}
546+
547+
func TestMiddleware_WithCryptoSignerE2E(t *testing.T) {
548+
saml.TimeNow = func() time.Time {
549+
rv, _ := time.Parse("Mon Jan 2 15:04:05.999999999 MST 2006", "Mon Dec 1 01:57:09.123456789 UTC 2015")
550+
return rv
551+
}
552+
saml.Clock = dsig.NewFakeClockAt(saml.TimeNow())
553+
saml.RandReader = &testRandomReader{}
554+
555+
cert := mustParseCertificate(golden.Get(t, "cert.pem"))
556+
idpMetadata := golden.Get(t, "idp_metadata.xml")
557+
558+
var metadata saml.EntityDescriptor
559+
if err := xml.Unmarshal(idpMetadata, &metadata); err != nil {
560+
panic(err)
561+
}
562+
563+
mockSigner := newMockRSASigner(t)
564+
565+
opts := Options{
566+
URL: mustParseURL("https://15661444.ngrok.io/"),
567+
Key: mockSigner,
568+
Certificate: cert,
569+
IDPMetadata: &metadata,
570+
}
571+
572+
middleware, err := New(opts)
573+
assert.Check(t, err)
574+
575+
sessionProvider := DefaultSessionProvider(opts)
576+
sessionProvider.Name = "ttt"
577+
sessionProvider.MaxAge = 7200 * time.Second
578+
579+
sessionCodec := sessionProvider.Codec.(JWTSessionCodec)
580+
sessionCodec.MaxAge = 7200 * time.Second
581+
sessionProvider.Codec = sessionCodec
582+
583+
middleware.Session = sessionProvider
584+
middleware.ServiceProvider.MetadataURL.Path = "/saml2/metadata"
585+
middleware.ServiceProvider.AcsURL.Path = "/saml2/acs"
586+
middleware.ServiceProvider.SloURL.Path = "/saml2/slo"
587+
588+
t.Run("SessionEncodeDecode", func(t *testing.T) {
589+
var tc JWTSessionClaims
590+
if err := json.Unmarshal(golden.Get(t, "token.json"), &tc); err != nil {
591+
t.Fatal(err)
592+
}
593+
594+
encoded, err := sessionProvider.Codec.Encode(tc)
595+
assert.Check(t, err)
596+
assert.Assert(t, encoded != "")
597+
598+
decoded, err := sessionProvider.Codec.Decode(encoded)
599+
assert.Check(t, err)
600+
decodedClaims := decoded.(JWTSessionClaims)
601+
assert.Equal(t, tc.Subject, decodedClaims.Subject)
602+
})
603+
604+
t.Run("TrackedRequestEncodeDecode", func(t *testing.T) {
605+
codec := middleware.RequestTracker.(CookieRequestTracker).Codec
606+
trackedReq := TrackedRequest{
607+
Index: "test-index",
608+
SAMLRequestID: "test-request-id",
609+
URI: "/test-uri",
610+
}
611+
612+
encoded, err := codec.Encode(trackedReq)
613+
assert.Check(t, err)
614+
assert.Assert(t, encoded != "")
615+
616+
decoded, err := codec.Decode(encoded)
617+
assert.Check(t, err)
618+
assert.Equal(t, trackedReq.Index, decoded.Index)
619+
assert.Equal(t, trackedReq.SAMLRequestID, decoded.SAMLRequestID)
620+
})
621+
622+
t.Run("RequireAccountFlow", func(t *testing.T) {
623+
handler := middleware.RequireAccount(
624+
http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
625+
panic("not reached")
626+
}))
627+
628+
req, _ := http.NewRequest("GET", "/protected", nil)
629+
resp := httptest.NewRecorder()
630+
handler.ServeHTTP(resp, req)
631+
632+
assert.Check(t, is.Equal(http.StatusFound, resp.Code))
633+
assert.Assert(t, resp.Header().Get("Location") != "")
634+
assert.Assert(t, resp.Header().Get("Set-Cookie") != "")
635+
})
636+
637+
t.Run("Metadata", func(t *testing.T) {
638+
req, _ := http.NewRequest("GET", "/saml2/metadata", nil)
639+
resp := httptest.NewRecorder()
640+
middleware.ServeHTTP(resp, req)
641+
642+
assert.Check(t, is.Equal(http.StatusOK, resp.Code))
643+
assert.Check(t, is.Equal("application/samlmetadata+xml",
644+
resp.Header().Get("Content-type")))
645+
golden.Assert(t, resp.Body.String(), "expected_middleware_metadata.xml")
646+
})
647+
}
648+
649+
func TestJWTSessionCodec_ECDSACryptoSignerEncodeDecode(t *testing.T) {
650+
now := time.Now()
651+
saml.TimeNow = func() time.Time {
652+
return now
653+
}
654+
655+
ecKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
656+
assert.Check(t, err)
657+
658+
signer := &mockSigner{signer: ecKey}
659+
660+
audience := "https://example.com/"
661+
codec := JWTSessionCodec{
662+
SigningMethod: jwt.SigningMethodES256,
663+
Audience: audience,
664+
Issuer: audience,
665+
MaxAge: time.Hour,
666+
Key: signer,
667+
}
668+
669+
tc := JWTSessionClaims{
670+
RegisteredClaims: jwt.RegisteredClaims{
671+
Audience: jwt.ClaimStrings{audience},
672+
Issuer: audience,
673+
Subject: "test-ecdsa-subject",
674+
IssuedAt: jwt.NewNumericDate(now),
675+
ExpiresAt: jwt.NewNumericDate(now.Add(time.Hour)),
676+
NotBefore: jwt.NewNumericDate(now),
677+
},
678+
SAMLSession: true,
679+
Attributes: map[string][]string{"email": {"test@example.com"}},
680+
}
681+
682+
encoded, err := codec.Encode(tc)
683+
assert.Check(t, err)
684+
assert.Assert(t, encoded != "")
685+
686+
decoded, err := codec.Decode(encoded)
687+
assert.Check(t, err)
688+
decodedClaims := decoded.(JWTSessionClaims)
689+
assert.Equal(t, tc.Subject, decodedClaims.Subject)
690+
assert.Equal(t, "test@example.com", decodedClaims.Attributes.Get("email"))
691+
}
692+
693+
func TestJWTSessionCodec_ECDSACryptoSignerP384(t *testing.T) {
694+
now := time.Now()
695+
saml.TimeNow = func() time.Time {
696+
return now
697+
}
698+
699+
ecKey, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader)
700+
assert.Check(t, err)
701+
702+
signer := &mockSigner{signer: ecKey}
703+
704+
audience := "https://example.com/"
705+
codec := JWTSessionCodec{
706+
SigningMethod: jwt.SigningMethodES384,
707+
Audience: audience,
708+
Issuer: audience,
709+
MaxAge: time.Hour,
710+
Key: signer,
711+
}
712+
713+
tc := JWTSessionClaims{
714+
RegisteredClaims: jwt.RegisteredClaims{
715+
Audience: jwt.ClaimStrings{audience},
716+
Issuer: audience,
717+
Subject: "test-p384",
718+
IssuedAt: jwt.NewNumericDate(now),
719+
ExpiresAt: jwt.NewNumericDate(now.Add(time.Hour)),
720+
NotBefore: jwt.NewNumericDate(now),
721+
},
722+
SAMLSession: true,
723+
}
724+
725+
encoded, err := codec.Encode(tc)
726+
assert.Check(t, err)
727+
assert.Assert(t, encoded != "")
728+
729+
decoded, err := codec.Decode(encoded)
730+
assert.Check(t, err)
731+
decodedClaims := decoded.(JWTSessionClaims)
732+
assert.Equal(t, "test-p384", decodedClaims.Subject)
733+
}
734+
735+
func TestJWTSessionCodec_RSAPSSCryptoSigner(t *testing.T) {
736+
now := time.Now()
737+
saml.TimeNow = func() time.Time {
738+
return now
739+
}
740+
741+
rsaKey, err := rsa.GenerateKey(rand.Reader, 2048)
742+
assert.Check(t, err)
743+
744+
signer := &mockSigner{signer: rsaKey}
745+
746+
audience := "https://example.com/"
747+
codec := JWTSessionCodec{
748+
SigningMethod: jwt.SigningMethodPS256,
749+
Audience: audience,
750+
Issuer: audience,
751+
MaxAge: time.Hour,
752+
Key: signer,
753+
}
754+
755+
tc := JWTSessionClaims{
756+
RegisteredClaims: jwt.RegisteredClaims{
757+
Audience: jwt.ClaimStrings{audience},
758+
Issuer: audience,
759+
Subject: "test-pss",
760+
IssuedAt: jwt.NewNumericDate(now),
761+
ExpiresAt: jwt.NewNumericDate(now.Add(time.Hour)),
762+
NotBefore: jwt.NewNumericDate(now),
763+
},
764+
SAMLSession: true,
765+
}
766+
767+
encoded, err := codec.Encode(tc)
768+
assert.Check(t, err)
769+
assert.Assert(t, encoded != "")
770+
771+
decoded, err := codec.Decode(encoded)
772+
assert.Check(t, err)
773+
decodedClaims := decoded.(JWTSessionClaims)
774+
assert.Equal(t, "test-pss", decodedClaims.Subject)
775+
}
776+
777+
func TestJWTSessionCodec_ECDSACryptoSignerP521(t *testing.T) {
778+
now := time.Now()
779+
saml.TimeNow = func() time.Time {
780+
return now
781+
}
782+
783+
ecKey, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader)
784+
assert.Check(t, err)
785+
786+
signer := &mockSigner{signer: ecKey}
787+
788+
audience := "https://example.com/"
789+
codec := JWTSessionCodec{
790+
SigningMethod: jwt.SigningMethodES512,
791+
Audience: audience,
792+
Issuer: audience,
793+
MaxAge: time.Hour,
794+
Key: signer,
795+
}
796+
797+
tc := JWTSessionClaims{
798+
RegisteredClaims: jwt.RegisteredClaims{
799+
Audience: jwt.ClaimStrings{audience},
800+
Issuer: audience,
801+
Subject: "test-p521",
802+
IssuedAt: jwt.NewNumericDate(now),
803+
ExpiresAt: jwt.NewNumericDate(now.Add(time.Hour)),
804+
NotBefore: jwt.NewNumericDate(now),
805+
},
806+
SAMLSession: true,
807+
}
808+
809+
encoded, err := codec.Encode(tc)
810+
assert.Check(t, err)
811+
assert.Assert(t, encoded != "")
812+
813+
decoded, err := codec.Decode(encoded)
814+
assert.Check(t, err)
815+
decodedClaims := decoded.(JWTSessionClaims)
816+
assert.Equal(t, "test-p521", decodedClaims.Subject)
817+
}
818+
819+
func TestJWTSessionCodec_EdDSACryptoSignerEncodeDecode(t *testing.T) {
820+
now := time.Now()
821+
saml.TimeNow = func() time.Time {
822+
return now
823+
}
824+
825+
_, edKey, err := ed25519.GenerateKey(rand.Reader)
826+
assert.Check(t, err)
827+
828+
signer := &mockSigner{signer: edKey}
829+
830+
audience := "https://example.com/"
831+
codec := JWTSessionCodec{
832+
SigningMethod: jwt.SigningMethodEdDSA,
833+
Audience: audience,
834+
Issuer: audience,
835+
MaxAge: time.Hour,
836+
Key: signer,
837+
}
838+
839+
tc := JWTSessionClaims{
840+
RegisteredClaims: jwt.RegisteredClaims{
841+
Audience: jwt.ClaimStrings{audience},
842+
Issuer: audience,
843+
Subject: "test-eddsa",
844+
IssuedAt: jwt.NewNumericDate(now),
845+
ExpiresAt: jwt.NewNumericDate(now.Add(time.Hour)),
846+
NotBefore: jwt.NewNumericDate(now),
847+
},
848+
SAMLSession: true,
849+
}
850+
851+
encoded, err := codec.Encode(tc)
852+
assert.Check(t, err)
853+
assert.Assert(t, encoded != "")
854+
855+
decoded, err := codec.Decode(encoded)
856+
assert.Check(t, err)
857+
decodedClaims := decoded.(JWTSessionClaims)
858+
assert.Equal(t, "test-eddsa", decodedClaims.Subject)
859+
}
860+
861+
func TestJWTSessionCodec_UnsupportedAlgorithmReturnsError(t *testing.T) {
862+
now := time.Now()
863+
saml.TimeNow = func() time.Time {
864+
return now
865+
}
866+
867+
rsaKey, err := rsa.GenerateKey(rand.Reader, 2048)
868+
assert.Check(t, err)
869+
870+
signer := &mockSigner{signer: rsaKey}
871+
872+
audience := "https://example.com/"
873+
codec := JWTSessionCodec{
874+
SigningMethod: jwt.SigningMethodNone,
875+
Audience: audience,
876+
Issuer: audience,
877+
MaxAge: time.Hour,
878+
Key: signer,
879+
}
880+
881+
tc := JWTSessionClaims{
882+
RegisteredClaims: jwt.RegisteredClaims{
883+
Audience: jwt.ClaimStrings{audience},
884+
Issuer: audience,
885+
Subject: "test",
886+
IssuedAt: jwt.NewNumericDate(now),
887+
ExpiresAt: jwt.NewNumericDate(now.Add(time.Hour)),
888+
NotBefore: jwt.NewNumericDate(now),
889+
},
890+
SAMLSession: true,
891+
}
892+
893+
_, err = codec.Encode(tc)
894+
assert.Check(t, is.ErrorContains(err, "unsupported signing algorithm for crypto.Signer"))
895+
}

0 commit comments

Comments
 (0)