Skip to content

Commit 388c550

Browse files
Add crypto.Signer support for KMS/HSM keys
Check public key type instead of private key type to support crypto.Signer implementations (GCP KMS, AWS KMS, HSM) that aren't concrete *rsa.PrivateKey or *ecdsa.PrivateKey types. Changes: - samlsp/new.go: Update defaultSigningMethodForKey() - samlsp/session_jwt.go: Add fallback signing with crypto.Signer - samlsp/request_tracker_jwt.go: Add fallback signing - service_provider.go: Update GetSigningContext() validation
1 parent 3167922 commit 388c550

File tree

5 files changed

+297
-15
lines changed

5 files changed

+297
-15
lines changed

samlsp/middleware_test.go

Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@ package samlsp
22

33
import (
44
"bytes"
5+
"crypto"
6+
"crypto/ed25519"
7+
"crypto/rand"
58
"crypto/rsa"
69
"crypto/sha256"
710
"crypto/x509"
@@ -17,6 +20,7 @@ import (
1720
"testing"
1821
"time"
1922

23+
"github.com/golang-jwt/jwt/v5"
2024
dsig "github.com/russellhaering/goxmldsig"
2125
"gotest.tools/assert"
2226
is "gotest.tools/assert/cmp"
@@ -520,3 +524,174 @@ func TestMiddlewareHandlesInvalidResponse(t *testing.T) {
520524
assert.Check(t, is.Equal("", resp.Header().Get("Location")))
521525
assert.Check(t, is.Equal("", resp.Header().Get("Set-Cookie")))
522526
}
527+
528+
type mockSigner struct {
529+
signer crypto.Signer
530+
}
531+
532+
func (m *mockSigner) Public() crypto.PublicKey {
533+
return m.signer.Public()
534+
}
535+
536+
func (m *mockSigner) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) ([]byte, error) {
537+
return m.signer.Sign(rand, digest, opts)
538+
}
539+
540+
func newMockRSASigner(t *testing.T) crypto.Signer {
541+
key := mustParsePrivateKey(golden.Get(t, "key.pem"))
542+
return &mockSigner{signer: key.(crypto.Signer)}
543+
}
544+
545+
func TestMiddleware_WithCryptoSignerE2E(t *testing.T) {
546+
saml.TimeNow = func() time.Time {
547+
rv, _ := time.Parse("Mon Jan 2 15:04:05.999999999 MST 2006", "Mon Dec 1 01:57:09.123456789 UTC 2015")
548+
return rv
549+
}
550+
saml.Clock = dsig.NewFakeClockAt(saml.TimeNow())
551+
saml.RandReader = &testRandomReader{}
552+
553+
cert := mustParseCertificate(golden.Get(t, "cert.pem"))
554+
idpMetadata := golden.Get(t, "idp_metadata.xml")
555+
556+
var metadata saml.EntityDescriptor
557+
if err := xml.Unmarshal(idpMetadata, &metadata); err != nil {
558+
panic(err)
559+
}
560+
561+
mockSigner := newMockRSASigner(t)
562+
563+
opts := Options{
564+
URL: mustParseURL("https://15661444.ngrok.io/"),
565+
Key: mockSigner,
566+
Certificate: cert,
567+
IDPMetadata: &metadata,
568+
}
569+
570+
middleware, err := New(opts)
571+
assert.Check(t, err)
572+
573+
sessionProvider := DefaultSessionProvider(opts)
574+
sessionProvider.Name = "ttt"
575+
sessionProvider.MaxAge = 7200 * time.Second
576+
577+
sessionCodec := sessionProvider.Codec.(JWTSessionCodec)
578+
sessionCodec.MaxAge = 7200 * time.Second
579+
sessionProvider.Codec = sessionCodec
580+
581+
middleware.Session = sessionProvider
582+
middleware.ServiceProvider.MetadataURL.Path = "/saml2/metadata"
583+
middleware.ServiceProvider.AcsURL.Path = "/saml2/acs"
584+
middleware.ServiceProvider.SloURL.Path = "/saml2/slo"
585+
586+
t.Run("SessionEncodeDecode", func(t *testing.T) {
587+
var tc JWTSessionClaims
588+
if err := json.Unmarshal(golden.Get(t, "token.json"), &tc); err != nil {
589+
t.Fatal(err)
590+
}
591+
592+
encoded, err := sessionProvider.Codec.Encode(tc)
593+
assert.Check(t, err)
594+
assert.Assert(t, encoded != "")
595+
596+
decoded, err := sessionProvider.Codec.Decode(encoded)
597+
assert.Check(t, err)
598+
decodedClaims := decoded.(JWTSessionClaims)
599+
assert.Equal(t, tc.Subject, decodedClaims.Subject)
600+
})
601+
602+
t.Run("TrackedRequestEncodeDecode", func(t *testing.T) {
603+
codec := middleware.RequestTracker.(CookieRequestTracker).Codec
604+
trackedReq := TrackedRequest{
605+
Index: "test-index",
606+
SAMLRequestID: "test-request-id",
607+
URI: "/test-uri",
608+
}
609+
610+
encoded, err := codec.Encode(trackedReq)
611+
assert.Check(t, err)
612+
assert.Assert(t, encoded != "")
613+
614+
decoded, err := codec.Decode(encoded)
615+
assert.Check(t, err)
616+
assert.Equal(t, trackedReq.Index, decoded.Index)
617+
assert.Equal(t, trackedReq.SAMLRequestID, decoded.SAMLRequestID)
618+
})
619+
620+
t.Run("RequireAccountFlow", func(t *testing.T) {
621+
handler := middleware.RequireAccount(
622+
http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
623+
panic("not reached")
624+
}))
625+
626+
req, _ := http.NewRequest("GET", "/protected", nil)
627+
resp := httptest.NewRecorder()
628+
handler.ServeHTTP(resp, req)
629+
630+
assert.Check(t, is.Equal(http.StatusFound, resp.Code))
631+
assert.Assert(t, resp.Header().Get("Location") != "")
632+
assert.Assert(t, resp.Header().Get("Set-Cookie") != "")
633+
})
634+
635+
t.Run("Metadata", func(t *testing.T) {
636+
req, _ := http.NewRequest("GET", "/saml2/metadata", nil)
637+
resp := httptest.NewRecorder()
638+
middleware.ServeHTTP(resp, req)
639+
640+
assert.Check(t, is.Equal(http.StatusOK, resp.Code))
641+
assert.Check(t, is.Equal("application/samlmetadata+xml",
642+
resp.Header().Get("Content-type")))
643+
golden.Assert(t, resp.Body.String(), "expected_middleware_metadata.xml")
644+
})
645+
}
646+
647+
func TestJWTSessionCodec_Ed25519(t *testing.T) {
648+
now := time.Now()
649+
saml.TimeNow = func() time.Time {
650+
return now
651+
}
652+
653+
// Generate Ed25519 key pair
654+
_, privateKey, err := ed25519.GenerateKey(rand.Reader)
655+
assert.Check(t, err)
656+
657+
audience := "https://example.com/"
658+
codec := JWTSessionCodec{
659+
SigningMethod: jwt.SigningMethodEdDSA,
660+
Audience: audience,
661+
Issuer: audience,
662+
MaxAge: time.Hour,
663+
Key: privateKey,
664+
}
665+
666+
// Create test claims directly
667+
tc := JWTSessionClaims{
668+
RegisteredClaims: jwt.RegisteredClaims{
669+
Audience: jwt.ClaimStrings{audience},
670+
Issuer: audience,
671+
Subject: "test-subject-123",
672+
IssuedAt: jwt.NewNumericDate(now),
673+
ExpiresAt: jwt.NewNumericDate(now.Add(time.Hour)),
674+
NotBefore: jwt.NewNumericDate(now),
675+
},
676+
Attributes: Attributes{
677+
"uid": []string{"testuser"},
678+
"givenName": []string{"Test User"},
679+
},
680+
SAMLSession: true,
681+
}
682+
683+
// Test encode
684+
encoded, err := codec.Encode(tc)
685+
assert.Check(t, err)
686+
assert.Assert(t, encoded != "", "encoded token should not be empty")
687+
688+
// Test decode
689+
decoded, err := codec.Decode(encoded)
690+
assert.Check(t, err)
691+
decodedClaims := decoded.(JWTSessionClaims)
692+
693+
// Verify claims match
694+
assert.Equal(t, tc.Subject, decodedClaims.Subject)
695+
assert.Check(t, decodedClaims.SAMLSession, "SAMLSession should be true")
696+
assert.Equal(t, tc.Attributes.Get("uid"), decodedClaims.Attributes.Get("uid"))
697+
}

