Skip to content

Commit 91213ee

Browse files
Add crypto.Signer support for KMS/HSM
Check public key type instead of private key type to support crypto.Signer implementations (e.g. GCP KMS, AWS KMS, HSM) that aren't concrete *rsa.PrivateKey or *ecdsa.PrivateKey types. Supports RSA (RS256/RS384/RS512), RSA-PSS (PS256/PS384/PS512), ECDSA (ES256/ES384/ES512), and EdDSA signing methods via crypto.Signer for JWT session and tracked request signing.
1 parent 06ae334 commit 91213ee

File tree

5 files changed

+467
-18
lines changed

5 files changed

+467
-18
lines changed

samlsp/middleware_test.go

Lines changed: 271 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,11 @@ package samlsp
22

33
import (
44
"bytes"
5+
"crypto"
6+
"crypto/ecdsa"
7+
"crypto/ed25519"
8+
"crypto/elliptic"
9+
"crypto/rand"
510
"crypto/rsa"
611
"crypto/sha256"
712
"crypto/x509"
@@ -17,6 +22,7 @@ import (
1722
"testing"
1823
"time"
1924

25+
"github.com/golang-jwt/jwt/v5"
2026
dsig "github.com/russellhaering/goxmldsig"
2127
"gotest.tools/assert"
2228
is "gotest.tools/assert/cmp"
@@ -520,3 +526,268 @@ func TestMiddlewareHandlesInvalidResponse(t *testing.T) {
520526
assert.Check(t, is.Equal("", resp.Header().Get("Location")))
521527
assert.Check(t, is.Equal("", resp.Header().Get("Set-Cookie")))
522528
}
529+
530+
type mockSigner struct {
531+
signer crypto.Signer
532+
}
533+
534+
func (m *mockSigner) Public() crypto.PublicKey {
535+
return m.signer.Public()
536+
}
537+
538+
func (m *mockSigner) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) ([]byte, error) {
539+
return m.signer.Sign(rand, digest, opts)
540+
}
541+
542+
func newMockRSASigner(t *testing.T) crypto.Signer {
543+
key := mustParsePrivateKey(golden.Get(t, "key.pem"))
544+
return &mockSigner{signer: key.(crypto.Signer)}
545+
}
546+
547+
func TestMiddleware_WithCryptoSignerE2E(t *testing.T) {
548+
origTimeNow := saml.TimeNow
549+
origClock := saml.Clock
550+
origRandReader := saml.RandReader
551+
t.Cleanup(func() {
552+
saml.TimeNow = origTimeNow
553+
saml.Clock = origClock
554+
saml.RandReader = origRandReader
555+
})
556+
557+
saml.TimeNow = func() time.Time {
558+
rv, _ := time.Parse("Mon Jan 2 15:04:05.999999999 MST 2006", "Mon Dec 1 01:57:09.123456789 UTC 2015")
559+
return rv
560+
}
561+
saml.Clock = dsig.NewFakeClockAt(saml.TimeNow())
562+
saml.RandReader = &testRandomReader{}
563+
564+
cert := mustParseCertificate(golden.Get(t, "cert.pem"))
565+
idpMetadata := golden.Get(t, "idp_metadata.xml")
566+
567+
var metadata saml.EntityDescriptor
568+
if err := xml.Unmarshal(idpMetadata, &metadata); err != nil {
569+
panic(err)
570+
}
571+
572+
mockSigner := newMockRSASigner(t)
573+
574+
opts := Options{
575+
URL: mustParseURL("https://15661444.ngrok.io/"),
576+
Key: mockSigner,
577+
Certificate: cert,
578+
IDPMetadata: &metadata,
579+
}
580+
581+
middleware, err := New(opts)
582+
assert.Check(t, err)
583+
584+
sessionProvider := DefaultSessionProvider(opts)
585+
sessionProvider.Name = "ttt"
586+
sessionProvider.MaxAge = 7200 * time.Second
587+
588+
sessionCodec := sessionProvider.Codec.(JWTSessionCodec)
589+
sessionCodec.MaxAge = 7200 * time.Second
590+
sessionProvider.Codec = sessionCodec
591+
592+
middleware.Session = sessionProvider
593+
middleware.ServiceProvider.MetadataURL.Path = "/saml2/metadata"
594+
middleware.ServiceProvider.AcsURL.Path = "/saml2/acs"
595+
middleware.ServiceProvider.SloURL.Path = "/saml2/slo"
596+
597+
t.Run("SessionEncodeDecode", func(t *testing.T) {
598+
var tc JWTSessionClaims
599+
if err := json.Unmarshal(golden.Get(t, "token.json"), &tc); err != nil {
600+
t.Fatal(err)
601+
}
602+
603+
encoded, err := sessionProvider.Codec.Encode(tc)
604+
assert.Check(t, err)
605+
assert.Assert(t, encoded != "")
606+
607+
decoded, err := sessionProvider.Codec.Decode(encoded)
608+
assert.Check(t, err)
609+
decodedClaims := decoded.(JWTSessionClaims)
610+
assert.Equal(t, tc.Subject, decodedClaims.Subject)
611+
})
612+
613+
t.Run("TrackedRequestEncodeDecode", func(t *testing.T) {
614+
codec := middleware.RequestTracker.(CookieRequestTracker).Codec
615+
trackedReq := TrackedRequest{
616+
Index: "test-index",
617+
SAMLRequestID: "test-request-id",
618+
URI: "/test-uri",
619+
}
620+
621+
encoded, err := codec.Encode(trackedReq)
622+
assert.Check(t, err)
623+
assert.Assert(t, encoded != "")
624+
625+
decoded, err := codec.Decode(encoded)
626+
assert.Check(t, err)
627+
assert.Equal(t, trackedReq.Index, decoded.Index)
628+
assert.Equal(t, trackedReq.SAMLRequestID, decoded.SAMLRequestID)
629+
})
630+
631+
t.Run("RequireAccountFlow", func(t *testing.T) {
632+
handler := middleware.RequireAccount(
633+
http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
634+
panic("not reached")
635+
}))
636+
637+
req, _ := http.NewRequest("GET", "/protected", nil)
638+
resp := httptest.NewRecorder()
639+
handler.ServeHTTP(resp, req)
640+
641+
assert.Check(t, is.Equal(http.StatusFound, resp.Code))
642+
assert.Assert(t, resp.Header().Get("Location") != "")
643+
assert.Assert(t, resp.Header().Get("Set-Cookie") != "")
644+
})
645+
646+
t.Run("Metadata", func(t *testing.T) {
647+
req, _ := http.NewRequest("GET", "/saml2/metadata", nil)
648+
resp := httptest.NewRecorder()
649+
middleware.ServeHTTP(resp, req)
650+
651+
assert.Check(t, is.Equal(http.StatusOK, resp.Code))
652+
assert.Check(t, is.Equal("application/samlmetadata+xml",
653+
resp.Header().Get("Content-type")))
654+
golden.Assert(t, resp.Body.String(), "expected_middleware_metadata.xml")
655+
})
656+
}
657+
658+
func TestJWTSessionCodec_CryptoSignerEncodeDecode(t *testing.T) {
659+
tests := []struct {
660+
name string
661+
method jwt.SigningMethod
662+
genKey func(t *testing.T) crypto.Signer
663+
subject string
664+
}{
665+
{
666+
name: "ECDSA-P256",
667+
method: jwt.SigningMethodES256,
668+
genKey: func(t *testing.T) crypto.Signer {
669+
k, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
670+
assert.Check(t, err)
671+
return k
672+
},
673+
subject: "test-ecdsa-p256",
674+
},
675+
{
676+
name: "ECDSA-P384",
677+
method: jwt.SigningMethodES384,
678+
genKey: func(t *testing.T) crypto.Signer {
679+
k, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader)
680+
assert.Check(t, err)
681+
return k
682+
},
683+
subject: "test-ecdsa-p384",
684+
},
685+
{
686+
name: "ECDSA-P521",
687+
method: jwt.SigningMethodES512,
688+
genKey: func(t *testing.T) crypto.Signer {
689+
k, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader)
690+
assert.Check(t, err)
691+
return k
692+
},
693+
subject: "test-ecdsa-p521",
694+
},
695+
{
696+
name: "RSA-PSS",
697+
method: jwt.SigningMethodPS256,
698+
genKey: func(t *testing.T) crypto.Signer {
699+
k, err := rsa.GenerateKey(rand.Reader, 2048)
700+
assert.Check(t, err)
701+
return k
702+
},
703+
subject: "test-rsa-pss",
704+
},
705+
{
706+
name: "EdDSA",
707+
method: jwt.SigningMethodEdDSA,
708+
genKey: func(t *testing.T) crypto.Signer {
709+
_, k, err := ed25519.GenerateKey(rand.Reader)
710+
assert.Check(t, err)
711+
return k
712+
},
713+
subject: "test-eddsa",
714+
},
715+
}
716+
717+
for _, tt := range tests {
718+
t.Run(tt.name, func(t *testing.T) {
719+
now := time.Now()
720+
origTimeNow := saml.TimeNow
721+
t.Cleanup(func() { saml.TimeNow = origTimeNow })
722+
saml.TimeNow = func() time.Time { return now }
723+
724+
signer := &mockSigner{signer: tt.genKey(t)}
725+
726+
audience := "https://example.com/"
727+
codec := JWTSessionCodec{
728+
SigningMethod: tt.method,
729+
Audience: audience,
730+
Issuer: audience,
731+
MaxAge: time.Hour,
732+
Key: signer,
733+
}
734+
735+
tc := JWTSessionClaims{
736+
RegisteredClaims: jwt.RegisteredClaims{
737+
Audience: jwt.ClaimStrings{audience},
738+
Issuer: audience,
739+
Subject: tt.subject,
740+
IssuedAt: jwt.NewNumericDate(now),
741+
ExpiresAt: jwt.NewNumericDate(now.Add(time.Hour)),
742+
NotBefore: jwt.NewNumericDate(now),
743+
},
744+
SAMLSession: true,
745+
}
746+
747+
encoded, err := codec.Encode(tc)
748+
assert.Check(t, err)
749+
assert.Assert(t, encoded != "")
750+
751+
decoded, err := codec.Decode(encoded)
752+
assert.Check(t, err)
753+
decodedClaims := decoded.(JWTSessionClaims)
754+
assert.Equal(t, tt.subject, decodedClaims.Subject)
755+
})
756+
}
757+
}
758+
759+
func TestJWTSessionCodec_UnsupportedAlgorithmReturnsError(t *testing.T) {
760+
now := time.Now()
761+
origTimeNow := saml.TimeNow
762+
t.Cleanup(func() { saml.TimeNow = origTimeNow })
763+
saml.TimeNow = func() time.Time { return now }
764+
765+
rsaKey, err := rsa.GenerateKey(rand.Reader, 2048)
766+
assert.Check(t, err)
767+
768+
signer := &mockSigner{signer: rsaKey}
769+
770+
audience := "https://example.com/"
771+
codec := JWTSessionCodec{
772+
SigningMethod: jwt.SigningMethodNone,
773+
Audience: audience,
774+
Issuer: audience,
775+
MaxAge: time.Hour,
776+
Key: signer,
777+
}
778+
779+
tc := JWTSessionClaims{
780+
RegisteredClaims: jwt.RegisteredClaims{
781+
Audience: jwt.ClaimStrings{audience},
782+
Issuer: audience,
783+
Subject: "test",
784+
IssuedAt: jwt.NewNumericDate(now),
785+
ExpiresAt: jwt.NewNumericDate(now.Add(time.Hour)),
786+
NotBefore: jwt.NewNumericDate(now),
787+
},
788+
SAMLSession: true,
789+
}
790+
791+
_, err = codec.Encode(tc)
792+
assert.Check(t, is.ErrorContains(err, "unsupported algorithm for crypto.Signer"))
793+
}

