Skip to content

Commit 02da1b8

Browse files
committed
feat(keymanager): implement enumerate KEM keys API in Go
1 parent f26280f commit 02da1b8

File tree

7 files changed

+472
-17
lines changed

7 files changed

+472
-17
lines changed

keymanager/key_protection_service/key_custody_core/kps_key_custody_core_cgo.go

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ package kpskcc
77
#cgo LDFLAGS: -L${SRCDIR}/../../target/release -L${SRCDIR}/../../target/debug -lkps_key_custody_core
88
#cgo LDFLAGS: -lcrypto -lssl
99
#cgo LDFLAGS: -lpthread -ldl -lm -lstdc++
10+
#include <stdbool.h>
1011
#include "include/kps_key_custody_core.h"
1112
*/
1213
import "C"
@@ -61,3 +62,53 @@ func GenerateKEMKeypair(algo *algorithms.HpkeAlgorithm, bindingPubKey []byte, li
6162
copy(pubkey, pubkeyBuf[:pubkeyLen])
6263
return id, pubkey, nil
6364
}
65+
66+
// EnumerateKEMKeys retrieves active KEM key entries from the Rust KCC registry with pagination.
67+
// Returns a list of keys and a boolean indicating if there are more keys to fetch.
68+
func EnumerateKEMKeys(limit, offset int) ([]KEMKeyInfo, bool, error) {
69+
if limit <= 0 {
70+
return nil, false, fmt.Errorf("limit must be positive")
71+
}
72+
if offset < 0 {
73+
return nil, false, fmt.Errorf("offset must be non-negative")
74+
}
75+
76+
entries := make([]C.KpsKeyInfo, limit)
77+
var hasMore C.bool
78+
79+
rc := C.key_manager_enumerate_kem_keys(
80+
&entries[0],
81+
C.size_t(limit),
82+
C.size_t(offset),
83+
&hasMore,
84+
)
85+
if rc < 0 {
86+
return nil, false, fmt.Errorf("key_manager_enumerate_kem_keys failed with code %d", rc)
87+
}
88+
89+
count := int(rc)
90+
result := make([]KEMKeyInfo, count)
91+
for i, e := range entries[:count] {
92+
id, err := uuid.FromBytes(C.GoBytes(unsafe.Pointer(&e.uuid[0]), 16))
93+
if err != nil {
94+
return nil, false, fmt.Errorf("invalid UUID at index %d: %w", i, err)
95+
}
96+
97+
kemPubKey := C.GoBytes(unsafe.Pointer(&e.pub_key[0]), C.int(e.pub_key_len))
98+
99+
algoBytes := C.GoBytes(unsafe.Pointer(&e.algorithm[0]), C.int(e.algorithm_len))
100+
algo := &algorithms.HpkeAlgorithm{}
101+
if err := proto.Unmarshal(algoBytes, algo); err != nil {
102+
return nil, false, fmt.Errorf("failed to unmarshal algorithm for key %d: %w", i, err)
103+
}
104+
105+
result[i] = KEMKeyInfo{
106+
ID: id,
107+
Algorithm: algo,
108+
KEMPubKey: kemPubKey,
109+
RemainingLifespanSecs: uint64(e.remaining_lifespan_secs),
110+
}
111+
}
112+
113+
return result, bool(hasMore), nil
114+
}
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
package kpskcc
2+
3+
import (
4+
algorithms "github.com/google/go-tpm-tools/keymanager/km_common/proto"
5+
"github.com/google/uuid"
6+
)
7+
8+
// KEMKeyInfo holds metadata for a single KEM key returned by EnumerateKEMKeys.
9+
type KEMKeyInfo struct {
10+
ID uuid.UUID
11+
Algorithm *algorithms.HpkeAlgorithm
12+
KEMPubKey []byte
13+
RemainingLifespanSecs uint64
14+
}
Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
// Package keyprotectionservice implements the Key Orchestration Layer (KOL)
22
// for the Key Protection Service. It wraps the KPS Key Custody Core (KCC) FFI
3-
// to provide a Go-native interface for KEM key generation.
3+
// to provide a Go-native interface for KEM key generation and enumeration.
44
package keyprotectionservice
55

