Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 49 additions & 24 deletions vault/cache_key.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ func (k ClientCacheKey) SameParent(other ClientCacheKey) (bool, error) {
//
// See computeClientCacheKey for more details on how the client cache is derived
func ComputeClientCacheKeyFromClient(c Client) (ClientCacheKey, error) {
return computeClientCacheKey(c.GetVaultAuthObj(), c.GetVaultConnectionObj(), c.GetCredentialProvider().GetUID())
return computeClientCacheKey(c.GetVaultAuthObj(), c.GetVaultConnectionObj(), c.GetCredentialProvider().GetUID(), false)
}

// ComputeClientCacheKeyFromObj for use in a ClientCache. It is derived from the configuration of obj.
Expand Down Expand Up @@ -101,7 +101,7 @@ func ComputeClientCacheKeyFromObj(ctx context.Context, client ctrlclient.Client,
return "", err
}

return computeClientCacheKey(authObj, connObj, provider.GetUID())
return computeClientCacheKey(authObj, connObj, provider.GetUID(), false)
}

// ComputeClientCacheKeyFromMeta for use in a ClientCache. It is derived from the configuration of obj.
Expand Down Expand Up @@ -134,7 +134,7 @@ func ComputeClientCacheKeyFromMeta(ctx context.Context, client ctrlclient.Client
return "", err
}

return computeClientCacheKey(authObj, connObj, provider.GetUID())
return computeClientCacheKey(authObj, connObj, provider.GetUID(), false)
}

