Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .goreleaser.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@ builds:
- arm64
env:
- CGO_ENABLED=0
overrides:
- goos: linux
goarch: amd64
env:
- CGO_ENABLED=1
id: "gotpm"
main: ./cmd/gotpm
binary: gotpm
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,6 @@ clean = true

[export]
item_types = ["functions", "structs", "constants", "enums"]
include = ["key_manager_generate_kem_keypair", "key_manager_destroy_kem_key", "key_manager_enumerate_kem_keys", "KpsKeyInfo", "MAX_ALGORITHM_LEN", "MAX_PUBLIC_KEY_LEN"]
include = ["key_manager_generate_kem_keypair", "key_manager_get_kem_key", "key_manager_destroy_kem_key", "key_manager_enumerate_kem_keys", "KpsKeyInfo", "MAX_ALGORITHM_LEN", "MAX_PUBLIC_KEY_LEN"]


Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,13 @@ int32_t key_manager_decap_and_seal(const uint8_t *uuid_bytes,
uint8_t *out_ciphertext,
size_t out_ciphertext_len);

int32_t key_manager_get_kem_key(const uint8_t *uuid_bytes,
uint8_t *out_kem_pubkey,
size_t out_kem_pubkey_len,
uint8_t *out_binding_pubkey,
size_t out_binding_pubkey_len,
uint64_t *out_delete_after);

#ifdef __cplusplus
} // extern "C"
#endif // __cplusplus
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@ import (

"github.com/google/uuid"

algorithms "github.com/google/go-tpm-tools/keymanager/km_common/proto"
keymanager "github.com/google/go-tpm-tools/keymanager/km_common/proto"
)

