Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 0 additions & 27 deletions keymanager/km_common/src/key_types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -213,33 +213,6 @@ mod tests {
use super::*;
use crate::algorithms::{AeadAlgorithm, KdfAlgorithm, KemAlgorithm};

fn create_key_record<F>(
algo: HpkeAlgorithm,
expiry_secs: u64,
spec_builder: F,
) -> Result<KeyRecord, crypto::Error>
where
F: FnOnce(HpkeAlgorithm, PublicKey) -> KeySpec,
{
let (pub_key, priv_key) = crypto::generate_keypair(KemAlgorithm::DhkemX25519HkdfSha256)?;
let id = Uuid::new_v4();
let vault = Vault::new(secret_box::SecretBox::from(priv_key))
.map_err(|_| crypto::Error::CryptoError)?;
let now = Instant::now();
let delete_after = now
.checked_add(Duration::from_secs(expiry_secs))
.ok_or(crypto::Error::UnsupportedAlgorithm)?;
Ok(KeyRecord {
meta: KeyMetadata {
id,
created_at: now,
delete_after,
spec: spec_builder(algo, pub_key),
},
private_key: vault,
})
}

#[test]
fn test_create_binding_key_success() {
let algo = HpkeAlgorithm {
Expand Down
22 changes: 10 additions & 12 deletions keymanager/workload_service/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,14 @@ func TestIntegrationGenerateKeysEndToEnd(t *testing.T) {
kpsSvc := kps.NewService(kpskcc.GenerateKEMKeypair)
srv := NewServer(kpsSvc, &realWorkloadService{})

reqBody, err := json.Marshal(GenerateKemRequest{
Algorithm: KemAlgorithmDHKEMX25519HKDFSHA256,
KeyProtectionMechanism: KeyProtectionMechanismVM,
Lifespan: ProtoDuration{Seconds: 3600},
reqBody, err := json.Marshal(GenerateKeyRequest{
Algorithm: AlgorithmDetails{Type: "kem", Params: AlgorithmParams{KemID: KemAlgorithmDHKEMX25519HKDFSHA256}},
Lifespan: ProtoDuration{Seconds: 3600},
})
if err != nil {
t.Fatalf("failed to marshal request: %v", err)
}
req := httptest.NewRequest(http.MethodPost, "/v1/keys:generate_kem", bytes.NewReader(reqBody))
req := httptest.NewRequest(http.MethodPost, "/v1/keys:generate_key", bytes.NewReader(reqBody))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
srv.Handler().ServeHTTP(w, req)
Expand All @@ -46,7 +45,7 @@ func TestIntegrationGenerateKeysEndToEnd(t *testing.T) {
t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
}

var resp GenerateKemResponse
var resp GenerateKeyResponse
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
t.Fatalf("failed to decode response: %v", err)
}
Expand Down Expand Up @@ -78,15 +77,14 @@ func TestIntegrationGenerateKeysUniqueMappings(t *testing.T) {
// Generate two key sets.
var kemUUIDs [2]uuid.UUID
for i := 0; i < 2; i++ {
reqBody, err := json.Marshal(GenerateKemRequest{
Algorithm: KemAlgorithmDHKEMX25519HKDFSHA256,
KeyProtectionMechanism: KeyProtectionMechanismVM,
Lifespan: ProtoDuration{Seconds: 3600},
reqBody, err := json.Marshal(GenerateKeyRequest{
Algorithm: AlgorithmDetails{Type: "kem", Params: AlgorithmParams{KemID: KemAlgorithmDHKEMX25519HKDFSHA256}},
Lifespan: ProtoDuration{Seconds: 3600},
})
if err != nil {
t.Fatalf("call %d: failed to marshal request: %v", i+1, err)
}
req := httptest.NewRequest(http.MethodPost, "/v1/keys:generate_kem", bytes.NewReader(reqBody))
req := httptest.NewRequest(http.MethodPost, "/v1/keys:generate_key", bytes.NewReader(reqBody))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
srv.Handler().ServeHTTP(w, req)
Expand All @@ -95,7 +93,7 @@ func TestIntegrationGenerateKeysUniqueMappings(t *testing.T) {
t.Fatalf("call %d: expected status 200, got %d: %s", i+1, w.Code, w.Body.String())
}

var resp GenerateKemResponse
var resp GenerateKeyResponse
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
t.Fatalf("call %d: failed to decode response: %v", i+1, err)
}
Expand Down
64 changes: 1 addition & 63 deletions keymanager/workload_service/proto_enums.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,64 +57,12 @@ func (k *KemAlgorithm) UnmarshalJSON(data []byte) error {
return fmt.Errorf("unknown KemAlgorithm: %q", s)
}

// KeyProtectionMechanism represents the requested key protection backend.
type KeyProtectionMechanism int32

const (
// KeyProtectionMechanismDefault is the default but invalid value.
KeyProtectionMechanismDefault KeyProtectionMechanism = 1
// KeyProtectionMechanismVM specifies that the key is protected by the VM.
KeyProtectionMechanismVM KeyProtectionMechanism = 2
)

var (
keyProtectionMechanismToString = map[KeyProtectionMechanism]string{
KeyProtectionMechanismDefault: "DEFAULT",
KeyProtectionMechanismVM: "KEY_PROTECTION_VM",
}
stringToKeyProtectionMechanism = map[string]KeyProtectionMechanism{
"DEFAULT": KeyProtectionMechanismDefault,
"KEY_PROTECTION_VM": KeyProtectionMechanismVM,
}
)

func (k KeyProtectionMechanism) String() string {
if s, ok := keyProtectionMechanismToString[k]; ok {
return s
}
return fmt.Sprintf("KEY_PROTECTION_MECHANISM_UNKNOWN(%d)", k)
}

// MarshalJSON converts a KeyProtectionMechanism enum value to its JSON string representation.
func (k KeyProtectionMechanism) MarshalJSON() ([]byte, error) {
return json.Marshal(k.String())
}

// UnmarshalJSON parses a JSON string into a KeyProtectionMechanism enum value.
func (k *KeyProtectionMechanism) UnmarshalJSON(data []byte) error {
var s string
if err := json.Unmarshal(data, &s); err != nil {
return fmt.Errorf("KeyProtectionMechanism must be a string")
}
if v, ok := stringToKeyProtectionMechanism[s]; ok {
*k = v
return nil
}
return fmt.Errorf("unknown KeyProtectionMechanism: %q", s)
}

// Supported algorithms and mechanisms.
var (
// SupportedKemAlgorithms is the source of truth for supported algorithms.
SupportedKemAlgorithms = []KemAlgorithm{
KemAlgorithmDHKEMX25519HKDFSHA256,
}

// SupportedKeyProtectionMechanisms is the source of truth for supported mechanisms.
SupportedKeyProtectionMechanisms = []KeyProtectionMechanism{
KeyProtectionMechanismDefault,
KeyProtectionMechanismVM,
}
)

// IsSupported returns true if the KEM algorithm is supported.
Expand All @@ -127,17 +75,7 @@ func (k KemAlgorithm) IsSupported() bool {
return false
}

// IsSupported returns true if the key protection mechanism is supported.
func (k KeyProtectionMechanism) IsSupported() bool {
for _, supported := range SupportedKeyProtectionMechanisms {
if k == supported {
return true
}
}
return false
}

// SupportedKemAlgorithmsString returns a comma-separated list of supported KEM keymanager.
// SupportedKemAlgorithmsString returns a comma-separated list of supported KEM algorithms.
func SupportedKemAlgorithmsString() string {
var names []string
for _, k := range SupportedKemAlgorithms {
Expand Down
46 changes: 24 additions & 22 deletions keymanager/workload_service/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,14 @@ func (d ProtoDuration) MarshalJSON() ([]byte, error) {
return json.Marshal(d.Seconds)
}

// GenerateKemRequest is the JSON body for POST /v1/keys:generate_kem.
type GenerateKemRequest struct {
Algorithm KemAlgorithm `json:"algorithm"`
KeyProtectionMechanism KeyProtectionMechanism `json:"key_protection_mechanism"`
Lifespan ProtoDuration `json:"lifespan"`
// GenerateKeyRequest is the JSON body for POST /v1/keys:generate_key.
type GenerateKeyRequest struct {
Algorithm AlgorithmDetails `json:"algorithm"`
Lifespan ProtoDuration `json:"lifespan"`
}

// GenerateKemResponse is returned by POST /v1/keys:generate_kem.
type GenerateKemResponse struct {
// GenerateKeyResponse is returned by POST /v1/keys:generate_key.
type GenerateKeyResponse struct {
KeyHandle KeyHandle `json:"key_handle"`
}

Expand Down Expand Up @@ -129,7 +128,7 @@ func NewServer(keyProtectionService KeyProtectionService, workloadService Worklo
}

mux := http.NewServeMux()
mux.HandleFunc("POST /v1/keys:generate_kem", s.handleGenerateKem)
mux.HandleFunc("POST /v1/keys:generate_key", s.handleGenerateKey)
mux.HandleFunc("GET /v1/capabilities", s.handleGetCapabilities)

s.httpServer = &http.Server{Handler: mux}
Expand Down Expand Up @@ -166,34 +165,37 @@ func (s *Server) LookupBindingUUID(kemUUID uuid.UUID) (uuid.UUID, bool) {
return id, ok
}

func (s *Server) handleGenerateKem(w http.ResponseWriter, r *http.Request) {
var req GenerateKemRequest
func (s *Server) handleGenerateKey(w http.ResponseWriter, r *http.Request) {
var req GenerateKeyRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeError(w, fmt.Sprintf("invalid request body: %v", err), http.StatusBadRequest)
return
}

// Validate algorithm.
if !req.Algorithm.IsSupported() {
writeError(w, fmt.Sprintf("unsupported algorithm: %s. Supported algorithms: %s", req.Algorithm, SupportedKemAlgorithmsString()), http.StatusBadRequest)
// Validate lifespan is positive.
if req.Lifespan.Seconds == 0 {
writeError(w, "lifespan must be greater than 0s", http.StatusBadRequest)
return
}

// Validate keyProtectionMechanism.
if !req.KeyProtectionMechanism.IsSupported() {
writeError(w, fmt.Sprintf("unsupported keyProtectionMechanism: %s", req.KeyProtectionMechanism), http.StatusBadRequest)
return
switch req.Algorithm.Type {
case "kem":
s.generateKEMKey(w, req)
default:
writeError(w, fmt.Sprintf("unsupported algorithm type: %q. Only 'kem' is supported.", req.Algorithm.Type), http.StatusBadRequest)
}
}

// Validate lifespan is positive.
if req.Lifespan.Seconds == 0 {
writeError(w, "lifespan must be greater than 0s", http.StatusBadRequest)
func (s *Server) generateKEMKey(w http.ResponseWriter, req GenerateKeyRequest) {
// Validate algorithm.
if !req.Algorithm.Params.KemID.IsSupported() {
writeError(w, fmt.Sprintf("unsupported algorithm: %s. Supported algorithms: %s", req.Algorithm.Params.KemID, SupportedKemAlgorithmsString()), http.StatusBadRequest)
return
}

// Construct the full HPKE algorithm suite based on the requested KEM.
// We currently only support one suite.
algo, err := req.Algorithm.ToHpkeAlgorithm()
algo, err := req.Algorithm.Params.KemID.ToHpkeAlgorithm()
if err != nil {
writeError(w, err.Error(), http.StatusBadRequest)
return
Expand All @@ -219,7 +221,7 @@ func (s *Server) handleGenerateKem(w http.ResponseWriter, r *http.Request) {
s.mu.Unlock()

// Step 4: Return KEM UUID to workload.
resp := GenerateKemResponse{
resp := GenerateKeyResponse{
KeyHandle: KeyHandle{Handle: kemUUID.String()},
}
writeJSON(w, resp, http.StatusOK)
Expand Down
Loading
Loading