Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 16 additions & 34 deletions cng/mlkem.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,11 @@ const (
// encapsulationKeySizeMLKEM768 is the size of an ML-KEM-768 encapsulation key (raw key material).
encapsulationKeySizeMLKEM768 = 1184

// decapsulationKeySizeMLKEM768 is the size of the ML-KEM-768 decapsulation key data (raw key material).
decapsulationKeySizeMLKEM768 = 2400

// ciphertextSizeMLKEM1024 is the size of a ciphertext produced by ML-KEM-1024.
ciphertextSizeMLKEM1024 = 1568

// encapsulationKeySizeMLKEM1024 is the size of an ML-KEM-1024 encapsulation key (raw key material).
encapsulationKeySizeMLKEM1024 = 1568

// decapsulationKeySizeMLKEM1024 is the size of the ML-KEM-1024 decapsulation key data (raw key material).
decapsulationKeySizeMLKEM1024 = 3168
)

// putUint32LE puts a uint32 in little-endian byte order.
Expand Down Expand Up @@ -102,13 +96,13 @@ func generateMLKEMKey(paramSet string, dst []byte) error {

// Export the private key blob
var size uint32
err = bcrypt.ExportKey(hKey, 0, utf16PtrFromString(bcrypt.MLKEM_PRIVATE_BLOB), nil, &size, 0)
err = bcrypt.ExportKey(hKey, 0, utf16PtrFromString(bcrypt.MLKEM_PRIVATE_SEED_BLOB), nil, &size, 0)
if err != nil {
return err
}

blob := make([]byte, size)
err = bcrypt.ExportKey(hKey, 0, utf16PtrFromString(bcrypt.MLKEM_PRIVATE_BLOB), blob, &size, 0)
err = bcrypt.ExportKey(hKey, 0, utf16PtrFromString(bcrypt.MLKEM_PRIVATE_SEED_BLOB), blob, &size, 0)
if err != nil {
return err
}
Expand Down Expand Up @@ -153,7 +147,7 @@ func extractMLKEMKeyBytes(blob []byte, dst []byte) error {
}

// mlkemDecapsulate is a shared helper for decapsulating with ML-KEM keys.
func mlkemDecapsulate(paramSet string, keyBytes []byte, ciphertext []byte, expectedCiphertextSize int) ([]byte, error) {
func mlkemDecapsulate(paramSet string, seed []byte, ciphertext []byte, expectedCiphertextSize int) ([]byte, error) {
if len(ciphertext) != expectedCiphertextSize {
return nil, errors.New("mlkem: invalid ciphertext size")
}
Expand All @@ -164,13 +158,13 @@ func mlkemDecapsulate(paramSet string, keyBytes []byte, ciphertext []byte, expec
}

// Construct blob from raw key bytes
blob, err := newMLKEMKeyBlob(paramSet, keyBytes, bcrypt.MLKEM_PRIVATE_MAGIC)
blob, err := newMLKEMKeyBlob(paramSet, seed, bcrypt.MLKEM_PRIVATE_SEED_MAGIC)
if err != nil {
return nil, err
}

var hKey bcrypt.KEY_HANDLE
err = bcrypt.ImportKeyPair(alg.handle, 0, utf16PtrFromString(bcrypt.MLKEM_PRIVATE_BLOB), &hKey, blob, 0)
err = bcrypt.ImportKeyPair(alg.handle, 0, utf16PtrFromString(bcrypt.MLKEM_PRIVATE_SEED_BLOB), &hKey, blob, 0)
if err != nil {
return nil, err
}
Expand All @@ -187,20 +181,20 @@ func mlkemDecapsulate(paramSet string, keyBytes []byte, ciphertext []byte, expec
}

// mlkemEncapsulationKey is a shared helper for extracting the encapsulation key from a decapsulation key.
func mlkemEncapsulationKey(paramSet string, keyBytes []byte, dst []byte) error {
func mlkemEncapsulationKey(paramSet string, seed []byte, dst []byte) error {
alg, err := loadMLKEM()
if err != nil {
return err
}

// Construct blob from raw key bytes
blob, err := newMLKEMKeyBlob(paramSet, keyBytes, bcrypt.MLKEM_PRIVATE_MAGIC)
blob, err := newMLKEMKeyBlob(paramSet, seed, bcrypt.MLKEM_PRIVATE_SEED_MAGIC)
if err != nil {
return err
}

var hKey bcrypt.KEY_HANDLE
err = bcrypt.ImportKeyPair(alg.handle, 0, utf16PtrFromString(bcrypt.MLKEM_PRIVATE_BLOB), &hKey, blob, 0)
err = bcrypt.ImportKeyPair(alg.handle, 0, utf16PtrFromString(bcrypt.MLKEM_PRIVATE_SEED_BLOB), &hKey, blob, 0)
if err != nil {
return err
}
Expand Down Expand Up @@ -258,9 +252,7 @@ func mlkemEncapsulate(paramSet string, keyBytes []byte, expectedCiphertextSize i

// DecapsulationKeyMLKEM768 is the secret key used to decapsulate a shared key
// from a ciphertext. It includes various precomputed values.
// Note: Unlike the stdlib crypto/mlkem implementation which uses a 64-byte seed,
// the Windows CNG implementation stores the full 2400-byte expanded key material.
type DecapsulationKeyMLKEM768 [decapsulationKeySizeMLKEM768]byte
type DecapsulationKeyMLKEM768 [seedSizeMLKEM]byte

// GenerateKeyMLKEM768 generates a new decapsulation key, drawing random bytes from
// the default crypto/rand source. The decapsulation key must be kept secret.
Expand All @@ -273,21 +265,17 @@ func GenerateKeyMLKEM768() (DecapsulationKeyMLKEM768, error) {
}

// NewDecapsulationKeyMLKEM768 constructs a decapsulation key from its serialized form.
// Note: Unlike the stdlib crypto/mlkem which expects a 64-byte seed, this function
// expects the full 2400-byte expanded key material as returned by Bytes().
func NewDecapsulationKeyMLKEM768(keyBytes []byte) (DecapsulationKeyMLKEM768, error) {
if len(keyBytes) != decapsulationKeySizeMLKEM768 {
func NewDecapsulationKeyMLKEM768(seed []byte) (DecapsulationKeyMLKEM768, error) {
if len(seed) != seedSizeMLKEM {
return DecapsulationKeyMLKEM768{}, errors.New("mlkem: invalid decapsulation key size")
}

var dk DecapsulationKeyMLKEM768
copy(dk[:], keyBytes)
copy(dk[:], seed)
return dk, nil
}

// Bytes returns the decapsulation key in its serialized form.
// Note: Unlike the stdlib crypto/mlkem which returns a 64-byte seed, this returns
// the full 2400-byte expanded key material.
//
// The decapsulation key must be kept secret.
func (dk DecapsulationKeyMLKEM768) Bytes() []byte {
Expand Down Expand Up @@ -352,9 +340,7 @@ func (ek EncapsulationKeyMLKEM768) Encapsulate() (sharedKey, ciphertext []byte)

// DecapsulationKeyMLKEM1024 is the secret key used to decapsulate a shared key
// from a ciphertext. It includes various precomputed values.
// Note: Unlike the stdlib crypto/mlkem implementation which uses a 64-byte seed,
// the Windows CNG implementation stores the full 3168-byte expanded key material.
type DecapsulationKeyMLKEM1024 [decapsulationKeySizeMLKEM1024]byte
type DecapsulationKeyMLKEM1024 [seedSizeMLKEM]byte

// GenerateKeyMLKEM1024 generates a new decapsulation key, drawing random bytes from
// the default crypto/rand source. The decapsulation key must be kept secret.
Expand All @@ -367,21 +353,17 @@ func GenerateKeyMLKEM1024() (DecapsulationKeyMLKEM1024, error) {
}

// NewDecapsulationKeyMLKEM1024 constructs a decapsulation key from its serialized form.
// Note: Unlike the stdlib crypto/mlkem which expects a 64-byte seed, this function
// expects the full 3168-byte expanded key material as returned by Bytes().
func NewDecapsulationKeyMLKEM1024(keyBytes []byte) (DecapsulationKeyMLKEM1024, error) {
if len(keyBytes) != decapsulationKeySizeMLKEM1024 {
func NewDecapsulationKeyMLKEM1024(seed []byte) (DecapsulationKeyMLKEM1024, error) {
if len(seed) != seedSizeMLKEM {
return DecapsulationKeyMLKEM1024{}, errors.New("mlkem: invalid decapsulation key size")
}

var dk DecapsulationKeyMLKEM1024
copy(dk[:], keyBytes)
copy(dk[:], seed)
return dk, nil
}

// Bytes returns the decapsulation key in its serialized form.
// Note: Unlike the stdlib crypto/mlkem which returns a 64-byte seed, this returns
// the full 3168-byte expanded key material.
//
// The decapsulation key must be kept secret.
func (dk DecapsulationKeyMLKEM1024) Bytes() []byte {
Expand Down
9 changes: 8 additions & 1 deletion cng/mlkem_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ package cng_test
import (
"bytes"
"crypto/mlkem"
"math/rand"
"crypto/rand"
"testing"

"github.com/microsoft/go-crypto-winnative/cng"
Expand Down Expand Up @@ -369,6 +369,10 @@ func TestMLKEMInteropWithStdlib(t *testing.T) {
t.Fatal(err)
}

if len(cngDK.Bytes()) != len(stdlibDK.Bytes()) {
t.Fatalf("decapsulation key sizes don't match: CNG=%d, stdlib=%d", len(cngDK.Bytes()), len(stdlibDK.Bytes()))
}

// Test CNG encapsulation key -> stdlib
cngEKBytes := cngDK.EncapsulationKey().Bytes()
stdlibEK, err := mlkem.NewEncapsulationKey768(cngEKBytes)
Expand Down Expand Up @@ -477,6 +481,9 @@ func TestMLKEMInteropWithStdlib(t *testing.T) {
if err != nil {
t.Fatal(err)
}
if len(cngDK.Bytes()) != len(stdlibDK.Bytes()) {
t.Fatalf("decapsulation key sizes don't match: CNG=%d, stdlib=%d", len(cngDK.Bytes()), len(stdlibDK.Bytes()))
}

// Test CNG encapsulation key -> stdlib
cngEKBytes := cngDK.EncapsulationKey().Bytes()
Expand Down
21 changes: 11 additions & 10 deletions internal/bcrypt/bcrypt_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,14 @@ const (
)

const (
RSAPUBLIC_KEY_BLOB = "RSAPUBLICBLOB"
RSAFULLPRIVATE_BLOB = "RSAFULLPRIVATEBLOB"
ECCPUBLIC_BLOB = "ECCPUBLICBLOB"
ECCPRIVATE_BLOB = "ECCPRIVATEBLOB"
DSA_PUBLIC_BLOB = "DSAPUBLICBLOB"
DSA_PRIVATE_BLOB = "DSAPRIVATEBLOB"
MLKEM_PUBLIC_BLOB = "MLKEMPUBLICBLOB"
MLKEM_PRIVATE_BLOB = "MLKEMPRIVATEBLOB"
RSAPUBLIC_KEY_BLOB = "RSAPUBLICBLOB"
RSAFULLPRIVATE_BLOB = "RSAFULLPRIVATEBLOB"
ECCPUBLIC_BLOB = "ECCPUBLICBLOB"
ECCPRIVATE_BLOB = "ECCPRIVATEBLOB"
DSA_PUBLIC_BLOB = "DSAPUBLICBLOB"
DSA_PRIVATE_BLOB = "DSAPRIVATEBLOB"
MLKEM_PUBLIC_BLOB = "MLKEMPUBLICBLOB"
MLKEM_PRIVATE_SEED_BLOB = "MLKEMPRIVATESEEDBLOB"
)

const (
Expand Down Expand Up @@ -212,8 +212,9 @@ const (
DSA_PUBLIC_MAGIC_V2 KeyBlobMagicNumber = 0x32425044
DSA_PRIVATE_MAGIC_V2 KeyBlobMagicNumber = 0x32565044

MLKEM_PUBLIC_MAGIC KeyBlobMagicNumber = 0x504b4c4d
MLKEM_PRIVATE_MAGIC KeyBlobMagicNumber = 0x524b4c4d
MLKEM_PUBLIC_MAGIC KeyBlobMagicNumber = 0x504B4C4D
MLKEM_PRIVATE_MAGIC KeyBlobMagicNumber = 0x524B4C4D
MLKEM_PRIVATE_SEED_MAGIC KeyBlobMagicNumber = 0x534B4C4D
)

type (
Expand Down
Loading