Skip to content

Commit fc99084

Browse files
committed
Refactor token exchange to remove unused custom claims and enhance thumbprint calculation options
1 parent 7fff2dd commit fc99084

File tree

4 files changed

+74
-17
lines changed

4 files changed

+74
-17
lines changed

_example/main.go

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,7 @@ func main() {
4848

4949
// Parse the JSON payload
5050
var payload struct {
51-
IDToken string `json:"id_token"`
52-
CustomClaims map[string]any `json:"custom_claims"`
51+
IDToken string `json:"id_token"`
5352
}
5453

5554
if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
@@ -64,7 +63,7 @@ func main() {
6463
}
6564

6665
// Exchange the token
67-
accessToken, err := tokenBridge.ExchangeToken(r.Context(), payload.IDToken, payload.CustomClaims)
66+
accessToken, err := tokenBridge.ExchangeToken(r.Context(), payload.IDToken)
6867
if err != nil {
6968
http.Error(w, fmt.Sprintf("Failed to exchange token: %v", err), http.StatusInternalServerError)
7069
return

oidc.go

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,12 @@ type OIDCVerifierOptions struct {
3737
// Now is a function that returns the current time, which can be used for expiry and validity checks.
3838
// If not provided, the default time function (time.Now) is used.
3939
Now func() time.Time
40+
41+
// CalculateThumbprintOptions is an optional configuration for calculating the thumbprint.
42+
// It allows customization of the TLS configuration and dialer used to establish the connection.
43+
// This is useful for scenarios where you need to specify custom TLS settings or a custom dialer.
44+
// If nil, default settings will be used.
45+
CalculateThumbprintOptions *CalculateThumbprintOptions
4046
}
4147

4248
// OIDCVerifier is responsible for verifying OpenID Connect ID tokens.
@@ -102,8 +108,9 @@ func NewOIDCVerifier(ctx context.Context, issuerURL *url.URL, clientIDs []string
102108
transport := opts.Transport
103109
if len(opts.Thumbprints) > 0 {
104110
transport = &thumbprintValidatingTransport{
105-
transport: opts.Transport,
106-
thumbprints: opts.Thumbprints,
111+
transport: opts.Transport,
112+
thumbprints: opts.Thumbprints,
113+
calculateThumbprintOptions: opts.CalculateThumbprintOptions,
107114
}
108115
}
109116

transport.go

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,11 @@ type thumbprintValidatingTransport struct {
2727
// one of the valid thumbprints, the response is rejected.
2828
thumbprints []string
2929

30-
// tlsConfig is a customizable TLS configuration used when establishing the TLS connection.
31-
tlsConfig *tls.Config
32-
33-
// dialer is a custom network dialer that will be used to establish the network connection.
34-
dialer *net.Dialer
30+
// calculateThumbprintOptions is an optional configuration for calculating the thumbprint.
31+
// It allows customization of the TLS configuration and dialer used to establish the connection.
32+
// This is useful for scenarios where you need to specify custom TLS settings or a custom dialer.
33+
// If nil, default settings will be used.
34+
calculateThumbprintOptions *CalculateThumbprintOptions
3535
}
3636

3737
// RoundTrip executes the HTTP request and processes the response. It validates the certificate
@@ -80,12 +80,14 @@ func (t *thumbprintValidatingTransport) RoundTrip(req *http.Request) (*http.Resp
8080
}
8181

8282
thumbprint, err := CalculateThumbprintFromJWKS(jwksURL, func(o *CalculateThumbprintOptions) {
83-
if t.tlsConfig != nil {
84-
o.TLSConfig = t.tlsConfig
85-
}
86-
87-
if t.dialer != nil {
88-
o.Dialer = t.dialer
83+
if t.calculateThumbprintOptions != nil {
84+
if t.calculateThumbprintOptions.TLSConfig != nil {
85+
o.TLSConfig = t.calculateThumbprintOptions.TLSConfig
86+
}
87+
88+
if t.calculateThumbprintOptions.Dialer != nil {
89+
o.Dialer = t.calculateThumbprintOptions.Dialer
90+
}
8991
}
9092
})
9193
if err != nil {
@@ -117,8 +119,11 @@ func (t *thumbprintValidatingTransport) isThumbprintValid(thumbprint string) boo
117119

118120
// CalculateThumbprintOptions holds configuration for CalculateThumbprintFromJWKS
119121
type CalculateThumbprintOptions struct {
122+
// TLSConfig is a customizable TLS configuration used when establishing the TLS connection.
120123
TLSConfig *tls.Config
121-
Dialer *net.Dialer
124+
125+
// Dialer is a custom network dialer that will be used to establish the network connection.
126+
Dialer *net.Dialer
122127
}
123128

124129
// CalculateThumbprintFromJWKS retrieves the certificate chain from the JWKS URI, extracts the last certificate,

transport_test.go

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
package tokenbridge
2+
3+
import (
4+
"crypto/x509"
5+
"encoding/pem"
6+
"testing"
7+
8+
"github.com/stretchr/testify/assert"
9+
)
10+
11+
func TestCalculateThumbprint(t *testing.T) {
12+
// Sample PEM-encoded certificate
13+
certPEM := `-----BEGIN CERTIFICATE-----
14+
MIICiTCCAfICCQD6m7oRw0uXOjANBgkqhkiG9w0BAQUFADCBiDELMAkGA1UEBhMC
15+
VVMxCzAJBgNVBAgTAldBMRAwDgYDVQQHEwdTZWF0dGxlMQ8wDQYDVQQKEwZBbWF6
16+
b24xFDASBgNVBAsTC0lBTSBDb25zb2xlMRIwEAYDVQQDEwlUZXN0Q2lsYWMxHzAd
17+
BgkqhkiG9w0BCQEWEG5vb25lQGFtYXpvbi5jb20wHhcNMTEwNDI1MjA0NTIxWhcN
18+
MTIwNDI0MjA0NTIxWjCBiDELMAkGA1UEBhMCVVMxCzAJBgNVBAgTAldBMRAwDgYD
19+
VQQHEwdTZWF0dGxlMQ8wDQYDVQQKEwZBbWF6b24xFDASBgNVBAsTC0lBTSBDb25z
20+
b2xlMRIwEAYDVQQDEwlUZXN0Q2lsYWMxHzAdBgkqhkiG9w0BCQEWEG5vb25lQGFt
21+
YXpvbi5jb20wgZ8wDQYJKoZIhvcNAQEBBQADgY0AMIGJAoGBAMaK0dn+a4GmWIWJ
22+
21uUSfwfEvySWtC2XADZ4nB+BLYgVIk60CpiwsZ3G93vUEIO3IyNoH/f0wYK8m9T
23+
rDHudUZg3qX4waLG5M43q7Wgc/MbQITxOUSQv7c7ugFFDzQGBzZswY6786m86gpE
24+
Ibb3OhjZnzcvQAaRHhdlQWIMm2nrAgMBAAEwDQYJKoZIhvcNAQEFBQADgYEAtCu4
25+
nUhVVxYUntneD9+h8Mg9q6q+auNKyExzyLwaxlAoo7TJHidbtS4J5iNmZgXL0Fkb
26+
FFBjvSfpJIlJ00zbhNYS5f6GuoEDmFJl0ZxBHjJnyp378OD8uTs7fLvjx79LjSTb
27+
NYiytVbZPQUQ5Yaxu2jXnimvw3rrszlaEXAMPLE=
28+
-----END CERTIFICATE-----`
29+
30+
// Decode the PEM-encoded certificate
31+
block, _ := pem.Decode([]byte(certPEM))
32+
if block == nil {
33+
t.Fatalf("failed to decode PEM block")
34+
}
35+
36+
// Parse the certificate
37+
cert, err := x509.ParseCertificate(block.Bytes)
38+
if err != nil {
39+
t.Fatalf("failed to parse certificate: %v", err)
40+
}
41+
42+
// Calculate the thumbprint
43+
thumbprint, err := calculateThumbprint(cert)
44+
assert.NoError(t, err, "Expected no error while calculating thumbprint")
45+
assert.Equal(t, "990F4193972F2BECF12DDEDA5237F9C952F20D9E", thumbprint, "Thumbprint should match expected value")
46+
}

0 commit comments

Comments
 (0)