Skip to content

Commit 86fa06d

Browse files
committed
refactor: interface based DI for mocking
1 parent 356e0b6 commit 86fa06d

File tree

2 files changed

+43
-35
lines changed

2 files changed

+43
-35
lines changed

keymanager/key_protection_service/service.go

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,30 +20,31 @@ type KEMKeyEnumerator interface {
2020
EnumerateKEMKeys(limit, offset int) ([]kpskcc.KEMKeyInfo, bool, error)
2121
}
2222

23+
// KeyCustodyCore defines the required FFI interactions for KPS.
24+
type KeyCustodyCore interface {
25+
GenerateKEMKeypair(algo *algorithms.HpkeAlgorithm, bindingPubKey []byte, lifespanSecs uint64) (uuid.UUID, []byte, error)
26+
EnumerateKEMKeys(limit, offset int) ([]kpskcc.KEMKeyInfo, bool, error)
27+
}
28+
2329
// Service implements KEMKeyGenerator and KEMKeyEnumerator by delegating to the KPS KCC FFI.
2430
type Service struct {
25-
generateKEMKeypairFn func(algo *algorithms.HpkeAlgorithm, bindingPubKey []byte, lifespanSecs uint64) (uuid.UUID, []byte, error)
26-
enumerateKEMKeysFn func(limit, offset int) ([]kpskcc.KEMKeyInfo, bool, error)
31+
kcc KeyCustodyCore
2732
}
2833