samlsp/new.go

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ type Options struct {
3838
}
3939

4040
func getDefaultSigningMethod(signer crypto.Signer) jwt.SigningMethod {
41-
if signer != nil {
41+
if !saml.IsSignerNil(signer) {
4242
switch signer.Public().(type) {
4343
case *ecdsa.PublicKey:
4444
return jwt.SigningMethodES256
@@ -149,15 +149,18 @@ func DefaultServiceProvider(opts Options) saml.ServiceProvider {
149149
}
150150

151151
func defaultSigningMethodForKey(key crypto.Signer) string {
152-
switch key.(type) {
153-
case *rsa.PrivateKey:
152+
if saml.IsSignerNil(key) {
153+
return ""
154+
}
155+
// Check public key type to support crypto.Signer implementations (KMS/HSM)
156+
// that aren't concrete *rsa.PrivateKey or *ecdsa.PrivateKey types
157+
switch key.Public().(type) {
158+
case *rsa.PublicKey:
154159
return dsig.RSASHA1SignatureMethod
155-
case *ecdsa.PrivateKey:
160+
case *ecdsa.PublicKey:
156161
return dsig.ECDSASHA256SignatureMethod
157-
case nil:
158-
return ""
159162
default:
160-
panic(fmt.Sprintf("programming error: unsupported key type %T", key))
163+
panic(fmt.Sprintf("programming error: unsupported public key type %T", key.Public()))
161164
}
162165
}
163166

samlsp/request_tracker_jwt.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,14 @@ func (s JWTTrackedRequestCodec) Encode(value TrackedRequest) (string, error) {
4444
SAMLAuthnRequest: true,
4545
}
4646
token := jwt.NewWithClaims(s.SigningMethod, claims)
47-
return token.SignedString(s.Key)
47+
return signToken(token, s.Key, s.SigningMethod)
4848
}
4949

5050
// Decode returns a Tracked request from an encoded string.
5151
func (s JWTTrackedRequestCodec) Decode(signed string) (*TrackedRequest, error) {
52+
if saml.IsSignerNil(s.Key) {
53+
return nil, fmt.Errorf("decoding key is nil")
54+
}
5255
parser := jwt.NewParser(
5356
jwt.WithValidMethods([]string{s.SigningMethod.Alg()}),
5457
jwt.WithTimeFunc(saml.TimeNow),

0 commit comments

Comments
 (0)