Skip to content

Commit e89fe6c

Browse files
Add RSA crypto.Signer support for KMS/HSM
Check public key type instead of private key type to support crypto.Signer implementations (e.g. GCP KMS, AWS KMS, HSM) that aren't concrete *rsa.PrivateKey types. Only RSA keys are supported for crypto.Signer since major SAML IdPs (Azure AD, Auth0, Okta) use RSA signing. Non-RSA crypto.Signer keys return a clear error.
1 parent 744174f commit e89fe6c

File tree

5 files changed

+253
-15
lines changed

5 files changed

+253
-15
lines changed

samlsp/middleware_test.go

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@ package samlsp
22

33
import (
44
"bytes"
5+
"crypto"
6+
"crypto/ecdsa"
7+
"crypto/elliptic"
8+
"crypto/rand"
59
"crypto/rsa"
610
"crypto/sha256"
711
"crypto/x509"
@@ -17,6 +21,7 @@ import (
1721
"testing"
1822
"time"
1923

24+
"github.com/golang-jwt/jwt/v5"
2025
dsig "github.com/russellhaering/goxmldsig"
2126
"gotest.tools/assert"
2227
is "gotest.tools/assert/cmp"
@@ -520,3 +525,158 @@ func TestMiddlewareHandlesInvalidResponse(t *testing.T) {
520525
assert.Check(t, is.Equal("", resp.Header().Get("Location")))
521526
assert.Check(t, is.Equal("", resp.Header().Get("Set-Cookie")))
522527
}
528+
529+
type mockSigner struct {
530+
signer crypto.Signer
531+
}
532+
533+
func (m *mockSigner) Public() crypto.PublicKey {
534+
return m.signer.Public()
535+
}
536+
537+
func (m *mockSigner) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) ([]byte, error) {
538+
return m.signer.Sign(rand, digest, opts)
539+
}
540+
541+
func newMockRSASigner(t *testing.T) crypto.Signer {
542+
key := mustParsePrivateKey(golden.Get(t, "key.pem"))
543+
return &mockSigner{signer: key.(crypto.Signer)}
544+
}
545+
546+
func TestMiddleware_WithCryptoSignerE2E(t *testing.T) {
547+
saml.TimeNow = func() time.Time {
548+
rv, _ := time.Parse("Mon Jan 2 15:04:05.999999999 MST 2006", "Mon Dec 1 01:57:09.123456789 UTC 2015")
549+
return rv
550+
}
551+
saml.Clock = dsig.NewFakeClockAt(saml.TimeNow())
552+
saml.RandReader = &testRandomReader{}
553+
554+
cert := mustParseCertificate(golden.Get(t, "cert.pem"))
555+
idpMetadata := golden.Get(t, "idp_metadata.xml")
556+
557+
var metadata saml.EntityDescriptor
558+
if err := xml.Unmarshal(idpMetadata, &metadata); err != nil {
559+
panic(err)
560+
}
561+
562+
mockSigner := newMockRSASigner(t)
563+
564+
opts := Options{
565+
URL: mustParseURL("https://15661444.ngrok.io/"),
566+
Key: mockSigner,
567+
Certificate: cert,
568+
IDPMetadata: &metadata,
569+
}
570+
571+
middleware, err := New(opts)
572+
assert.Check(t, err)
573+
574+
sessionProvider := DefaultSessionProvider(opts)
575+
sessionProvider.Name = "ttt"
576+
sessionProvider.MaxAge = 7200 * time.Second
577+
578+
sessionCodec := sessionProvider.Codec.(JWTSessionCodec)
579+
sessionCodec.MaxAge = 7200 * time.Second
580+
sessionProvider.Codec = sessionCodec
581+
582+
middleware.Session = sessionProvider
583+
middleware.ServiceProvider.MetadataURL.Path = "/saml2/metadata"
584+
middleware.ServiceProvider.AcsURL.Path = "/saml2/acs"
585+
middleware.ServiceProvider.SloURL.Path = "/saml2/slo"
586+
587+
t.Run("SessionEncodeDecode", func(t *testing.T) {
588+
var tc JWTSessionClaims
589+
if err := json.Unmarshal(golden.Get(t, "token.json"), &tc); err != nil {
590+
t.Fatal(err)
591+
}
592+
593+
encoded, err := sessionProvider.Codec.Encode(tc)
594+
assert.Check(t, err)
595+
assert.Assert(t, encoded != "")
596+
597+
decoded, err := sessionProvider.Codec.Decode(encoded)
598+
assert.Check(t, err)
599+
decodedClaims := decoded.(JWTSessionClaims)
600+
assert.Equal(t, tc.Subject, decodedClaims.Subject)
601+
})
602+
603+
t.Run("TrackedRequestEncodeDecode", func(t *testing.T) {
604+
codec := middleware.RequestTracker.(CookieRequestTracker).Codec
605+
trackedReq := TrackedRequest{
606+
Index: "test-index",
607+
SAMLRequestID: "test-request-id",
608+
URI: "/test-uri",
609+
}
610+
611+
encoded, err := codec.Encode(trackedReq)
612+
assert.Check(t, err)
613+
assert.Assert(t, encoded != "")
614+
615+
decoded, err := codec.Decode(encoded)
616+
assert.Check(t, err)
617+
assert.Equal(t, trackedReq.Index, decoded.Index)
618+
assert.Equal(t, trackedReq.SAMLRequestID, decoded.SAMLRequestID)
619+
})
620+
621+
t.Run("RequireAccountFlow", func(t *testing.T) {
622+
handler := middleware.RequireAccount(
623+
http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
624+
panic("not reached")
625+
}))
626+
627+
req, _ := http.NewRequest("GET", "/protected", nil)
628+
resp := httptest.NewRecorder()
629+
handler.ServeHTTP(resp, req)
630+
631+
assert.Check(t, is.Equal(http.StatusFound, resp.Code))
632+
assert.Assert(t, resp.Header().Get("Location") != "")
633+
assert.Assert(t, resp.Header().Get("Set-Cookie") != "")
634+
})
635+
636+
t.Run("Metadata", func(t *testing.T) {
637+
req, _ := http.NewRequest("GET", "/saml2/metadata", nil)
638+
resp := httptest.NewRecorder()
639+
middleware.ServeHTTP(resp, req)
640+
641+
assert.Check(t, is.Equal(http.StatusOK, resp.Code))
642+
assert.Check(t, is.Equal("application/samlmetadata+xml",
643+
resp.Header().Get("Content-type")))
644+
golden.Assert(t, resp.Body.String(), "expected_middleware_metadata.xml")
645+
})
646+
}
647+
648+
func TestJWTSessionCodec_NonRSACryptoSignerReturnsError(t *testing.T) {
649+
now := time.Now()
650+
saml.TimeNow = func() time.Time {
651+
return now
652+
}
653+
654+
ecKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
655+
assert.Check(t, err)
656+
657+
signer := &mockSigner{signer: ecKey}
658+
659+
audience := "https://example.com/"
660+
codec := JWTSessionCodec{
661+
SigningMethod: jwt.SigningMethodES256,
662+
Audience: audience,
663+
Issuer: audience,
664+
MaxAge: time.Hour,
665+
Key: signer,
666+
}
667+
668+
tc := JWTSessionClaims{
669+
RegisteredClaims: jwt.RegisteredClaims{
670+
Audience: jwt.ClaimStrings{audience},
671+
Issuer: audience,
672+
Subject: "test",
673+
IssuedAt: jwt.NewNumericDate(now),
674+
ExpiresAt: jwt.NewNumericDate(now.Add(time.Hour)),
675+
NotBefore: jwt.NewNumericDate(now),
676+
},
677+
SAMLSession: true,
678+
}
679+
680+
_, err = codec.Encode(tc)
681+
assert.Check(t, is.ErrorContains(err, "crypto.Signer must hold an RSA key"))
682+
}