// ComputeClientCacheKey for use in a ClientCache. It is derived by combining instances of
Expand All @@ -153,38 +153,63 @@ func ComputeClientCacheKeyFromMeta(ctx context.Context, client ctrlclient.Client
// allowed for Kubernetes resources, which is 63 characters.
//
// If the computed cache-key exceeds 63 characters, the limit imposed for Kubernetes resource names,
// or if any of the inputs do not coform in any way, and error will be returned.
func computeClientCacheKey(authObj *secretsv1beta1.VaultAuth, connObj *secretsv1beta1.VaultConnection, providerUID types.UID) (ClientCacheKey, error) {
// or if any of the inputs do not conform in any way, an error will be returned.
//
// Cache key generation is simpler when isStandalone is true (indicating a client without access to k8s resources):
// - Uses content-based hashes of authObj.Spec and connObj.Spec instead of UIDs
// - Generation is always 1 since objects aren't actual k8s resources
func computeClientCacheKey(authObj *secretsv1beta1.VaultAuth, connObj *secretsv1beta1.VaultConnection, providerUID types.UID, isStandalone bool) (ClientCacheKey, error) {
var errs error
method := authObj.Spec.Method
if method == "" {
errs = errors.Join(errs, fmt.Errorf("auth method is empty"))
}

// only used for duplicate UID detection, all values are ignored
seen := make(map[types.UID]int)
requireUIDLen := 36
validateUID := func(name string, uid types.UID) {
if len(uid) != requireUIDLen {
errs = errors.Join(errs, fmt.Errorf("%w %d, must be %d", errorInvalidUIDLength, len(uid), requireUIDLen))
var input string
if isStandalone {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if you may have considered setting the UIDs from the Spec, before passing them into this function. That might obviate the need for the special casing here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought about it, but got a bit scared off by the "read-only" description of the UID field (since traditionally the field is set by the API server).

I like your idea about making this whole thing a standaloneClientFactory instead though; that may clean up all of this and also allow us to more definitively avoid breaking existing code.

// Standalone mode: use content-based hashes instead of K8s UIDs
if len(providerUID) == 0 {
errs = errors.Join(errs, fmt.Errorf("providerUID cannot be empty"))
}
if _, ok := seen[uid]; ok {
errs = errors.Join(errs, fmt.Errorf("%w %s", errorDuplicateUID, uid))

if errs != nil {
return "", errs
}

authSpecHash := helpers.HashString(fmt.Sprintf("%+v", authObj.Spec))
connSpecHash := helpers.HashString(fmt.Sprintf("%+v", connObj.Spec))

Comment on lines +179 to +181
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense to validate the various Spec? Could the "repr" be empty or missing required values?

// Format: "authHash-1.connHash-1.providerUID"
// (generation is always 1 for standalone since we didn't fetch the auth and conn objects from a K8s resource)
input = fmt.Sprintf("%s-%d.%s-%d.%s",
authSpecHash, 1,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: is there a particular reason to interpolate the generation since it is always 1?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was just trying to match the format of the original client cache key, but I suppose it's not necessary!

connSpecHash, 1, providerUID)
} else {
// Normal VSO operation: use K8s resource UIDs with strict validation
seen := make(map[types.UID]int)
requireUIDLen := 36
validateUID := func(name string, uid types.UID) {
if len(uid) != requireUIDLen {
errs = errors.Join(errs, fmt.Errorf("%w %d, must be %d", errorInvalidUIDLength, len(uid), requireUIDLen))
}
if _, ok := seen[uid]; ok {
errs = errors.Join(errs, fmt.Errorf("%w %s", errorDuplicateUID, uid))
}
seen[uid] = 1
}
seen[uid] = 1
}

validateUID("authUID", authObj.GetUID())
validateUID("connUID", connObj.GetUID())
validateUID("providerUID", providerUID)
validateUID("authUID", authObj.GetUID())
validateUID("connUID", connObj.GetUID())
validateUID("providerUID", providerUID)

if errs != nil {
return "", errs
}
if errs != nil {
return "", errs
}

input := fmt.Sprintf("%s-%d.%s-%d.%s",
authObj.GetUID(), authObj.GetGeneration(),
connObj.GetUID(), connObj.GetGeneration(), providerUID)
input = fmt.Sprintf("%s-%d.%s-%d.%s",
authObj.GetUID(), authObj.GetGeneration(),
connObj.GetUID(), connObj.GetGeneration(), providerUID)
}

key := strings.ToLower(method + "-" + helpers.HashString(input))
if len(key) > 63 {
Expand Down
203 changes: 200 additions & 3 deletions vault/cache_key_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -192,16 +192,213 @@ func Test_computeClientCacheKey(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := computeClientCacheKey(tt.authObj, tt.connObj, tt.providerUID)
if !tt.wantErr(t, err, fmt.Sprintf("computeClientCacheKey(%v, %v, %v)",
got, err := computeClientCacheKey(tt.authObj, tt.connObj, tt.providerUID, false)
if !tt.wantErr(t, err, fmt.Sprintf("computeClientCacheKey(%v, %v, %v, false)",
tt.authObj, tt.connObj, tt.providerUID)) {
return
}
assert.Equalf(t, tt.want, got, "computeClientCacheKey(%v, %v, %v)", tt.authObj, tt.connObj, tt.providerUID)
assert.Equalf(t, tt.want, got, "computeClientCacheKey(%v, %v, %v, false)", tt.authObj, tt.connObj, tt.providerUID)
})
}
}

func Test_computeClientCacheKey_standalone(t *testing.T) {
t.Parallel()
tests := []struct {
name string
authObj *secretsv1beta1.VaultAuth
connObj *secretsv1beta1.VaultConnection
providerUID types.UID
wantErr assert.ErrorAssertionFunc
wantPrefix string // expected method prefix in cache key
}{
{
name: "standalone-empty-provideruid-fails",
authObj: &secretsv1beta1.VaultAuth{
Spec: secretsv1beta1.VaultAuthSpec{
Method: consts.ProviderMethodAppRole,
},
},
connObj: &secretsv1beta1.VaultConnection{},
providerUID: "",
wantErr: func(t assert.TestingT, err error, i ...interface{}) bool {
return assert.Error(t, err, i...)
},
},
{
name: "standalone-different-auth-specs-different-keys",
authObj: &secretsv1beta1.VaultAuth{
Spec: secretsv1beta1.VaultAuthSpec{
Method: consts.ProviderMethodAppRole,
AppRole: &secretsv1beta1.VaultAuthConfigAppRole{
RoleID: "role-1",
},
},
},
connObj: &secretsv1beta1.VaultConnection{
Spec: secretsv1beta1.VaultConnectionSpec{
Address: "http://vault:8200",
},
},
providerUID: providerUID,
wantErr: assert.NoError,
wantPrefix: "approle-",
},
{
name: "standalone-different-conn-specs-different-keys",
authObj: &secretsv1beta1.VaultAuth{
Spec: secretsv1beta1.VaultAuthSpec{
Method: consts.ProviderMethodJWT,
},
},
connObj: &secretsv1beta1.VaultConnection{
Spec: secretsv1beta1.VaultConnectionSpec{
Address: "http://vault:9200",
},
},
providerUID: providerUID,
wantErr: assert.NoError,
wantPrefix: "jwt-",
},
{
name: "standalone-empty-method-fails",
authObj: &secretsv1beta1.VaultAuth{
Spec: secretsv1beta1.VaultAuthSpec{
Method: "",
},
},
connObj: &secretsv1beta1.VaultConnection{
ObjectMeta: metav1.ObjectMeta{},
},
providerUID: providerUID,
wantErr: func(t assert.TestingT, err error, i ...interface{}) bool {
return assert.Error(t, err, i...)
},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := computeClientCacheKey(tt.authObj, tt.connObj, tt.providerUID, true)
if !tt.wantErr(t, err, fmt.Sprintf("computeClientCacheKey(%v, %v, %v, true)",
tt.authObj, tt.connObj, tt.providerUID)) {
return
}
if tt.wantPrefix != "" {
assert.True(t, strings.HasPrefix(got.String(), tt.wantPrefix),
"expected key to start with %q, got %q", tt.wantPrefix, got.String())
}
})
}
}

// Test that standalone mode produces deterministic, content-based cache keys
func Test_computeClientCacheKey_standalone_deterministic(t *testing.T) {
t.Parallel()

authObj := &secretsv1beta1.VaultAuth{
Spec: secretsv1beta1.VaultAuthSpec{
Method: consts.ProviderMethodAppRole,
AppRole: &secretsv1beta1.VaultAuthConfigAppRole{
RoleID: "test-role",
},
},
}
connObj := &secretsv1beta1.VaultConnection{
Spec: secretsv1beta1.VaultConnectionSpec{
Address: "http://vault:8200",
},
}

// Call twice with identical inputs
key1, err1 := computeClientCacheKey(authObj, connObj, providerUID, true)
require.NoError(t, err1)

key2, err2 := computeClientCacheKey(authObj, connObj, providerUID, true)
require.NoError(t, err2)

// Should produce identical cache keys
assert.Equal(t, key1, key2, "identical specs should produce identical cache keys")
}

// Test that standalone and normal mode produce different cache keys
func Test_computeClientCacheKey_standalone_vs_normal(t *testing.T) {
t.Parallel()

authObjWithUID := &secretsv1beta1.VaultAuth{
ObjectMeta: metav1.ObjectMeta{
UID: authUID,
Generation: 1,
},
Spec: secretsv1beta1.VaultAuthSpec{
Method: consts.ProviderMethodAppRole,
AppRole: &secretsv1beta1.VaultAuthConfigAppRole{
RoleID: "test-role",
},
},
}
connObjWithUID := &secretsv1beta1.VaultConnection{
ObjectMeta: metav1.ObjectMeta{
UID: connUID,
Generation: 2,
},
Spec: secretsv1beta1.VaultConnectionSpec{
Address: "http://vault:8200",
},
}

authObjNoUID := &secretsv1beta1.VaultAuth{
ObjectMeta: metav1.ObjectMeta{
UID: "",
Generation: 0,
},
Spec: secretsv1beta1.VaultAuthSpec{
Method: consts.ProviderMethodAppRole,
AppRole: &secretsv1beta1.VaultAuthConfigAppRole{
RoleID: "test-role",
},
},
}
connObjNoUID := &secretsv1beta1.VaultConnection{
ObjectMeta: metav1.ObjectMeta{
UID: "",
Generation: 0,
},
Spec: secretsv1beta1.VaultConnectionSpec{
Address: "http://vault:8200",
},
}

// Normal mode with UIDs - should succeed
normalKey, err := computeClientCacheKey(authObjWithUID, connObjWithUID, providerUID, false)
require.NoError(t, err)

// Standalone mode without UIDs - should succeed
standaloneKey, err := computeClientCacheKey(authObjNoUID, connObjNoUID, providerUID, true)
require.NoError(t, err)

// Keys should be different because:
// - Normal mode uses UID + generation
// - Standalone mode uses spec hash + generation=1
assert.NotEqual(t, normalKey, standaloneKey,
"standalone and normal mode should produce different cache keys even with same specs")

assert.True(t, strings.HasPrefix(normalKey.String(), "approle-"))
assert.True(t, strings.HasPrefix(standaloneKey.String(), "approle-"))

// Normal mode with empty UIDs - should fail (requires UIDs)
_, err = computeClientCacheKey(authObjNoUID, connObjNoUID, providerUID, false)
require.Error(t, err, "normal mode should fail with empty UIDs")

// Standalone mode with UIDs - should succeed (UIDs allowed but not required)
standaloneKeyWithUIDs, err := computeClientCacheKey(authObjWithUID, connObjWithUID, providerUID, true)
require.NoError(t, err, "standalone mode should succeed even with UIDs present")

// Standalone mode with and without UIDs should produce different keys
assert.NotEqual(t, standaloneKey, standaloneKeyWithUIDs,
"standalone mode should produce different keys with vs without UIDs due to different spec hashes")
}

func TestComputeClientCacheKeyFromClient(t *testing.T) {
t.Parallel()
tests := []computeClientCacheKeyTest{
Expand Down
11 changes: 10 additions & 1 deletion vault/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ type ClientOptions struct {
GlobalVaultAuthOptions *common.GlobalVaultAuthOptions
CredentialProviderFactory credentials.CredentialProviderFactory
UserAgent string
IsStandalone bool
}

func defaultClientOptions() *ClientOptions {
Expand Down Expand Up @@ -199,6 +200,7 @@ var _ Client = (*defaultClient)(nil)
type defaultClient struct {
client *api.Client
isClone bool
isStandalone bool
authObj *secretsv1beta1.VaultAuth
connObj *secretsv1beta1.VaultConnection
authSecret *api.Secret
Expand Down Expand Up @@ -356,7 +358,13 @@ func (c *defaultClient) GetCacheKey() (ClientCacheKey, error) {
}

func (c *defaultClient) getCacheKey() (ClientCacheKey, error) {
cacheKey, err := ComputeClientCacheKeyFromClient(c)
var cacheKey ClientCacheKey
var err error
if c.isStandalone {
cacheKey, err = computeClientCacheKey(c.authObj, c.connObj, c.credentialProvider.GetUID(), true)
} else {
cacheKey, err = ComputeClientCacheKeyFromClient(c)
}
if err != nil {
return "", err
}
Expand Down Expand Up @@ -824,6 +832,7 @@ func (c *defaultClient) init(ctx context.Context, client ctrlclient.Client,
}

c.skipRenewal = opts.SkipRenewal
c.isStandalone = opts.IsStandalone
c.credentialProvider = credentialProvider
c.client = vc
c.authObj = authObj
Expand Down
Loading