var defaultAlgo = &algorithms.HpkeAlgorithm{
Kem: algorithms.KemAlgorithm_KEM_ALGORITHM_DHKEM_X25519_HKDF_SHA256,
Kdf: algorithms.KdfAlgorithm_KDF_ALGORITHM_HKDF_SHA256,
Aead: algorithms.AeadAlgorithm_AEAD_ALGORITHM_AES_256_GCM,
var defaultAlgo = &keymanager.HpkeAlgorithm{
Kem: keymanager.KemAlgorithm_KEM_ALGORITHM_DHKEM_X25519_HKDF_SHA256,
Kdf: keymanager.KdfAlgorithm_KDF_ALGORITHM_HKDF_SHA256,
Aead: keymanager.AeadAlgorithm_AEAD_ALGORITHM_AES_256_GCM,
}

func TestIntegrationGenerateKEMKeypair(t *testing.T) {
Expand Down Expand Up @@ -61,3 +61,49 @@ func TestIntegrationGenerateKEMKeypairUniqueness(t *testing.T) {
t.Fatalf("expected unique UUIDs, got same: %s", id1)
}
}

func TestIntegrationGetKemKey(t *testing.T) {
bindingPK := make([]byte, 32)
for i := range bindingPK {
bindingPK[i] = byte(i + 50)
}

id, pubKey, err := GenerateKEMKeypair(defaultAlgo, bindingPK, 3600)
if err != nil {
t.Fatalf("GenerateKEMKeypair failed: %v", err)
}

retrievedKemPK, retrievedBindingPK, deleteAfter, err := GetKemKey(id)
if err != nil {
t.Fatalf("GetKemKey failed: %v", err)
}

if len(retrievedKemPK) != len(pubKey) {
t.Fatalf("expected KEM pubkey length %d, got %d", len(pubKey), len(retrievedKemPK))
}
for i := range pubKey {
if pubKey[i] != retrievedKemPK[i] {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of iterating and comparing, we can use bytes.Equal

import "bytes"

// ...

if !bytes.Equal(pubKey, retrievedKemPK) {
	t.Fatalf("KEM pubkey mismatch: expected %x, got %x", pubKey, retrievedKemPK)
}

if !bytes.Equal(bindingPK, retrievedBindingPK) {
	t.Fatalf("binding pubkey mismatch: expected %x, got %x", bindingPK, retrievedBindingPK)
}

t.Fatalf("KEM pubkey mismatch at index %d", i)
}
}

if len(retrievedBindingPK) != len(bindingPK) {
t.Fatalf("expected binding pubkey length %d, got %d", len(bindingPK), len(retrievedBindingPK))
}
for i := range bindingPK {
if bindingPK[i] != retrievedBindingPK[i] {
t.Fatalf("binding pubkey mismatch at index %d", i)
}
}

if deleteAfter == 0 {
t.Fatal("expected non-zero deleteAfter timestamp")
}
}

func TestIntegrationGetKemKeyNotFound(t *testing.T) {
_, _, _, err := GetKemKey(uuid.New())
if err == nil {
t.Fatal("expected error for non-existent UUID")
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we also check the expected error string:

expectedErrMsg := "key_manager_get_kem_key failed with code -1"
if !strings.Contains(err.Error(), expectedErrMsg) {
    t.Fatalf("expected error containing %q, got: %v", expectedErrMsg, err)
}

if you decide to export typed sentinel errors ErrKeyNotFound then you could also use errors.Is but comparing string is also fine.

}
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@ import (
"github.com/google/uuid"
"google.golang.org/protobuf/proto"

algorithms "github.com/google/go-tpm-tools/keymanager/km_common/proto"
keymanager "github.com/google/go-tpm-tools/keymanager/km_common/proto"
)

// GenerateKEMKeypair generates an X25519 HPKE KEM keypair linked to the
// provided binding public key via Rust FFI.
// Returns the UUID key handle and the KEM public key bytes.
func GenerateKEMKeypair(algo *algorithms.HpkeAlgorithm, bindingPubKey []byte, lifespanSecs uint64) (uuid.UUID, []byte, error) {
func GenerateKEMKeypair(algo *keymanager.HpkeAlgorithm, bindingPubKey []byte, lifespanSecs uint64) (uuid.UUID, []byte, error) {
if len(bindingPubKey) == 0 {
return uuid.Nil, nil, fmt.Errorf("binding public key must not be empty")
}
Expand Down Expand Up @@ -61,3 +61,34 @@ func GenerateKEMKeypair(algo *algorithms.HpkeAlgorithm, bindingPubKey []byte, li
copy(pubkey, pubkeyBuf[:pubkeyLen])
return id, pubkey, nil
}

// GetKemKey retrieves KEM and binding public keys and delete_after timestamp via Rust FFI.
func GetKemKey(id uuid.UUID) ([]byte, []byte, uint64, error) {
uuidBytes, err := id.MarshalBinary()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we know the size of uuid, should we create a fixed size array?

var uuidBytes [16]byte
copy(uuidBytes[:], id[:])

if err != nil {
return nil, nil, 0, fmt.Errorf("failed to marshal UUID: %v", err)
}

var kemPubkeyBuf [32]byte
var bindingPubkeyBuf [32]byte
var deleteAfter C.uint64_t

rc := C.key_manager_get_kem_key(
(*C.uint8_t)(unsafe.Pointer(&uuidBytes[0])),
(*C.uint8_t)(unsafe.Pointer(&kemPubkeyBuf[0])),
C.size_t(len(kemPubkeyBuf)),
(*C.uint8_t)(unsafe.Pointer(&bindingPubkeyBuf[0])),
C.size_t(len(bindingPubkeyBuf)),
&deleteAfter,
)
if rc != 0 {
return nil, nil, 0, fmt.Errorf("key_manager_get_kem_key failed with code %d", rc)
}

kemPubkey := make([]byte, len(kemPubkeyBuf))
copy(kemPubkey, kemPubkeyBuf[:])
bindingPubkey := make([]byte, len(bindingPubkeyBuf))
copy(bindingPubkey, bindingPubkeyBuf[:])

return kemPubkey, bindingPubkey, uint64(deleteAfter), nil
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,8 @@ import (
func GenerateKEMKeypair(_ *algorithms.HpkeAlgorithm, _ []byte, _ uint64) (uuid.UUID, []byte, error) {
return uuid.Nil, nil, fmt.Errorf("GenerateKEMKeypair is not supported on this architecture")
}

// GetKemKey is a stub for architectures where the Rust library is not supported.
func GetKemKey(id uuid.UUID) ([]byte, []byte, uint64, error) {
return nil, nil, 0, fmt.Errorf("GetKemKey is not supported on this architecture")
}
173 changes: 173 additions & 0 deletions keymanager/key_protection_service/key_custody_core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,105 @@ pub unsafe extern "C" fn key_manager_decap_and_seal(
.unwrap_or(-1)
}

/// Internal function to retrieve a KEM key's public keys and expiration.
fn get_kem_key_internal(uuid: Uuid) -> Result<(PublicKey, PublicKey, u64), i32> {
let record = KEY_REGISTRY.get_key(&uuid).ok_or(-1)?;
match &record.meta.spec {
KeySpec::KemWithBindingPub {
kem_public_key,
binding_public_key,
..
} => {
let remaining = record
.meta
.delete_after
.saturating_duration_since(std::time::Instant::now());
let unix_expiry = std::time::SystemTime::now() + remaining;
let unix_secs = unix_expiry
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or(std::time::Duration::from_secs(0))
.as_secs();
Ok((
kem_public_key.clone(),
binding_public_key.clone(),
unix_secs,
))
}
_ => Err(-1),
}
}

/// Retrieves the KEM and binding public keys associated with the given UUID.
///
/// ## Arguments
/// * `uuid_bytes` - A pointer to a 16-byte buffer containing the key UUID.
/// * `out_kem_pubkey` - A pointer to a buffer where the KEM public key will be written.
/// * `out_kem_pubkey_len` - The size of `out_kem_pubkey` buffer.
/// * `out_binding_pubkey` - A pointer to a buffer where the binding public key will be written.
/// * `out_binding_pubkey_len` - The size of `out_binding_pubkey` buffer.
/// * `out_delete_after` - A pointer to a u64 where the UNIX expiration timestamp will be written.
///
/// ## Safety
/// This function is unsafe because it dereferences raw pointers.
///
/// ## Returns
/// * `0` on success.
/// * `-1` if arguments are invalid or key is not found.
/// * `-2` if either public key buffer is too small.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we're checking != so it fails with -2 even if buffer is too large

#[unsafe(no_mangle)]
pub unsafe extern "C" fn key_manager_get_kem_key(
uuid_bytes: *const u8,
out_kem_pubkey: *mut u8,
out_kem_pubkey_len: usize,
out_binding_pubkey: *mut u8,
out_binding_pubkey_len: usize,
out_delete_after: *mut u64,
) -> i32 {
std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
if uuid_bytes.is_null()
|| out_kem_pubkey.is_null()
|| out_kem_pubkey_len == 0
|| out_binding_pubkey.is_null()
|| out_binding_pubkey_len == 0
|| out_delete_after.is_null()
{
return -1;
}

// Convert to Safe Types
let uuid_slice = unsafe { std::slice::from_raw_parts(uuid_bytes, 16) };
let out_kem_pubkey_slice =
unsafe { std::slice::from_raw_parts_mut(out_kem_pubkey, out_kem_pubkey_len) };
Comment on lines +470 to +471
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we're checking the out_kem_pubkey_len value much later after we're already using it to create a slice.

from_raw_parts_mut asserts to the compiler that valid len bytes exists. And this could result in undefined behaviour.

check the length first and create the slice only after we know the length is correct.

let out_binding_pubkey_slice = unsafe {
std::slice::from_raw_parts_mut(out_binding_pubkey, out_binding_pubkey_len)
};
let out_delete_after_ref = unsafe { &mut *out_delete_after };

let uuid = match Uuid::from_slice(uuid_slice) {
Ok(u) => u,
Err(_) => return -1,
};

// Call Safe Internal Function
match get_kem_key_internal(uuid) {
Ok((kem_pubkey, binding_pubkey, delete_after)) => {
if out_kem_pubkey_len != kem_pubkey.as_bytes().len()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we've followed != size enforcement at many places which is fine for now since we own the cgo layer too.

But would < be a more generalized check? that way caller can pass larger buffer too if they don't know the exact size.

|| out_binding_pubkey_len != binding_pubkey.as_bytes().len()
{
return -2;
}

out_kem_pubkey_slice.copy_from_slice(kem_pubkey.as_bytes());
out_binding_pubkey_slice.copy_from_slice(binding_pubkey.as_bytes());
*out_delete_after_ref = delete_after;
0 // Success
}
Err(e) => e,
}
}))
.unwrap_or(-1)
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down Expand Up @@ -992,4 +1091,78 @@ mod tests {

assert_eq!(result, -2);
}

#[test]
fn test_get_kem_key_success() {
let binding_pubkey = [42u8; 32];
let mut uuid_bytes = [0u8; 16];
let mut generated_kem_pubkey_bytes = [0u8; 32];
let pubkey_len = generated_kem_pubkey_bytes.len();
let algo = HpkeAlgorithm {
kem: KemAlgorithm::DhkemX25519HkdfSha256 as i32,
kdf: KdfAlgorithm::HkdfSha256 as i32,
aead: AeadAlgorithm::Aes256Gcm as i32,
};
let algo_bytes = algo.encode_to_vec();

// Generate a key to retrieve.
let res = unsafe {
key_manager_generate_kem_keypair(
algo_bytes.as_ptr(),
algo_bytes.len(),
binding_pubkey.as_ptr(),
binding_pubkey.len(),
3600,
uuid_bytes.as_mut_ptr(),
generated_kem_pubkey_bytes.as_mut_ptr(),
pubkey_len,
)
};
assert_eq!(res, 0);

// Now, retrieve it.
let mut retrieved_kem_pubkey_bytes = [0u8; 32];
let mut retrieved_binding_pubkey_bytes = [0u8; 32];
let mut delete_after: u64 = 0;

let result = unsafe {
key_manager_get_kem_key(
uuid_bytes.as_ptr(),
retrieved_kem_pubkey_bytes.as_mut_ptr(),
retrieved_kem_pubkey_bytes.len(),
retrieved_binding_pubkey_bytes.as_mut_ptr(),
retrieved_binding_pubkey_bytes.len(),
&mut delete_after,
)
};

assert_eq!(result, 0);
assert_eq!(generated_kem_pubkey_bytes, retrieved_kem_pubkey_bytes);
assert_eq!(binding_pubkey, retrieved_binding_pubkey_bytes);
let now_unix = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs();
assert!(delete_after > now_unix);
}

#[test]
fn test_get_kem_key_not_found() {
let uuid_bytes = [42u8; 16]; // Some non-existent UUID.
let mut kem_pubkey_bytes = [0u8; 32];
let mut binding_pubkey_bytes = [0u8; 32];
let mut delete_after: u64 = 0;

let result = unsafe {
key_manager_get_kem_key(
uuid_bytes.as_ptr(),
kem_pubkey_bytes.as_mut_ptr(),
kem_pubkey_bytes.len(),
binding_pubkey_bytes.as_mut_ptr(),
binding_pubkey_bytes.len(),
&mut delete_after,
)
};
assert_eq!(result, -1);
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add a test test_get_kem_key_invalid_buffer_len that passes retrieved_kem_pubkey_bytes.len() + 1 or - 1 to ensure the C-API correctly rejects it with a -2 status code

}
31 changes: 19 additions & 12 deletions keymanager/key_protection_service/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,26 +6,33 @@ package keyprotectionservice
import (
"github.com/google/uuid"

algorithms "github.com/google/go-tpm-tools/keymanager/km_common/proto"
keymanager "github.com/google/go-tpm-tools/keymanager/km_common/proto"
)

// KEMKeyGenerator generates KEM keypairs linked to a binding public key.
type KEMKeyGenerator interface {
GenerateKEMKeypair(algo *algorithms.HpkeAlgorithm, bindingPubKey []byte, lifespanSecs uint64) (uuid.UUID, []byte, error)
// KeyProtectionService defines the interface for the underlying Key Custody Core operations.
type KeyProtectionService interface {
GenerateKEMKeypair(algo *keymanager.HpkeAlgorithm, bindingPubKey []byte, lifespanSecs uint64) (uuid.UUID, []byte, error)
GetKemKey(id uuid.UUID) ([]byte, []byte, uint64, error)
}

// Service implements KEMKeyGenerator by delegating to the KPS KCC FFI.
// Service implements KEM keypair operations by delegating to a KeyProtectionService backend.
type Service struct {
generateKEMKeypairFn func(algo *algorithms.HpkeAlgorithm, bindingPubKey []byte, lifespanSecs uint64) (uuid.UUID, []byte, error)
kps KeyProtectionService
}

// NewService creates a new KPS KOL service with the given KCC function.
func NewService(generateKEMKeypairFn func(algo *algorithms.HpkeAlgorithm, bindingPubKey []byte, lifespanSecs uint64) (uuid.UUID, []byte, error)) *Service {
return &Service{generateKEMKeypairFn: generateKEMKeypairFn}
// NewService creates a new Service with the given KeyProtectionService backend.
func NewService(kps KeyProtectionService) *Service {
return &Service{kps: kps}
}

// GenerateKEMKeypair generates a KEM keypair linked to the provided binding
// public key by calling the KPS KCC FFI.
func (s *Service) GenerateKEMKeypair(algo *algorithms.HpkeAlgorithm, bindingPubKey []byte, lifespanSecs uint64) (uuid.UUID, []byte, error) {
return s.generateKEMKeypairFn(algo, bindingPubKey, lifespanSecs)
// public key by calling the underlying KeyProtectionService backend.
func (s *Service) GenerateKEMKeypair(algo *keymanager.HpkeAlgorithm, bindingPubKey []byte, lifespanSecs uint64) (uuid.UUID, []byte, error) {
return s.kps.GenerateKEMKeypair(algo, bindingPubKey, lifespanSecs)
}

// GetKemKey retrieves KEM and binding public keys and delete_after timestamp
// by calling the underlying KeyProtectionService backend.
func (s *Service) GetKemKey(id uuid.UUID) ([]byte, []byte, uint64, error) {
return s.kps.GetKemKey(id)
}
Loading
Loading