samlsp/new.go

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -149,13 +149,16 @@ func DefaultServiceProvider(opts Options) saml.ServiceProvider {
149149
}
150150

151151
func defaultSigningMethodForKey(key crypto.Signer) string {
152-
switch key.(type) {
153-
case *rsa.PrivateKey:
152+
if key == nil {
153+
return ""
154+
}
155+
// Check public key type to support crypto.Signer implementations (KMS/HSM)
156+
// that aren't concrete *rsa.PrivateKey or *ecdsa.PrivateKey types
157+
switch key.Public().(type) {
158+
case *rsa.PublicKey:
154159
return dsig.RSASHA1SignatureMethod
155-
case *ecdsa.PrivateKey:
160+
case *ecdsa.PublicKey:
156161
return dsig.ECDSASHA256SignatureMethod
157-
case nil:
158-
return ""
159162
default:
160163
panic(fmt.Sprintf("programming error: unsupported key type %T", key))
161164
}

samlsp/request_tracker_jwt.go

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@ package samlsp
22

33
import (
44
"crypto"
5+
"crypto/ecdsa"
6+
"crypto/ed25519"
7+
"crypto/rsa"
58
"fmt"
69
"time"
710

@@ -44,7 +47,18 @@ func (s JWTTrackedRequestCodec) Encode(value TrackedRequest) (string, error) {
4447
SAMLAuthnRequest: true,
4548
}
4649
token := jwt.NewWithClaims(s.SigningMethod, claims)
47-
return token.SignedString(s.Key)
50+
51+
// Check if key is a concrete private key type that jwt library can handle directly.
52+
// For crypto.Signer implementations (KMS/HSM), use custom signing.
53+
switch s.Key.(type) {
54+
case *rsa.PrivateKey, *ecdsa.PrivateKey, ed25519.PrivateKey:
55+
return token.SignedString(s.Key)
56+
default:
57+
if s.Key == nil {
58+
return "", fmt.Errorf("signing key is nil")
59+
}
60+
return signJWTWithCryptoSigner(token, s.Key, s.SigningMethod)
61+
}
4862
}
4963

5064
// Decode returns a Tracked request from an encoded string.

samlsp/session_jwt.go

Lines changed: 91 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,16 @@ package samlsp
22

33
import (
44
"crypto"
5+
"crypto/ecdsa"
6+
"crypto/ed25519"
7+
"crypto/rand"
8+
"crypto/rsa"
9+
"encoding/asn1"
10+
"encoding/base64"
511
"errors"
12+
"fmt"
13+
"math/big"
14+
"strings"
615
"time"
716

817
"github.com/golang-jwt/jwt/v5"
@@ -77,12 +86,18 @@ func (c JWTSessionCodec) Encode(s Session) (string, error) {
7786
claims := s.(JWTSessionClaims) // this will panic if you pass the wrong kind of session
7887

7988
token := jwt.NewWithClaims(c.SigningMethod, claims)
80-
signedString, err := token.SignedString(c.Key)
81-
if err != nil {
82-
return "", err
83-
}
8489

85-
return signedString, nil
90+
// Check if key is a concrete private key type that jwt library can handle directly.
91+
// For crypto.Signer implementations (KMS/HSM), use custom signing.
92+
switch c.Key.(type) {
93+
case *rsa.PrivateKey, *ecdsa.PrivateKey, ed25519.PrivateKey:
94+
return token.SignedString(c.Key)
95+
default:
96+
if c.Key == nil {
97+
return "", fmt.Errorf("signing key is nil")
98+
}
99+
return signJWTWithCryptoSigner(token, c.Key, c.SigningMethod)
100+
}
86101
}
87102

88103
// Decode parses the serialized session that may have been returned by Encode
@@ -137,3 +152,74 @@ func (a Attributes) Get(key string) string {
137152
}
138153
return v[0]
139154
}
155+
156+
// signJWTWithCryptoSigner signs a JWT token using the crypto.Signer interface.
157+
// This allows KMS/HSM keys that implement crypto.Signer to sign JWTs.
158+
func signJWTWithCryptoSigner(token *jwt.Token, signer crypto.Signer, method jwt.SigningMethod) (string, error) {
159+
// Get the signing string (header.payload)
160+
signingString, err := token.SigningString()
161+
if err != nil {
162+
return "", err
163+
}
164+
165+
// Determine hash algorithm based on signing method
166+
var hashFunc crypto.Hash
167+
switch method.Alg() {
168+
case "RS256", "ES256", "PS256":
169+
hashFunc = crypto.SHA256
170+
case "RS384", "ES384", "PS384":
171+
hashFunc = crypto.SHA384
172+
case "RS512", "ES512", "PS512":
173+
hashFunc = crypto.SHA512
174+
default:
175+
return "", fmt.Errorf("unsupported signing algorithm: %s", method.Alg())
176+
}
177+
178+
// Hash the signing string
179+
hasher := hashFunc.New()
180+
hasher.Write([]byte(signingString))
181+
digest := hasher.Sum(nil)
182+
183+
// Sign using crypto.Signer
184+
sig, err := signer.Sign(rand.Reader, digest, hashFunc)
185+
if err != nil {
186+
return "", fmt.Errorf("signing with crypto.Signer: %w", err)
187+
}
188+
189+
// For ECDSA, the signature from crypto.Signer is ASN.1 DER encoded,
190+
// but JWT expects raw R||S format
191+
if _, ok := signer.Public().(*ecdsa.PublicKey); ok {
192+
sig, err = convertECDSASignatureToJWT(sig, signer.Public().(*ecdsa.PublicKey))
193+
if err != nil {
194+
return "", err
195+
}
196+
}
197+
198+
// Encode signature and return complete JWT
199+
return strings.Join([]string{signingString, base64.RawURLEncoding.EncodeToString(sig)}, "."), nil
200+
}
201+
202+
// convertECDSASignatureToJWT converts ASN.1 DER encoded ECDSA signature to JWT format (R||S)
203+
func convertECDSASignatureToJWT(derSig []byte, pubKey *ecdsa.PublicKey) ([]byte, error) {
204+
// Parse ASN.1 DER signature
205+
var sig struct {
206+
R, S *big.Int
207+
}
208+
if _, err := asn1.Unmarshal(derSig, &sig); err != nil {
209+
return nil, fmt.Errorf("parsing ECDSA signature: %w", err)
210+
}
211+
212+
// Calculate key size in bytes
213+
keyBytes := (pubKey.Curve.Params().BitSize + 7) / 8
214+
215+
// Create R||S format
216+
rBytes := sig.R.Bytes()
217+
sBytes := sig.S.Bytes()
218+
219+
// Pad to key size
220+
result := make([]byte, 2*keyBytes)
221+
copy(result[keyBytes-len(rBytes):keyBytes], rBytes)
222+
copy(result[2*keyBytes-len(sBytes):], sBytes)
223+
224+
return result, nil
225+
}

service_provider.go

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -567,21 +567,25 @@ func GetSigningContext(sp *ServiceProvider) (*dsig.SigningContext, error) {
567567
// keyPair.Certificate = append(keyPair.Certificate, cert.Raw)
568568
// }
569569

570+
// Validate that the key type matches the signature method.
571+
// We check the public key type to support crypto.Signer implementations
572+
// (like KMS/HSM signers) that aren't literal *rsa.PrivateKey or *ecdsa.PrivateKey.
573+
pubKey := sp.Key.Public()
570574
switch sp.SignatureMethod {
571575
case dsig.RSASHA1SignatureMethod,
572576
dsig.RSASHA256SignatureMethod,
573577
dsig.RSASHA384SignatureMethod,
574578
dsig.RSASHA512SignatureMethod:
575-
if _, ok := sp.Key.(*rsa.PrivateKey); !ok {
576-
return nil, fmt.Errorf("signature method %s requires a key of type rsa.PrivateKey, not %T", sp.SignatureMethod, sp.Key)
579+
if _, ok := pubKey.(*rsa.PublicKey); !ok {
580+
return nil, fmt.Errorf("signature method %s requires an RSA key, got %T", sp.SignatureMethod, pubKey)
577581
}
578582

579583
case dsig.ECDSASHA1SignatureMethod,
580584
dsig.ECDSASHA256SignatureMethod,
581585
dsig.ECDSASHA384SignatureMethod,
582586
dsig.ECDSASHA512SignatureMethod:
583-
if _, ok := sp.Key.(*ecdsa.PrivateKey); !ok {
584-
return nil, fmt.Errorf("signature method %s requires a key of type ecdsa.PrivateKey, not %T", sp.SignatureMethod, sp.Key)
587+
if _, ok := pubKey.(*ecdsa.PublicKey); !ok {
588+
return nil, fmt.Errorf("signature method %s requires an ECDSA key, got %T", sp.SignatureMethod, pubKey)
585589
}
586590
default:
587591
return nil, fmt.Errorf("invalid signing method %s", sp.SignatureMethod)

0 commit comments

Comments
 (0)