Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 47 additions & 69 deletions cng/mlkem.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ package cng

import (
"errors"
"runtime"

"github.com/microsoft/go-crypto-winnative/internal/bcrypt"
)
Expand All @@ -33,6 +32,11 @@ const (
encapsulationKeySizeMLKEM1024 = 1568
)

const (
sizeOfPrivateSeedMLKEM1024 = 4 + 4 + 4 + 10 + seedSizeMLKEM // dwMagic (4) + cbParameterSet (4) + cbKey (4) + ParameterSet (8 "1024\0") + Key (64)
sizeOfPublicKeyMLKEM1024 = 4 + 4 + 4 + 10 + encapsulationKeySizeMLKEM1024 // dwMagic (4) + cbParameterSet (4) + cbKey (4) + ParameterSet (8 "1024\0") + Key (1184)
)

// putUint32LE puts a uint32 in little-endian byte order.
func putUint32LE(b []byte, v uint32) {
b[0] = byte(v)
Expand Down Expand Up @@ -95,41 +99,38 @@ func generateMLKEMKey(paramSet string, dst []byte) error {
}

// Export the private key blob
blob := make([]byte, sizeOfPrivateSeedMLKEM1024) // use the larger size to be safe and avoid an allocation
var size uint32
err = bcrypt.ExportKey(hKey, 0, utf16PtrFromString(bcrypt.MLKEM_PRIVATE_SEED_BLOB), nil, &size, 0)
if err != nil {
return err
}

blob := make([]byte, size)
err = bcrypt.ExportKey(hKey, 0, utf16PtrFromString(bcrypt.MLKEM_PRIVATE_SEED_BLOB), blob, &size, 0)
if err != nil {
return err
}

// Extract raw key bytes into destination
return extractMLKEMKeyBytes(blob, dst)
return extractMLKEMKeyBytes(dst, blob[:size])
}