66
import (
7+
kpskcc "github.com/google/go-tpm-tools/keymanager/key_protection_service/key_custody_core"
78
"github.com/google/uuid"
89

910
algorithms "github.com/google/go-tpm-tools/keymanager/km_common/proto"
@@ -14,18 +15,35 @@ type KEMKeyGenerator interface {
1415
GenerateKEMKeypair(algo *algorithms.HpkeAlgorithm, bindingPubKey []byte, lifespanSecs uint64) (uuid.UUID, []byte, error)
1516
}
1617

17-
// Service implements KEMKeyGenerator by delegating to the KPS KCC FFI.
18+
// KEMKeyEnumerator enumerates active KEM keys in the KPS registry.
19+
type KEMKeyEnumerator interface {
20+
EnumerateKEMKeys(limit, offset int) ([]kpskcc.KEMKeyInfo, bool, error)
21+
}
22+
23+
// Service implements KEMKeyGenerator and KEMKeyEnumerator by delegating to the KPS KCC FFI.
1824
type Service struct {
1925
generateKEMKeypairFn func(algo *algorithms.HpkeAlgorithm, bindingPubKey []byte, lifespanSecs uint64) (uuid.UUID, []byte, error)
26+
enumerateKEMKeysFn func(limit, offset int) ([]kpskcc.KEMKeyInfo, bool, error)
2027
}
2128

22-
// NewService creates a new KPS KOL service with the given KCC function.
23-
func NewService(generateKEMKeypairFn func(algo *algorithms.HpkeAlgorithm, bindingPubKey []byte, lifespanSecs uint64) (uuid.UUID, []byte, error)) *Service {
24-
return &Service{generateKEMKeypairFn: generateKEMKeypairFn}
29+
// NewService creates a new KPS KOL service with the given KCC functions.
30+
func NewService(
31+
generateKEMKeypairFn func(algo *algorithms.HpkeAlgorithm, bindingPubKey []byte, lifespanSecs uint64) (uuid.UUID, []byte, error),
32+
enumerateKEMKeysFn func(limit, offset int) ([]kpskcc.KEMKeyInfo, bool, error),
33+
) *Service {
34+
return &Service{
35+
generateKEMKeypairFn: generateKEMKeypairFn,
36+
enumerateKEMKeysFn: enumerateKEMKeysFn,
37+
}
2538
}
2639

2740
// GenerateKEMKeypair generates a KEM keypair linked to the provided binding
2841
// public key by calling the KPS KCC FFI.
2942
func (s *Service) GenerateKEMKeypair(algo *algorithms.HpkeAlgorithm, bindingPubKey []byte, lifespanSecs uint64) (uuid.UUID, []byte, error) {
3043
return s.generateKEMKeypairFn(algo, bindingPubKey, lifespanSecs)
3144
}
45+
46+
// EnumerateKEMKeys retrieves all active KEM key entries from the KPS KCC registry.
47+
func (s *Service) EnumerateKEMKeys(limit, offset int) ([]kpskcc.KEMKeyInfo, bool, error) {
48+
return s.enumerateKEMKeysFn(limit, offset)
49+
}

keymanager/key_protection_service/service_test.go

Lines changed: 77 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"fmt"
55
"testing"
66

7+
kpskcc "github.com/google/go-tpm-tools/keymanager/key_protection_service/key_custody_core"
78
"github.com/google/uuid"
89

910
algorithms "github.com/google/go-tpm-tools/keymanager/km_common/proto"
@@ -16,15 +17,20 @@ func TestServiceGenerateKEMKeypairSuccess(t *testing.T) {
1617
expectedPubKey[i] = byte(i + 10)
1718
}
1819

19-
svc := NewService(func(_ *algorithms.HpkeAlgorithm, bindingPubKey []byte, lifespanSecs uint64) (uuid.UUID, []byte, error) {
20-
if len(bindingPubKey) != 32 {
21-
t.Fatalf("expected 32-byte binding public key, got %d", len(bindingPubKey))
22-
}
23-
if lifespanSecs != 7200 {
24-
t.Fatalf("expected lifespanSecs 7200, got %d", lifespanSecs)
25-
}
26-
return expectedUUID, expectedPubKey, nil
27-
})
20+
svc := NewService(
21+
func(algo *algorithms.HpkeAlgorithm, bindingPubKey []byte, lifespanSecs uint64) (uuid.UUID, []byte, error) {
22+
if len(bindingPubKey) != 32 {
23+
t.Fatalf("expected 32-byte binding public key, got %d", len(bindingPubKey))
24+
}
25+
if lifespanSecs != 7200 {
26+
t.Fatalf("expected lifespanSecs 7200, got %d", lifespanSecs)
27+
}
28+
return expectedUUID, expectedPubKey, nil
29+
},
30+
func(limit, offset int) ([]kpskcc.KEMKeyInfo, bool, error) {
31+
return nil, false, nil
32+
},
33+
)
2834

2935
id, pubKey, err := svc.GenerateKEMKeypair(&algorithms.HpkeAlgorithm{}, make([]byte, 32), 7200)
3036
if err != nil {
@@ -39,12 +45,71 @@ func TestServiceGenerateKEMKeypairSuccess(t *testing.T) {
3945
}
4046

4147
func TestServiceGenerateKEMKeypairError(t *testing.T) {
42-
svc := NewService(func(_ *algorithms.HpkeAlgorithm, _ []byte, _ uint64) (uuid.UUID, []byte, error) {
43-
return uuid.Nil, nil, fmt.Errorf("FFI error")
44-
})
48+
svc := NewService(
49+
func(algo *algorithms.HpkeAlgorithm, bindingPubKey []byte, lifespanSecs uint64) (uuid.UUID, []byte, error) {
50+
return uuid.Nil, nil, fmt.Errorf("FFI error")
51+
},
52+
func(limit, offset int) ([]kpskcc.KEMKeyInfo, bool, error) {
53+
return nil, false, nil
54+
},
55+
)
4556

4657
_, _, err := svc.GenerateKEMKeypair(&algorithms.HpkeAlgorithm{}, make([]byte, 32), 3600)
4758
if err == nil {
4859
t.Fatal("expected error, got nil")
4960
}
5061
}
62+
63+
func TestServiceEnumerateKEMKeysSuccess(t *testing.T) {
64+
expectedKeys := []kpskcc.KEMKeyInfo{
65+
{
66+
ID: uuid.New(),
67+
Algorithm: &algorithms.HpkeAlgorithm{
68+
Kem: algorithms.KemAlgorithm_KEM_ALGORITHM_DHKEM_X25519_HKDF_SHA256,
69+
Kdf: algorithms.KdfAlgorithm_KDF_ALGORITHM_HKDF_SHA256,
70+
Aead: algorithms.AeadAlgorithm_AEAD_ALGORITHM_AES_256_GCM,
71+
},
72+
KEMPubKey: make([]byte, 32),
73+
RemainingLifespanSecs: 3500,
74+
},
75+
}
76+
77+
svc := NewService(
78+
func(algo *algorithms.HpkeAlgorithm, bindingPubKey []byte, lifespanSecs uint64) (uuid.UUID, []byte, error) {
79+
return uuid.Nil, nil, nil
80+
},
81+
func(limit, offset int) ([]kpskcc.KEMKeyInfo, bool, error) {
82+
if limit != 100 || offset != 0 {
83+
return nil, false, fmt.Errorf("unexpected limit/offset: %d/%d", limit, offset)
84+
}
85+
return expectedKeys, false, nil
86+
},
87+
)
88+
89+
keys, _, err := svc.EnumerateKEMKeys(100, 0)
90+
if err != nil {
91+
t.Fatalf("unexpected error: %v", err)
92+
}
93+
if len(keys) != 1 {
94+
t.Fatalf("expected 1 key, got %d", len(keys))
95+
}
96+
if keys[0].ID != expectedKeys[0].ID {
97+
t.Fatalf("expected ID %s, got %s", expectedKeys[0].ID, keys[0].ID)
98+
}
99+
}
100+
101+
func TestServiceEnumerateKEMKeysError(t *testing.T) {
102+
svc := NewService(
103+
func(algo *algorithms.HpkeAlgorithm, bindingPubKey []byte, lifespanSecs uint64) (uuid.UUID, []byte, error) {
104+
return uuid.Nil, nil, nil
105+
},
106+
func(limit, offset int) ([]kpskcc.KEMKeyInfo, bool, error) {
107+
return nil, false, fmt.Errorf("enumerate error")
108+
},
109+
)
110+
111+
_, _, err := svc.EnumerateKEMKeys(100, 0)
112+
if err == nil {
113+
t.Fatal("expected error, got nil")
114+
}
115+
}

keymanager/workload_service/proto_enums.go

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,3 +97,88 @@ func (k KemAlgorithm) ToHpkeAlgorithm() (*algorithms.HpkeAlgorithm, error) {
9797
return nil, fmt.Errorf("unsupported algorithm: %s", k)
9898
}
9999
}
100+
101+
// KdfAlgorithm represents the requested KDF algorithm.
102+
type KdfAlgorithm int32
103+
104+
const (
105+
KdfAlgorithmUnspecified KdfAlgorithm = 0
106+
// Corrected from HKDF_SHA384 to HKDF_SHA256 based on ToHpkeAlgorithm usage which maps to HKDF_SHA256 (val 1)
107+
KdfAlgorithmHKDFSHA256 KdfAlgorithm = 1
108+
)
109+
110+
var (
111+
kdfAlgorithmToString = map[KdfAlgorithm]string{
112+
KdfAlgorithmUnspecified: "KDF_ALGORITHM_UNSPECIFIED",
113+
KdfAlgorithmHKDFSHA256: "HKDF_SHA256",
114+
}
115+
stringToKdfAlgorithm = map[string]KdfAlgorithm{
116+
"KDF_ALGORITHM_UNSPECIFIED": KdfAlgorithmUnspecified,
117+
"HKDF_SHA256": KdfAlgorithmHKDFSHA256,
118+
}
119+
)
120+
121+
func (k KdfAlgorithm) String() string {
122+
if s, ok := kdfAlgorithmToString[k]; ok {
123+
return s
124+
}
125+
return fmt.Sprintf("KDF_ALGORITHM_UNKNOWN(%d)", k)
126+
}
127+
128+
func (k KdfAlgorithm) MarshalJSON() ([]byte, error) {
129+
return json.Marshal(k.String())
130+
}
131+
132+
func (k *KdfAlgorithm) UnmarshalJSON(data []byte) error {
133+
var s string
134+
if err := json.Unmarshal(data, &s); err != nil {
135+
return fmt.Errorf("KdfAlgorithm must be a string")
136+
}
137+
if v, ok := stringToKdfAlgorithm[s]; ok {
138+
*k = v
139+
return nil
140+
}
141+
return fmt.Errorf("unknown KdfAlgorithm: %q", s)
142+
}
143+
144+
// AeadAlgorithm represents the requested AEAD algorithm.
145+
type AeadAlgorithm int32
146+
147+
const (
148+
AeadAlgorithmUnspecified AeadAlgorithm = 0
149+
AeadAlgorithmAES256GCM AeadAlgorithm = 1
150+
)
151+
152+
var (
153+
aeadAlgorithmToString = map[AeadAlgorithm]string{
154+
AeadAlgorithmUnspecified: "AEAD_ALGORITHM_UNSPECIFIED",
155+
AeadAlgorithmAES256GCM: "AES_256_GCM",
156+
}
157+
stringToAeadAlgorithm = map[string]AeadAlgorithm{
158+
"AEAD_ALGORITHM_UNSPECIFIED": AeadAlgorithmUnspecified,
159+
"AES_256_GCM": AeadAlgorithmAES256GCM,
160+
}
161+
)
162+
163+
func (k AeadAlgorithm) String() string {
164+
if s, ok := aeadAlgorithmToString[k]; ok {
165+
return s
166+
}
167+
return fmt.Sprintf("AEAD_ALGORITHM_UNKNOWN(%d)", k)
168+
}
169+
170+
func (k AeadAlgorithm) MarshalJSON() ([]byte, error) {
171+
return json.Marshal(k.String())
172+
}
173+
174+
func (k *AeadAlgorithm) UnmarshalJSON(data []byte) error {
175+
var s string
176+
if err := json.Unmarshal(data, &s); err != nil {
177+
return fmt.Errorf("AeadAlgorithm must be a string")
178+
}
179+
if v, ok := stringToAeadAlgorithm[s]; ok {
180+
*k = v
181+
return nil
182+
}
183+
return fmt.Errorf("unknown AeadAlgorithm: %q", s)
184+
}

0 commit comments

Comments
 (0)