29-
// NewService creates a new KPS KOL service with the given KCC functions.
30-
func NewService(
31-
generateKEMKeypairFn func(algo *algorithms.HpkeAlgorithm, bindingPubKey []byte, lifespanSecs uint64) (uuid.UUID, []byte, error),
32-
enumerateKEMKeysFn func(limit, offset int) ([]kpskcc.KEMKeyInfo, bool, error),
33-
) *Service {
34+
// NewService creates a new KPS KOL service with the given KCC implementation.
35+
func NewService(kcc KeyCustodyCore) *Service {
3436
return &Service{
35-
generateKEMKeypairFn: generateKEMKeypairFn,
36-
enumerateKEMKeysFn: enumerateKEMKeysFn,
37+
kcc: kcc,
3738
}
3839
}
3940

4041
// GenerateKEMKeypair generates a KEM keypair linked to the provided binding
4142
// public key by calling the KPS KCC FFI.
4243
func (s *Service) GenerateKEMKeypair(algo *algorithms.HpkeAlgorithm, bindingPubKey []byte, lifespanSecs uint64) (uuid.UUID, []byte, error) {
43-
return s.generateKEMKeypairFn(algo, bindingPubKey, lifespanSecs)
44+
return s.kcc.GenerateKEMKeypair(algo, bindingPubKey, lifespanSecs)
4445
}
4546

4647
// EnumerateKEMKeys retrieves all active KEM key entries from the KPS KCC registry.
4748
func (s *Service) EnumerateKEMKeys(limit, offset int) ([]kpskcc.KEMKeyInfo, bool, error) {
48-
return s.enumerateKEMKeysFn(limit, offset)
49+
return s.kcc.EnumerateKEMKeys(limit, offset)
4950
}

keymanager/key_protection_service/service_test.go

Lines changed: 31 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,34 @@ import (
1010
algorithms "github.com/google/go-tpm-tools/keymanager/km_common/proto"
1111
)
1212

13+
type mockKCC struct {
14+
generateFn func(algo *algorithms.HpkeAlgorithm, bindingPubKey []byte, lifespanSecs uint64) (uuid.UUID, []byte, error)
15+
enumerateFn func(limit, offset int) ([]kpskcc.KEMKeyInfo, bool, error)
16+
}
17+
18+
func (m *mockKCC) GenerateKEMKeypair(algo *algorithms.HpkeAlgorithm, bindingPubKey []byte, lifespanSecs uint64) (uuid.UUID, []byte, error) {
19+
if m.generateFn != nil {
20+
return m.generateFn(algo, bindingPubKey, lifespanSecs)
21+
}
22+
return uuid.Nil, nil, nil
23+
}
24+
25+
func (m *mockKCC) EnumerateKEMKeys(limit, offset int) ([]kpskcc.KEMKeyInfo, bool, error) {
26+
if m.enumerateFn != nil {
27+
return m.enumerateFn(limit, offset)
28+
}
29+
return nil, false, nil
30+
}
31+
1332
func TestServiceGenerateKEMKeypairSuccess(t *testing.T) {
1433
expectedUUID := uuid.New()
1534
expectedPubKey := make([]byte, 32)
1635
for i := range expectedPubKey {
1736
expectedPubKey[i] = byte(i + 10)
1837
}
1938

20-
svc := NewService(
21-
func(_ *algorithms.HpkeAlgorithm, bindingPubKey []byte, lifespanSecs uint64) (uuid.UUID, []byte, error) {
39+
svc := NewService(&mockKCC{
40+
generateFn: func(_ *algorithms.HpkeAlgorithm, bindingPubKey []byte, lifespanSecs uint64) (uuid.UUID, []byte, error) {
2241
if len(bindingPubKey) != 32 {
2342
t.Fatalf("expected 32-byte binding public key, got %d", len(bindingPubKey))
2443
}
@@ -27,10 +46,7 @@ func TestServiceGenerateKEMKeypairSuccess(t *testing.T) {
2746
}
2847
return expectedUUID, expectedPubKey, nil
2948
},
30-
func(_, _ int) ([]kpskcc.KEMKeyInfo, bool, error) {
31-
return nil, false, nil
32-
},
33-
)
49+
})
3450

3551
id, pubKey, err := svc.GenerateKEMKeypair(&algorithms.HpkeAlgorithm{}, make([]byte, 32), 7200)
3652
if err != nil {
@@ -45,14 +61,11 @@ func TestServiceGenerateKEMKeypairSuccess(t *testing.T) {
4561
}
4662

4763
func TestServiceGenerateKEMKeypairError(t *testing.T) {
48-
svc := NewService(
49-
func(_ *algorithms.HpkeAlgorithm, _ []byte, _ uint64) (uuid.UUID, []byte, error) {
64+
svc := NewService(&mockKCC{
65+
generateFn: func(_ *algorithms.HpkeAlgorithm, _ []byte, _ uint64) (uuid.UUID, []byte, error) {
5066
return uuid.Nil, nil, fmt.Errorf("FFI error")
5167
},
52-
func(_, _ int) ([]kpskcc.KEMKeyInfo, bool, error) {
53-
return nil, false, nil
54-
},
55-
)
68+
})
5669

5770
_, _, err := svc.GenerateKEMKeypair(&algorithms.HpkeAlgorithm{}, make([]byte, 32), 3600)
5871
if err == nil {
@@ -74,17 +87,14 @@ func TestServiceEnumerateKEMKeysSuccess(t *testing.T) {
7487
},
7588
}
7689

77-
svc := NewService(
78-
func(_ *algorithms.HpkeAlgorithm, _ []byte, _ uint64) (uuid.UUID, []byte, error) {
79-
return uuid.Nil, nil, nil
80-
},
81-
func(limit, offset int) ([]kpskcc.KEMKeyInfo, bool, error) {
90+
svc := NewService(&mockKCC{
91+
enumerateFn: func(limit, offset int) ([]kpskcc.KEMKeyInfo, bool, error) {
8292
if limit != 100 || offset != 0 {
8393
return nil, false, fmt.Errorf("unexpected limit/offset: %d/%d", limit, offset)
8494
}
8595
return expectedKeys, false, nil
8696
},
87-
)
97+
})
8898

8999
keys, _, err := svc.EnumerateKEMKeys(100, 0)
90100
if err != nil {
@@ -99,14 +109,11 @@ func TestServiceEnumerateKEMKeysSuccess(t *testing.T) {
99109
}
100110

101111
func TestServiceEnumerateKEMKeysError(t *testing.T) {
102-
svc := NewService(
103-
func(_ *algorithms.HpkeAlgorithm, _ []byte, _ uint64) (uuid.UUID, []byte, error) {
104-
return uuid.Nil, nil, nil
105-
},
106-
func(_, _ int) ([]kpskcc.KEMKeyInfo, bool, error) {
112+
svc := NewService(&mockKCC{
113+
enumerateFn: func(_, _ int) ([]kpskcc.KEMKeyInfo, bool, error) {
107114
return nil, false, fmt.Errorf("enumerate error")
108115
},
109-
)
116+
})
110117

111118
_, _, err := svc.EnumerateKEMKeys(100, 0)
112119
if err == nil {

0 commit comments

Comments
 (0)