// newMLKEMKeyBlob creates a key blob from raw key bytes.
func newMLKEMKeyBlob(paramSet string, keyBytes []byte, magic bcrypt.KeyBlobMagicNumber) ([]byte, error) {
func newMLKEMKeyBlob(dst []byte, paramSet string, keyBytes []byte, magic bcrypt.KeyBlobMagicNumber) error {
paramSetUTF16 := utf16FromString(paramSet)
paramSetByteLen := len(paramSetUTF16) * 2

blob := make([]byte, 12+paramSetByteLen+len(keyBytes))
putUint32LE(blob[0:4], uint32(magic))
putUint32LE(blob[4:8], uint32(paramSetByteLen)) // cbParameterSet
putUint32LE(blob[8:12], uint32(len(keyBytes))) // cbKey
if len(dst) < 12+paramSetByteLen+len(keyBytes) {
return errors.New("mlkem: destination blob too small")
}
putUint32LE(dst[0:4], uint32(magic))
putUint32LE(dst[4:8], uint32(paramSetByteLen)) // cbParameterSet
putUint32LE(dst[8:12], uint32(len(keyBytes))) // cbKey
for i, v := range paramSetUTF16 {
putUint16LE(blob[12+i*2:], v)
putUint16LE(dst[12+i*2:], v)
}
copy(blob[12+paramSetByteLen:], keyBytes)
copy(dst[12+paramSetByteLen:], keyBytes)

return blob, nil
return nil
}

// extractMLKEMKeyBytes extracts the raw key bytes from a blob into the provided destination slice.
func extractMLKEMKeyBytes(blob []byte, dst []byte) error {
func extractMLKEMKeyBytes(dst, blob []byte) error {
if len(blob) < 12 {
return errors.New("mlkem: blob too small")
}
Expand Down Expand Up @@ -157,8 +158,9 @@ func mlkemDecapsulate(paramSet string, seed []byte, ciphertext []byte, expectedC
return nil, err
}

// Construct blob from raw key bytes
blob, err := newMLKEMKeyBlob(paramSet, seed, bcrypt.MLKEM_PRIVATE_SEED_MAGIC)
// Construct blob from seed
blob := make([]byte, sizeOfPrivateSeedMLKEM1024) // use the larger size to be safe and avoid an allocation
err = newMLKEMKeyBlob(blob, paramSet, seed, bcrypt.MLKEM_PRIVATE_SEED_MAGIC)
if err != nil {
return nil, err
}
Expand All @@ -181,59 +183,57 @@ func mlkemDecapsulate(paramSet string, seed []byte, ciphertext []byte, expectedC
}

// mlkemEncapsulationKey is a shared helper for extracting the encapsulation key from a decapsulation key.
func mlkemEncapsulationKey(paramSet string, seed []byte, dst []byte) error {
func mlkemEncapsulationKey(paramSet string, seed []byte, dst []byte) {
alg, err := loadMLKEM()
if err != nil {
return err
panic(err)
}

// Construct blob from raw key bytes
blob, err := newMLKEMKeyBlob(paramSet, seed, bcrypt.MLKEM_PRIVATE_SEED_MAGIC)
// Construct blob from seed
blob := make([]byte, sizeOfPrivateSeedMLKEM1024) // use the larger size to be safe and avoid an allocation
err = newMLKEMKeyBlob(blob, paramSet, seed, bcrypt.MLKEM_PRIVATE_SEED_MAGIC)
if err != nil {
return err
panic(err)
}

var hKey bcrypt.KEY_HANDLE
err = bcrypt.ImportKeyPair(alg.handle, 0, utf16PtrFromString(bcrypt.MLKEM_PRIVATE_SEED_BLOB), &hKey, blob, 0)
if err != nil {
return err
panic(err)
}
defer bcrypt.DestroyKey(hKey)

// Export the public key blob
pubBlob := make([]byte, sizeOfPublicKeyMLKEM1024) // use the larger size to be safe and avoid an allocation
var size uint32
err = bcrypt.ExportKey(hKey, 0, utf16PtrFromString(bcrypt.MLKEM_PUBLIC_BLOB), nil, &size, 0)
if err != nil {
return err
}

pubBlob := make([]byte, size)
err = bcrypt.ExportKey(hKey, 0, utf16PtrFromString(bcrypt.MLKEM_PUBLIC_BLOB), pubBlob, &size, 0)
if err != nil {
return err
panic(err)
}

// Extract raw public key bytes from blob
return extractMLKEMKeyBytes(pubBlob, dst)
if err := extractMLKEMKeyBytes(dst, pubBlob[:size]); err != nil {
panic(err)
}
}

// mlkemEncapsulate is a shared helper for encapsulating with ML-KEM keys.
func mlkemEncapsulate(paramSet string, keyBytes []byte, expectedCiphertextSize int) ([]byte, []byte, error) {
func mlkemEncapsulate(paramSet string, keyBytes []byte, expectedCiphertextSize int) ([]byte, []byte) {
alg, err := loadMLKEM()
if err != nil {
return nil, nil, err
panic(err)
}

// Construct blob from raw key bytes
blob, err := newMLKEMKeyBlob(paramSet, keyBytes, bcrypt.MLKEM_PUBLIC_MAGIC)
blob := make([]byte, sizeOfPublicKeyMLKEM1024) // use the larger size to be safe and avoid an allocation
err = newMLKEMKeyBlob(blob, paramSet, keyBytes, bcrypt.MLKEM_PUBLIC_MAGIC)
if err != nil {
return nil, nil, err
panic(err)
}

var hKey bcrypt.KEY_HANDLE
err = bcrypt.ImportKeyPair(alg.handle, 0, utf16PtrFromString(bcrypt.MLKEM_PUBLIC_BLOB), &hKey, blob, 0)
if err != nil {
return nil, nil, err
panic(err)
}
defer bcrypt.DestroyKey(hKey)

Expand All @@ -244,10 +244,10 @@ func mlkemEncapsulate(paramSet string, keyBytes []byte, expectedCiphertextSize i

err = bcrypt.Encapsulate(hKey, sharedKey, &cbResult, ciphertext, &cbCiphertextResult, 0)
if err != nil {
return nil, nil, err
panic(err)
}

return sharedKey[:cbResult], ciphertext[:cbCiphertextResult], nil
return sharedKey[:cbResult], ciphertext[:cbCiphertextResult]
}

// DecapsulationKeyMLKEM768 is the secret key used to decapsulate a shared key
Expand Down Expand Up @@ -287,19 +287,14 @@ func (dk DecapsulationKeyMLKEM768) Bytes() []byte {
//
// The shared key must be kept secret.
func (dk DecapsulationKeyMLKEM768) Decapsulate(ciphertext []byte) (sharedKey []byte, err error) {
sharedKey, err = mlkemDecapsulate(bcrypt.MLKEM_PARAMETER_SET_768, dk[:], ciphertext, ciphertextSizeMLKEM768)
runtime.KeepAlive(dk)
return
return mlkemDecapsulate(bcrypt.MLKEM_PARAMETER_SET_768, dk[:], ciphertext, ciphertextSizeMLKEM768)
}

// EncapsulationKey returns the public encapsulation key necessary to produce
// ciphertexts.
func (dk DecapsulationKeyMLKEM768) EncapsulationKey() EncapsulationKeyMLKEM768 {
var ek EncapsulationKeyMLKEM768
if err := mlkemEncapsulationKey(bcrypt.MLKEM_PARAMETER_SET_768, dk[:], ek[:]); err != nil {
panic(err)
}
runtime.KeepAlive(dk)
mlkemEncapsulationKey(bcrypt.MLKEM_PARAMETER_SET_768, dk[:], ek[:])
return ek
}

Expand Down Expand Up @@ -329,13 +324,7 @@ func (ek EncapsulationKeyMLKEM768) Bytes() []byte {
//
// The shared key must be kept secret.
func (ek EncapsulationKeyMLKEM768) Encapsulate() (sharedKey, ciphertext []byte) {
var err error
sharedKey, ciphertext, err = mlkemEncapsulate(bcrypt.MLKEM_PARAMETER_SET_768, ek[:], ciphertextSizeMLKEM768)
if err != nil {
panic(err)
}
runtime.KeepAlive(ek)
return
return mlkemEncapsulate(bcrypt.MLKEM_PARAMETER_SET_768, ek[:], ciphertextSizeMLKEM768)
}

// DecapsulationKeyMLKEM1024 is the secret key used to decapsulate a shared key
Expand Down Expand Up @@ -375,19 +364,14 @@ func (dk DecapsulationKeyMLKEM1024) Bytes() []byte {
//
// The shared key must be kept secret.
func (dk DecapsulationKeyMLKEM1024) Decapsulate(ciphertext []byte) (sharedKey []byte, err error) {
sharedKey, err = mlkemDecapsulate(bcrypt.MLKEM_PARAMETER_SET_1024, dk[:], ciphertext, ciphertextSizeMLKEM1024)
runtime.KeepAlive(dk)
return
return mlkemDecapsulate(bcrypt.MLKEM_PARAMETER_SET_1024, dk[:], ciphertext, ciphertextSizeMLKEM1024)
}

// EncapsulationKey returns the public encapsulation key necessary to produce
// ciphertexts.
func (dk DecapsulationKeyMLKEM1024) EncapsulationKey() EncapsulationKeyMLKEM1024 {
var ek EncapsulationKeyMLKEM1024
if err := mlkemEncapsulationKey(bcrypt.MLKEM_PARAMETER_SET_1024, dk[:], ek[:]); err != nil {
panic(err)
}
runtime.KeepAlive(dk)
mlkemEncapsulationKey(bcrypt.MLKEM_PARAMETER_SET_1024, dk[:], ek[:])
return ek
}

Expand Down Expand Up @@ -417,11 +401,5 @@ func (ek EncapsulationKeyMLKEM1024) Bytes() []byte {
//
// The shared key must be kept secret.
func (ek EncapsulationKeyMLKEM1024) Encapsulate() (sharedKey, ciphertext []byte) {
var err error
sharedKey, ciphertext, err = mlkemEncapsulate(bcrypt.MLKEM_PARAMETER_SET_1024, ek[:], ciphertextSizeMLKEM1024)
if err != nil {
panic(err)
}
runtime.KeepAlive(ek)
return
return mlkemEncapsulate(bcrypt.MLKEM_PARAMETER_SET_1024, ek[:], ciphertextSizeMLKEM1024)
}
12 changes: 12 additions & 0 deletions cng/mlkem_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ func TestMLKEMRoundTrip(t *testing.T) {
if !cng.SupportsMLKEM() {
t.Skip("ML-KEM not supported on this platform")
}
t.Parallel()
t.Run("768", func(t *testing.T) {
testRoundTrip(t, cng.GenerateKeyMLKEM768, cng.NewEncapsulationKeyMLKEM768, cng.NewDecapsulationKeyMLKEM768)
})
Expand All @@ -42,6 +43,7 @@ func testRoundTrip[E encapsulationKey, D decapsulationKey[E]](
t *testing.T, generateKey func() (D, error),
newEncapsulationKey func([]byte) (E, error),
newDecapsulationKey func([]byte) (D, error)) {
t.Parallel()
dk, err := generateKey()
if err != nil {
t.Fatal(err)
Expand Down Expand Up @@ -103,6 +105,7 @@ func TestMLKEMBadLengths(t *testing.T) {
if !cng.SupportsMLKEM() {
t.Skip("ML-KEM not supported on this platform")
}
t.Parallel()
t.Run("768", func(t *testing.T) {
testBadLengths(t, cng.GenerateKeyMLKEM768, cng.NewEncapsulationKeyMLKEM768, cng.NewDecapsulationKeyMLKEM768)
})
Expand All @@ -115,6 +118,7 @@ func testBadLengths[E encapsulationKey, D decapsulationKey[E]](
t *testing.T, generateKey func() (D, error),
newEncapsulationKey func([]byte) (E, error),
newDecapsulationKey func([]byte) (D, error)) {
t.Parallel()
dk, err := generateKey()
dkBytes := dk.Bytes()
if err != nil {
Expand Down Expand Up @@ -269,6 +273,7 @@ func BenchmarkMLKEMRoundTrip(b *testing.B) {

// Test that the constants match the ML-KEM specification (NIST FIPS 203).
func TestMLKEMConstantSizes(t *testing.T) {
t.Parallel()
if cng.SharedKeySizeMLKEM != mlkem.SharedKeySize {
t.Errorf("SharedKeySize mismatch: got %d, want %d", cng.SharedKeySizeMLKEM, mlkem.SharedKeySize)
}
Expand Down Expand Up @@ -299,8 +304,10 @@ func TestMLKEMInteropWithStdlib(t *testing.T) {
if !cng.SupportsMLKEM() {
t.Skip("ML-KEM not supported on this platform")
}
t.Parallel()

t.Run("768_CNG_to_Stdlib", func(t *testing.T) {
t.Parallel()
// Generate key with CNG
cngDK, err := cng.GenerateKeyMLKEM768()
if err != nil {
Expand Down Expand Up @@ -330,6 +337,7 @@ func TestMLKEMInteropWithStdlib(t *testing.T) {
})

t.Run("768_Stdlib_to_CNG", func(t *testing.T) {
t.Parallel()
// Generate key with stdlib
stdlibDK, err := mlkem.GenerateKey768()
if err != nil {
Expand Down Expand Up @@ -359,6 +367,7 @@ func TestMLKEMInteropWithStdlib(t *testing.T) {
})

t.Run("768_Bidirectional", func(t *testing.T) {
t.Parallel()
// Generate keys with both implementations
cngDK, err := cng.GenerateKeyMLKEM768()
if err != nil {
Expand Down Expand Up @@ -414,6 +423,7 @@ func TestMLKEMInteropWithStdlib(t *testing.T) {
})

t.Run("1024_CNG_to_Stdlib", func(t *testing.T) {
t.Parallel()
// Generate key with CNG
cngDK, err := cng.GenerateKeyMLKEM1024()
if err != nil {
Expand Down Expand Up @@ -443,6 +453,7 @@ func TestMLKEMInteropWithStdlib(t *testing.T) {
})

t.Run("1024_Stdlib_to_CNG", func(t *testing.T) {
t.Parallel()
// Generate key with stdlib
stdlibDK, err := mlkem.GenerateKey1024()
if err != nil {
Expand Down Expand Up @@ -472,6 +483,7 @@ func TestMLKEMInteropWithStdlib(t *testing.T) {
})

t.Run("1024_Bidirectional", func(t *testing.T) {
t.Parallel()
// Generate keys with both implementations
cngDK, err := cng.GenerateKeyMLKEM1024()
if err != nil {
Expand Down
Loading