samlsp/new.go

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -149,15 +149,18 @@ func DefaultServiceProvider(opts Options) saml.ServiceProvider {
149149
}
150150

151151
func defaultSigningMethodForKey(key crypto.Signer) string {
152-
switch key.(type) {
153-
case *rsa.PrivateKey:
152+
if key == nil {
153+
return ""
154+
}
155+
// Check public key type to support crypto.Signer implementations (KMS/HSM)
156+
// that aren't concrete *rsa.PrivateKey or *ecdsa.PrivateKey types
157+
switch key.Public().(type) {
158+
case *rsa.PublicKey:
154159
return dsig.RSASHA1SignatureMethod
155-
case *ecdsa.PrivateKey:
160+
case *ecdsa.PublicKey:
156161
return dsig.ECDSASHA256SignatureMethod
157-
case nil:
158-
return ""
159162
default:
160-
panic(fmt.Sprintf("programming error: unsupported key type %T", key))
163+
panic(fmt.Sprintf("programming error: unsupported public key type %T", key.Public()))
161164
}
162165
}
163166

samlsp/request_tracker_jwt.go

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@ package samlsp
22

33
import (
44
"crypto"
5+
"crypto/ecdsa"
6+
"crypto/ed25519"
7+
"crypto/rsa"
58
"fmt"
69
"time"
710

@@ -44,7 +47,19 @@ func (s JWTTrackedRequestCodec) Encode(value TrackedRequest) (string, error) {
4447
SAMLAuthnRequest: true,
4548
}
4649
token := jwt.NewWithClaims(s.SigningMethod, claims)
47-
return token.SignedString(s.Key)
50+
51+
if s.Key == nil {
52+
return "", fmt.Errorf("signing key is nil")
53+
}
54+
55+
// Check if key is a concrete private key type that jwt library can handle directly.
56+
// For crypto.Signer implementations (KMS/HSM), use custom signing.
57+
switch s.Key.(type) {
58+
case *rsa.PrivateKey, *ecdsa.PrivateKey, ed25519.PrivateKey:
59+
return token.SignedString(s.Key)
60+
default:
61+
return signJWTWithCryptoSigner(token, s.Key, s.SigningMethod)
62+
}
4863
}
4964

5065
// Decode returns a Tracked request from an encoded string.

samlsp/session_jwt.go

Lines changed: 60 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,14 @@ package samlsp
22

33
import (
44
"crypto"
5+
"crypto/ecdsa"
6+
"crypto/ed25519"
7+
"crypto/rand"
8+
"crypto/rsa"
9+
"encoding/base64"
510
"errors"
11+
"fmt"
12+
"strings"
613
"time"
714

815
"github.com/golang-jwt/jwt/v5"
@@ -77,12 +84,19 @@ func (c JWTSessionCodec) Encode(s Session) (string, error) {
7784
claims := s.(JWTSessionClaims) // this will panic if you pass the wrong kind of session
7885

7986
token := jwt.NewWithClaims(c.SigningMethod, claims)
80-
signedString, err := token.SignedString(c.Key)
81-
if err != nil {
82-
return "", err
87+
88+
if c.Key == nil {
89+
return "", fmt.Errorf("signing key is nil")
8390
}
8491

85-
return signedString, nil
92+
// Check if key is a concrete private key type that jwt library can handle directly.
93+
// For crypto.Signer implementations (KMS/HSM), use custom signing.
94+
switch c.Key.(type) {
95+
case *rsa.PrivateKey, *ecdsa.PrivateKey, ed25519.PrivateKey:
96+
return token.SignedString(c.Key)
97+
default:
98+
return signJWTWithCryptoSigner(token, c.Key, c.SigningMethod)
99+
}
86100
}
87101

88102
// Decode parses the serialized session that may have been returned by Encode
@@ -137,3 +151,45 @@ func (a Attributes) Get(key string) string {
137151
}
138152
return v[0]
139153
}
154+
155+
// signJWTWithCryptoSigner signs a JWT token using the crypto.Signer interface.
156+
// Only RSA signing methods are supported since major SAML IdPs (Azure AD, Auth0,
157+
// Okta) use RSA. This allows KMS/HSM keys that implement crypto.Signer to sign JWTs.
158+
func signJWTWithCryptoSigner(token *jwt.Token, signer crypto.Signer, method jwt.SigningMethod) (string, error) {
159+
if _, ok := signer.Public().(*rsa.PublicKey); !ok {
160+
return "", fmt.Errorf("crypto.Signer must hold an RSA key, got %T", signer.Public())
161+
}
162+
163+
// Get the signing string (header.payload)
164+
signingString, err := token.SigningString()
165+
if err != nil {
166+
return "", err
167+
}
168+
169+
// Determine hash algorithm based on signing method
170+
var hashFunc crypto.Hash
171+
switch method.Alg() {
172+
case "RS256":
173+
hashFunc = crypto.SHA256
174+
case "RS384":
175+
hashFunc = crypto.SHA384
176+
case "RS512":
177+
hashFunc = crypto.SHA512
178+
default:
179+
return "", fmt.Errorf("unsupported signing algorithm for crypto.Signer: %s", method.Alg())
180+
}
181+
182+
// Hash the signing string
183+
hasher := hashFunc.New()
184+
hasher.Write([]byte(signingString))
185+
digest := hasher.Sum(nil)
186+
187+
// Sign using crypto.Signer
188+
sig, err := signer.Sign(rand.Reader, digest, hashFunc)
189+
if err != nil {
190+
return "", fmt.Errorf("signing with crypto.Signer: %w", err)
191+
}
192+
193+
// Encode signature and return complete JWT
194+
return strings.Join([]string{signingString, base64.RawURLEncoding.EncodeToString(sig)}, "."), nil
195+
}

service_provider.go

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -567,21 +567,25 @@ func GetSigningContext(sp *ServiceProvider) (*dsig.SigningContext, error) {
567567
// keyPair.Certificate = append(keyPair.Certificate, cert.Raw)
568568
// }
569569

570+
// Validate that the key type matches the signature method.
571+
// We check the public key type to support crypto.Signer implementations
572+
// (like KMS/HSM signers) that aren't literal *rsa.PrivateKey or *ecdsa.PrivateKey.
573+
pubKey := sp.Key.Public()
570574
switch sp.SignatureMethod {
571575
case dsig.RSASHA1SignatureMethod,
572576
dsig.RSASHA256SignatureMethod,
573577
dsig.RSASHA384SignatureMethod,
574578
dsig.RSASHA512SignatureMethod:
575-
if _, ok := sp.Key.(*rsa.PrivateKey); !ok {
576-
return nil, fmt.Errorf("signature method %s requires a key of type rsa.PrivateKey, not %T", sp.SignatureMethod, sp.Key)
579+
if _, ok := pubKey.(*rsa.PublicKey); !ok {
580+
return nil, fmt.Errorf("signature method %s requires an RSA key, got %T", sp.SignatureMethod, pubKey)
577581
}
578582

579583
case dsig.ECDSASHA1SignatureMethod,
580584
dsig.ECDSASHA256SignatureMethod,
581585
dsig.ECDSASHA384SignatureMethod,
582586
dsig.ECDSASHA512SignatureMethod:
583-
if _, ok := sp.Key.(*ecdsa.PrivateKey); !ok {
584-
return nil, fmt.Errorf("signature method %s requires a key of type ecdsa.PrivateKey, not %T", sp.SignatureMethod, sp.Key)
587+
if _, ok := pubKey.(*ecdsa.PublicKey); !ok {
588+
return nil, fmt.Errorf("signature method %s requires an ECDSA key, got %T", sp.SignatureMethod, pubKey)
585589
}
586590
default:
587591
return nil, fmt.Errorf("invalid signing method %s", sp.SignatureMethod)

0 commit comments

Comments
 (0)