Skip to content

Commit 6738b07

Browse files
authored
Merge e5b91ac into 5869d13
2 parents 5869d13 + e5b91ac commit 6738b07

File tree

2 files changed

+120
-4
lines changed

2 files changed

+120
-4
lines changed

internal/command/client.go

Lines changed: 63 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,14 @@ package command
1818

1919
import (
2020
"fmt"
21+
"strings"
2122

2223
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
2324
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
2425
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
2526
commandsdk "github.com/Keyfactor/keyfactor-go-client/v3/api"
27+
"github.com/go-logr/logr"
28+
"github.com/golang-jwt/jwt/v5"
2629
"golang.org/x/net/context"
2730
"golang.org/x/oauth2"
2831
"golang.org/x/oauth2/google"
@@ -95,6 +98,11 @@ type azure struct {
9598

9699
// GetAccessToken implements TokenCredential.
97100
func (a *azure) GetAccessToken(ctx context.Context) (string, error) {
101+
log := log.FromContext(ctx)
102+
103+
// To prevent clogging logs every time JWT is generated
104+
initializing := a.cred == nil
105+
98106
// Lazily create the credential if needed
99107
if a.cred == nil {
100108
c, err := azidentity.NewDefaultAzureCredential(nil)
@@ -104,6 +112,8 @@ func (a *azure) GetAccessToken(ctx context.Context) (string, error) {
104112
a.cred = c
105113
}
106114

115+
log.Info(fmt.Sprintf("generating Default Azure Credentials with scopes %s", strings.Join(a.scopes, " ")))
116+
107117
// Request a token with the provided scopes
108118
token, err := a.cred.GetToken(ctx, policy.TokenRequestOptions{
109119
Scopes: a.scopes,
@@ -112,8 +122,20 @@ func (a *azure) GetAccessToken(ctx context.Context) (string, error) {
112122
return "", fmt.Errorf("%w: failed to fetch token: %w", errTokenFetchFailure, err)
113123
}
114124

115-
log.FromContext(ctx).Info("fetched token using Azure DefaultAzureCredential")
116-
return token.Token, nil
125+
tokenString := token.Token
126+
127+
if initializing {
128+
// Only want to output this once, don't want to output this every time the JWT is generated
129+
130+
log.Info("==== BEGIN DEBUG: DefaultAzureCredential JWT ======")
131+
132+
printClaims(log, tokenString, []string{"aud", "azp", "iss", "sub", "oid"})
133+
134+
log.Info("==== END DEBUG: DefaultAzureCredential JWT ======")
135+
}
136+
137+
log.Info("fetched token using Azure DefaultAzureCredential")
138+
return tokenString, nil
117139
}
118140

119141
func newAzureDefaultCredentialSource(ctx context.Context, scopes []string) (*azure, error) {
@@ -142,17 +164,28 @@ type gcp struct {
142164

143165
// GetAccessToken implements TokenCredential.
144166
func (g *gcp) GetAccessToken(ctx context.Context) (string, error) {
145-
// Lazily create the TokenSource if it's nil.
146167
log := log.FromContext(ctx)
168+
169+
// To prevent clogging logs every time JWT is generated
170+
initializing := g.tokenSource == nil
171+
172+
// Lazily create the TokenSource if it's nil.
147173
if g.tokenSource == nil {
174+
log.Info(fmt.Sprintf("generating default Google credentials with scopes: %s", strings.Join(g.scopes, " ")))
175+
148176
credentials, err := google.FindDefaultCredentials(ctx, g.scopes...)
149177
if err != nil {
150178
return "", fmt.Errorf("%w: failed to find GCP ADC: %w", errTokenFetchFailure, err)
151179
}
152180
log.Info(fmt.Sprintf("generating a Google OIDC ID token..."))
153181

182+
// Default audience to "command" if not provided
183+
aud := getValueOrDefault(g.audience, "command")
184+
185+
log.Info(fmt.Sprintf("generating Google id token with audience %s", aud))
186+
154187
// Use credentials to generate a JWT (requires a service account)
155-
tokenSource, err := idtoken.NewTokenSource(ctx, getValueOrDefault(g.audience, "command"), idtoken.WithCredentialsJSON(credentials.JSON))
188+
tokenSource, err := idtoken.NewTokenSource(ctx, aud, idtoken.WithCredentialsJSON(credentials.JSON))
156189
if err != nil {
157190
return "", fmt.Errorf("%w: failed to get GCP ID Token Source: %w", errTokenFetchFailure, err)
158191
}
@@ -171,6 +204,14 @@ func (g *gcp) GetAccessToken(ctx context.Context) (string, error) {
171204
return "", fmt.Errorf("%w: failed to fetch token from GCP ADC token source: %w", errTokenFetchFailure, err)
172205
}
173206

207+
if initializing {
208+
// Only want to output this once, don't want to output this every time the JWT is generated
209+
210+
log.Info("==== BEGIN DEBUG: Default Google ID Token JWT ======")
211+
printClaims(log, token.AccessToken, []string{"aud", "iss", "sub", "email"})
212+
log.Info("==== END DEBUG: Default Google ID Token JWT ======")
213+
}
214+
174215
log.Info("fetched token using GCP ApplicationDefaultCredential")
175216

176217
return token.AccessToken, nil
@@ -188,3 +229,21 @@ func newGCPDefaultCredentialSource(ctx context.Context, audience string, scopes
188229
tokenCredentialSource = source
189230
return source, nil
190231
}
232+
233+
func printClaims(log logr.Logger, token string, claimsToPrint []string) error {
234+
tokenRaw, _, err := new(jwt.Parser).ParseUnverified(token, jwt.MapClaims{})
235+
if err != nil {
236+
log.Error(err, "failed to parse JWT")
237+
return fmt.Errorf("failed to parse JWT: %w", err)
238+
}
239+
240+
claims, _ := tokenRaw.Claims.(jwt.MapClaims)
241+
242+
for _, key := range claimsToPrint {
243+
if value, ok := claims[key]; ok {
244+
log.Info(fmt.Sprintf(" %s: %s", key, value))
245+
}
246+
}
247+
248+
return nil
249+
}

internal/command/client_test.go

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
package command
2+
3+
import (
4+
"testing"
5+
6+
"github.com/go-logr/logr/testr"
7+
"github.com/golang-jwt/jwt/v5"
8+
"github.com/stretchr/testify/assert"
9+
)
10+
11+
func TestPrintClaims(t *testing.T) {
12+
t.Run("valid jwt returns no error", func(t *testing.T) {
13+
// Sample JWT with dummy claims (no signature needed for ParseUnverified)
14+
claims := jwt.MapClaims{
15+
"aud": "api://1234",
16+
"iss": "https://sts.windows.net/tenant-id/",
17+
"sub": "user-id",
18+
}
19+
token := createUnsignedJWT(t, claims)
20+
21+
// Use testr logger
22+
testLogger := testr.New(t)
23+
24+
// Call the function
25+
err := printClaims(testLogger, token, []string{"aud", "iss", "sub"})
26+
assert.NoError(t, err)
27+
})
28+
29+
t.Run("invalid jwt returns an error", func(t *testing.T) {
30+
// Use testr logger
31+
testLogger := testr.New(t)
32+
33+
// Call the function
34+
err := printClaims(testLogger, "abcdefghijklmnop", []string{"aud", "iss", "sub"})
35+
assert.Error(t, err)
36+
})
37+
38+
t.Run("jwt with no claims returns error", func(t *testing.T) {
39+
// Use testr logger
40+
testLogger := testr.New(t)
41+
42+
// Call the function
43+
err := printClaims(testLogger, "eyJhbGciOiJub25lIiwidHlwIjoiSldUIn0..", []string{"aud", "iss", "sub"})
44+
assert.Error(t, err)
45+
})
46+
}
47+
48+
func createUnsignedJWT(t *testing.T, claims jwt.MapClaims) string {
49+
t.Helper()
50+
51+
token := jwt.NewWithClaims(jwt.SigningMethodNone, claims)
52+
str, err := token.SignedString(jwt.UnsafeAllowNoneSignatureType)
53+
if err != nil {
54+
t.Fatalf("failed to create test token: %v", err)
55+
}
56+
return str
57+
}

0 commit comments

Comments
 (0)