Skip to content

Commit 4a22109

Browse files
Add crypto.Signer support for KMS/HSM keys
Check public key type instead of private key type to support crypto.Signer implementations (GCP KMS, AWS KMS, HSM) that aren't concrete *rsa.PrivateKey or *ecdsa.PrivateKey types. Changes: - samlsp/new.go: Update defaultSigningMethodForKey() - samlsp/session_jwt.go: Add fallback signing with crypto.Signer - samlsp/request_tracker_jwt.go: Add fallback signing - service_provider.go: Update GetSigningContext() validation
1 parent 3465403 commit 4a22109

File tree

4 files changed

+116
-15
lines changed

4 files changed

+116
-15
lines changed

samlsp/new.go

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -149,13 +149,16 @@ func DefaultServiceProvider(opts Options) saml.ServiceProvider {
149149
}
150150

151151
func defaultSigningMethodForKey(key crypto.Signer) string {
152-
switch key.(type) {
153-
case *rsa.PrivateKey:
152+
if key == nil {
153+
return ""
154+
}
155+
// Check public key type to support crypto.Signer implementations (KMS/HSM)
156+
// that aren't concrete *rsa.PrivateKey or *ecdsa.PrivateKey types
157+
switch key.Public().(type) {
158+
case *rsa.PublicKey:
154159
return dsig.RSASHA1SignatureMethod
155-
case *ecdsa.PrivateKey:
160+
case *ecdsa.PublicKey:
156161
return dsig.ECDSASHA256SignatureMethod
157-
case nil:
158-
return ""
159162
default:
160163
panic(fmt.Sprintf("programming error: unsupported key type %T", key))
161164
}

samlsp/request_tracker_jwt.go

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@ package samlsp
22

33
import (
44
"crypto"
5+
"crypto/ecdsa"
6+
"crypto/ed25519"
7+
"crypto/rsa"
58
"fmt"
69
"time"
710

@@ -44,7 +47,15 @@ func (s JWTTrackedRequestCodec) Encode(value TrackedRequest) (string, error) {
4447
SAMLAuthnRequest: true,
4548
}
4649
token := jwt.NewWithClaims(s.SigningMethod, claims)
47-
return token.SignedString(s.Key)
50+
51+
// Check if key is a concrete private key type that jwt library can handle directly.
52+
// For crypto.Signer implementations (KMS/HSM), use custom signing.
53+
switch s.Key.(type) {
54+
case *rsa.PrivateKey, *ecdsa.PrivateKey, ed25519.PrivateKey:
55+
return token.SignedString(s.Key)
56+
default:
57+
return signJWTWithCryptoSigner(token, s.Key, s.SigningMethod)
58+
}
4859
}
4960

5061
// Decode returns a Tracked request from an encoded string.

samlsp/session_jwt.go

Lines changed: 88 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,16 @@ package samlsp
22

33
import (
44
"crypto"
5+
"crypto/ecdsa"
6+
"crypto/ed25519"
7+
"crypto/rand"
8+
"crypto/rsa"
9+
"encoding/asn1"
10+
"encoding/base64"
511
"errors"
12+
"fmt"
13+
"math/big"
14+
"strings"
615
"time"
716

817
"github.com/golang-jwt/jwt/v5"
@@ -77,12 +86,15 @@ func (c JWTSessionCodec) Encode(s Session) (string, error) {
7786
claims := s.(JWTSessionClaims) // this will panic if you pass the wrong kind of session
7887

7988
token := jwt.NewWithClaims(c.SigningMethod, claims)
80-
signedString, err := token.SignedString(c.Key)
81-
if err != nil {
82-
return "", err
83-
}
8489

85-
return signedString, nil
90+
// Check if key is a concrete private key type that jwt library can handle directly.
91+
// For crypto.Signer implementations (KMS/HSM), use custom signing.
92+
switch c.Key.(type) {
93+
case *rsa.PrivateKey, *ecdsa.PrivateKey, ed25519.PrivateKey:
94+
return token.SignedString(c.Key)
95+
default:
96+
return signJWTWithCryptoSigner(token, c.Key, c.SigningMethod)
97+
}
8698
}
8799

88100
// Decode parses the serialized session that may have been returned by Encode
@@ -137,3 +149,74 @@ func (a Attributes) Get(key string) string {
137149
}
138150
return v[0]
139151
}
152+
153+
// signJWTWithCryptoSigner signs a JWT token using the crypto.Signer interface.
154+
// This allows KMS/HSM keys that implement crypto.Signer to sign JWTs.
155+
func signJWTWithCryptoSigner(token *jwt.Token, signer crypto.Signer, method jwt.SigningMethod) (string, error) {
156+
// Get the signing string (header.payload)
157+
signingString, err := token.SigningString()
158+
if err != nil {
159+
return "", err
160+
}
161+
162+
// Determine hash algorithm based on signing method
163+
var hashFunc crypto.Hash
164+
switch method.Alg() {
165+
case "RS256", "ES256", "PS256":
166+
hashFunc = crypto.SHA256
167+
case "RS384", "ES384", "PS384":
168+
hashFunc = crypto.SHA384
169+
case "RS512", "ES512", "PS512":
170+
hashFunc = crypto.SHA512
171+
default:
172+
hashFunc = crypto.SHA256
173+
}
174+
175+
// Hash the signing string
176+
hasher := hashFunc.New()
177+
hasher.Write([]byte(signingString))
178+
digest := hasher.Sum(nil)
179+
180+
// Sign using crypto.Signer
181+
sig, err := signer.Sign(rand.Reader, digest, hashFunc)
182+
if err != nil {
183+
return "", fmt.Errorf("signing with crypto.Signer: %w", err)
184+
}
185+
186+
// For ECDSA, the signature from crypto.Signer is ASN.1 DER encoded,
187+
// but JWT expects raw R||S format
188+
if _, ok := signer.Public().(*ecdsa.PublicKey); ok {
189+
sig, err = convertECDSASignatureToJWT(sig, signer.Public().(*ecdsa.PublicKey))
190+
if err != nil {
191+
return "", err
192+
}
193+
}
194+
195+
// Encode signature and return complete JWT
196+
return strings.Join([]string{signingString, base64.RawURLEncoding.EncodeToString(sig)}, "."), nil
197+
}
198+
199+
// convertECDSASignatureToJWT converts ASN.1 DER encoded ECDSA signature to JWT format (R||S)
200+
func convertECDSASignatureToJWT(derSig []byte, pubKey *ecdsa.PublicKey) ([]byte, error) {
201+
// Parse ASN.1 DER signature
202+
var sig struct {
203+
R, S *big.Int
204+
}
205+
if _, err := asn1.Unmarshal(derSig, &sig); err != nil {
206+
return nil, fmt.Errorf("parsing ECDSA signature: %w", err)
207+
}
208+
209+
// Calculate key size in bytes
210+
keyBytes := (pubKey.Curve.Params().BitSize + 7) / 8
211+
212+
// Create R||S format
213+
rBytes := sig.R.Bytes()
214+
sBytes := sig.S.Bytes()
215+
216+
// Pad to key size
217+
result := make([]byte, 2*keyBytes)
218+
copy(result[keyBytes-len(rBytes):keyBytes], rBytes)
219+
copy(result[2*keyBytes-len(sBytes):], sBytes)
220+
221+
return result, nil
222+
}

service_provider.go

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -567,21 +567,25 @@ func GetSigningContext(sp *ServiceProvider) (*dsig.SigningContext, error) {
567567
// keyPair.Certificate = append(keyPair.Certificate, cert.Raw)
568568
// }
569569

570+
// Validate that the key type matches the signature method.
571+
// We check the public key type to support crypto.Signer implementations
572+
// (like KMS/HSM signers) that aren't literal *rsa.PrivateKey or *ecdsa.PrivateKey.
573+
pubKey := sp.Key.Public()
570574
switch sp.SignatureMethod {
571575
case dsig.RSASHA1SignatureMethod,
572576
dsig.RSASHA256SignatureMethod,
573577
dsig.RSASHA384SignatureMethod,
574578
dsig.RSASHA512SignatureMethod:
575-
if _, ok := sp.Key.(*rsa.PrivateKey); !ok {
576-
return nil, fmt.Errorf("signature method %s requires a key of type rsa.PrivateKey, not %T", sp.SignatureMethod, sp.Key)
579+
if _, ok := pubKey.(*rsa.PublicKey); !ok {
580+
return nil, fmt.Errorf("signature method %s requires an RSA key, got %T", sp.SignatureMethod, pubKey)
577581
}
578582

579583
case dsig.ECDSASHA1SignatureMethod,
580584
dsig.ECDSASHA256SignatureMethod,
581585
dsig.ECDSASHA384SignatureMethod,
582586
dsig.ECDSASHA512SignatureMethod:
583-
if _, ok := sp.Key.(*ecdsa.PrivateKey); !ok {
584-
return nil, fmt.Errorf("signature method %s requires a key of type ecdsa.PrivateKey, not %T", sp.SignatureMethod, sp.Key)
587+
if _, ok := pubKey.(*ecdsa.PublicKey); !ok {
588+
return nil, fmt.Errorf("signature method %s requires an ECDSA key, got %T", sp.SignatureMethod, pubKey)
585589
}
586590
default:
587591
return nil, fmt.Errorf("invalid signing method %s", sp.SignatureMethod)

0 commit comments

Comments
 (0)