Skip to content

Commit c07a4ba

Browse files
committed
mlkem: reduce allocations
1 parent 280f60c commit c07a4ba

File tree

2 files changed

+59
-69
lines changed

2 files changed

+59
-69
lines changed

cng/mlkem.go

Lines changed: 47 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ package cng
88

99
import (
1010
"errors"
11-
"runtime"
1211

1312
"github.com/microsoft/go-crypto-winnative/internal/bcrypt"
1413
)
@@ -33,6 +32,11 @@ const (
3332
encapsulationKeySizeMLKEM1024 = 1568
3433
)
3534

35+
const (
36+
sizeOfPrivateSeedMLKEM1024 = 4 + 4 + 4 + 10 + seedSizeMLKEM // dwMagic (4) + cbParameterSet (4) + cbKey (4) + ParameterSet (8 "1024\0") + Key (64)
37+
sizeOfPublicKeyMLKEM1024 = 4 + 4 + 4 + 10 + encapsulationKeySizeMLKEM1024 // dwMagic (4) + cbParameterSet (4) + cbKey (4) + ParameterSet (8 "1024\0") + Key (1184)
38+
)
39+
3640
// putUint32LE puts a uint32 in little-endian byte order.
3741
func putUint32LE(b []byte, v uint32) {
3842
b[0] = byte(v)
@@ -95,41 +99,38 @@ func generateMLKEMKey(paramSet string, dst []byte) error {
9599
}
96100

97101
// Export the private key blob
102+
blob := make([]byte, sizeOfPrivateSeedMLKEM1024) // use the larger size to be safe and avoid an allocation
98103
var size uint32
99-
err = bcrypt.ExportKey(hKey, 0, utf16PtrFromString(bcrypt.MLKEM_PRIVATE_SEED_BLOB), nil, &size, 0)
100-
if err != nil {
101-
return err
102-
}
103-
104-
blob := make([]byte, size)
105104
err = bcrypt.ExportKey(hKey, 0, utf16PtrFromString(bcrypt.MLKEM_PRIVATE_SEED_BLOB), blob, &size, 0)
106105
if err != nil {
107106
return err
108107
}
109108

110109
// Extract raw key bytes into destination
111-
return extractMLKEMKeyBytes(blob, dst)
110+
return extractMLKEMKeyBytes(dst, blob[:size])
112111
}
113112

114113
// newMLKEMKeyBlob creates a key blob from raw key bytes.
115-
func newMLKEMKeyBlob(paramSet string, keyBytes []byte, magic bcrypt.KeyBlobMagicNumber) ([]byte, error) {
114+
func newMLKEMKeyBlob(dst []byte, paramSet string, keyBytes []byte, magic bcrypt.KeyBlobMagicNumber) error {
116115
paramSetUTF16 := utf16FromString(paramSet)
117116
paramSetByteLen := len(paramSetUTF16) * 2
118117

119-
blob := make([]byte, 12+paramSetByteLen+len(keyBytes))
120-
putUint32LE(blob[0:4], uint32(magic))
121-
putUint32LE(blob[4:8], uint32(paramSetByteLen)) // cbParameterSet
122-
putUint32LE(blob[8:12], uint32(len(keyBytes))) // cbKey
118+
if len(dst) < 12+paramSetByteLen+len(keyBytes) {
119+
return errors.New("mlkem: destination blob too small")
120+
}
121+
putUint32LE(dst[0:4], uint32(magic))
122+
putUint32LE(dst[4:8], uint32(paramSetByteLen)) // cbParameterSet
123+
putUint32LE(dst[8:12], uint32(len(keyBytes))) // cbKey
123124
for i, v := range paramSetUTF16 {
124-
putUint16LE(blob[12+i*2:], v)
125+
putUint16LE(dst[12+i*2:], v)
125126
}
126-
copy(blob[12+paramSetByteLen:], keyBytes)
127+
copy(dst[12+paramSetByteLen:], keyBytes)
127128

128-
return blob, nil
129+
return nil
129130
}
130131

131132
// extractMLKEMKeyBytes extracts the raw key bytes from a blob into the provided destination slice.
132-
func extractMLKEMKeyBytes(blob []byte, dst []byte) error {
133+
func extractMLKEMKeyBytes(dst, blob []byte) error {
133134
if len(blob) < 12 {
134135
return errors.New("mlkem: blob too small")
135136
}
@@ -157,8 +158,9 @@ func mlkemDecapsulate(paramSet string, seed []byte, ciphertext []byte, expectedC
157158
return nil, err
158159
}
159160

160-
// Construct blob from raw key bytes
161-
blob, err := newMLKEMKeyBlob(paramSet, seed, bcrypt.MLKEM_PRIVATE_SEED_MAGIC)
161+
// Construct blob from seed
162+
blob := make([]byte, sizeOfPrivateSeedMLKEM1024) // use the larger size to be safe and avoid an allocation
163+
err = newMLKEMKeyBlob(blob, paramSet, seed, bcrypt.MLKEM_PRIVATE_SEED_MAGIC)
162164
if err != nil {
163165
return nil, err
164166
}
@@ -181,59 +183,57 @@ func mlkemDecapsulate(paramSet string, seed []byte, ciphertext []byte, expectedC
181183
}
182184

183185
// mlkemEncapsulationKey is a shared helper for extracting the encapsulation key from a decapsulation key.
184-
func mlkemEncapsulationKey(paramSet string, seed []byte, dst []byte) error {
186+
func mlkemEncapsulationKey(paramSet string, seed []byte, dst []byte) {
185187
alg, err := loadMLKEM()
186188
if err != nil {
187-
return err
189+
panic(err)
188190
}
189191

190-
// Construct blob from raw key bytes
191-
blob, err := newMLKEMKeyBlob(paramSet, seed, bcrypt.MLKEM_PRIVATE_SEED_MAGIC)
192+
// Construct blob from seed
193+
blob := make([]byte, sizeOfPrivateSeedMLKEM1024) // use the larger size to be safe and avoid an allocation
194+
err = newMLKEMKeyBlob(blob, paramSet, seed, bcrypt.MLKEM_PRIVATE_SEED_MAGIC)
192195
if err != nil {
193-
return err
196+
panic(err)
194197
}
195198

196199
var hKey bcrypt.KEY_HANDLE
197200
err = bcrypt.ImportKeyPair(alg.handle, 0, utf16PtrFromString(bcrypt.MLKEM_PRIVATE_SEED_BLOB), &hKey, blob, 0)
198201
if err != nil {
199-
return err
202+
panic(err)
200203
}
201204
defer bcrypt.DestroyKey(hKey)
202205

203206
// Export the public key blob
207+
pubBlob := make([]byte, sizeOfPublicKeyMLKEM1024) // use the larger size to be safe and avoid an allocation
204208
var size uint32
205-
err = bcrypt.ExportKey(hKey, 0, utf16PtrFromString(bcrypt.MLKEM_PUBLIC_BLOB), nil, &size, 0)
206-
if err != nil {
207-
return err
208-
}
209-
210-
pubBlob := make([]byte, size)
211209
err = bcrypt.ExportKey(hKey, 0, utf16PtrFromString(bcrypt.MLKEM_PUBLIC_BLOB), pubBlob, &size, 0)
212210
if err != nil {
213-
return err
211+
panic(err)
214212
}
215-
216213
// Extract raw public key bytes from blob
217-
return extractMLKEMKeyBytes(pubBlob, dst)
214+
if err := extractMLKEMKeyBytes(dst, pubBlob[:size]); err != nil {
215+
panic(err)
216+
}
218217
}
219218

220219
// mlkemEncapsulate is a shared helper for encapsulating with ML-KEM keys.
221-
func mlkemEncapsulate(paramSet string, keyBytes []byte, expectedCiphertextSize int) ([]byte, []byte, error) {
220+
func mlkemEncapsulate(paramSet string, keyBytes []byte, expectedCiphertextSize int) ([]byte, []byte) {
222221
alg, err := loadMLKEM()
223222
if err != nil {
224-
return nil, nil, err
223+
panic(err)
225224
}
226225

227226
// Construct blob from raw key bytes
228-
blob, err := newMLKEMKeyBlob(paramSet, keyBytes, bcrypt.MLKEM_PUBLIC_MAGIC)
227+
blob := make([]byte, sizeOfPublicKeyMLKEM1024) // use the larger size to be safe and avoid an allocation
228+
err = newMLKEMKeyBlob(blob, paramSet, keyBytes, bcrypt.MLKEM_PUBLIC_MAGIC)
229229
if err != nil {
230-
return nil, nil, err
230+
panic(err)
231231
}
232232

233233
var hKey bcrypt.KEY_HANDLE
234234
err = bcrypt.ImportKeyPair(alg.handle, 0, utf16PtrFromString(bcrypt.MLKEM_PUBLIC_BLOB), &hKey, blob, 0)
235235
if err != nil {
236-
return nil, nil, err
236+
panic(err)
237237
}
238238
defer bcrypt.DestroyKey(hKey)
239239

@@ -244,10 +244,10 @@ func mlkemEncapsulate(paramSet string, keyBytes []byte, expectedCiphertextSize i
244244

245245
err = bcrypt.Encapsulate(hKey, sharedKey, &cbResult, ciphertext, &cbCiphertextResult, 0)
246246
if err != nil {
247-
return nil, nil, err
247+
panic(err)
248248
}
249249

250-
return sharedKey[:cbResult], ciphertext[:cbCiphertextResult], nil
250+
return sharedKey[:cbResult], ciphertext[:cbCiphertextResult]
251251
}
252252

253253
// DecapsulationKeyMLKEM768 is the secret key used to decapsulate a shared key
@@ -287,19 +287,14 @@ func (dk DecapsulationKeyMLKEM768) Bytes() []byte {
287287
//
288288
// The shared key must be kept secret.
289289
func (dk DecapsulationKeyMLKEM768) Decapsulate(ciphertext []byte) (sharedKey []byte, err error) {
290-
sharedKey, err = mlkemDecapsulate(bcrypt.MLKEM_PARAMETER_SET_768, dk[:], ciphertext, ciphertextSizeMLKEM768)
291-
runtime.KeepAlive(dk)
292-
return
290+
return mlkemDecapsulate(bcrypt.MLKEM_PARAMETER_SET_768, dk[:], ciphertext, ciphertextSizeMLKEM768)
293291
}
294292

295293
// EncapsulationKey returns the public encapsulation key necessary to produce
296294
// ciphertexts.
297295
func (dk DecapsulationKeyMLKEM768) EncapsulationKey() EncapsulationKeyMLKEM768 {
298296
var ek EncapsulationKeyMLKEM768
299-
if err := mlkemEncapsulationKey(bcrypt.MLKEM_PARAMETER_SET_768, dk[:], ek[:]); err != nil {
300-
panic(err)
301-
}
302-
runtime.KeepAlive(dk)
297+
mlkemEncapsulationKey(bcrypt.MLKEM_PARAMETER_SET_768, dk[:], ek[:])
303298
return ek
304299
}
305300

@@ -329,13 +324,7 @@ func (ek EncapsulationKeyMLKEM768) Bytes() []byte {
329324
//
330325
// The shared key must be kept secret.
331326
func (ek EncapsulationKeyMLKEM768) Encapsulate() (sharedKey, ciphertext []byte) {
332-
var err error
333-
sharedKey, ciphertext, err = mlkemEncapsulate(bcrypt.MLKEM_PARAMETER_SET_768, ek[:], ciphertextSizeMLKEM768)
334-
if err != nil {
335-
panic(err)
336-
}
337-
runtime.KeepAlive(ek)
338-
return
327+
return mlkemEncapsulate(bcrypt.MLKEM_PARAMETER_SET_768, ek[:], ciphertextSizeMLKEM768)
339328
}
340329

341330
// DecapsulationKeyMLKEM1024 is the secret key used to decapsulate a shared key
@@ -375,19 +364,14 @@ func (dk DecapsulationKeyMLKEM1024) Bytes() []byte {
375364
//
376365
// The shared key must be kept secret.
377366
func (dk DecapsulationKeyMLKEM1024) Decapsulate(ciphertext []byte) (sharedKey []byte, err error) {
378-
sharedKey, err = mlkemDecapsulate(bcrypt.MLKEM_PARAMETER_SET_1024, dk[:], ciphertext, ciphertextSizeMLKEM1024)
379-
runtime.KeepAlive(dk)
380-
return
367+
return mlkemDecapsulate(bcrypt.MLKEM_PARAMETER_SET_1024, dk[:], ciphertext, ciphertextSizeMLKEM1024)
381368
}
382369

383370
// EncapsulationKey returns the public encapsulation key necessary to produce
384371
// ciphertexts.
385372
func (dk DecapsulationKeyMLKEM1024) EncapsulationKey() EncapsulationKeyMLKEM1024 {
386373
var ek EncapsulationKeyMLKEM1024
387-
if err := mlkemEncapsulationKey(bcrypt.MLKEM_PARAMETER_SET_1024, dk[:], ek[:]); err != nil {
388-
panic(err)
389-
}
390-
runtime.KeepAlive(dk)
374+
mlkemEncapsulationKey(bcrypt.MLKEM_PARAMETER_SET_1024, dk[:], ek[:])
391375
return ek
392376
}
393377

@@ -417,11 +401,5 @@ func (ek EncapsulationKeyMLKEM1024) Bytes() []byte {
417401
//
418402
// The shared key must be kept secret.
419403
func (ek EncapsulationKeyMLKEM1024) Encapsulate() (sharedKey, ciphertext []byte) {
420-
var err error
421-
sharedKey, ciphertext, err = mlkemEncapsulate(bcrypt.MLKEM_PARAMETER_SET_1024, ek[:], ciphertextSizeMLKEM1024)
422-
if err != nil {
423-
panic(err)
424-
}
425-
runtime.KeepAlive(ek)
426-
return
404+
return mlkemEncapsulate(bcrypt.MLKEM_PARAMETER_SET_1024, ek[:], ciphertextSizeMLKEM1024)
427405
}

cng/mlkem_test.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ func TestMLKEMRoundTrip(t *testing.T) {
3030
if !cng.SupportsMLKEM() {
3131
t.Skip("ML-KEM not supported on this platform")
3232
}
33+
t.Parallel()
3334
t.Run("768", func(t *testing.T) {
3435
testRoundTrip(t, cng.GenerateKeyMLKEM768, cng.NewEncapsulationKeyMLKEM768, cng.NewDecapsulationKeyMLKEM768)
3536
})
@@ -42,6 +43,7 @@ func testRoundTrip[E encapsulationKey, D decapsulationKey[E]](
4243
t *testing.T, generateKey func() (D, error),
4344
newEncapsulationKey func([]byte) (E, error),
4445
newDecapsulationKey func([]byte) (D, error)) {
46+
t.Parallel()
4547
dk, err := generateKey()
4648
if err != nil {
4749
t.Fatal(err)
@@ -103,6 +105,7 @@ func TestMLKEMBadLengths(t *testing.T) {
103105
if !cng.SupportsMLKEM() {
104106
t.Skip("ML-KEM not supported on this platform")
105107
}
108+
t.Parallel()
106109
t.Run("768", func(t *testing.T) {
107110
testBadLengths(t, cng.GenerateKeyMLKEM768, cng.NewEncapsulationKeyMLKEM768, cng.NewDecapsulationKeyMLKEM768)
108111
})
@@ -115,6 +118,7 @@ func testBadLengths[E encapsulationKey, D decapsulationKey[E]](
115118
t *testing.T, generateKey func() (D, error),
116119
newEncapsulationKey func([]byte) (E, error),
117120
newDecapsulationKey func([]byte) (D, error)) {
121+
t.Parallel()
118122
dk, err := generateKey()
119123
dkBytes := dk.Bytes()
120124
if err != nil {
@@ -269,6 +273,7 @@ func BenchmarkMLKEMRoundTrip(b *testing.B) {
269273

270274
// Test that the constants match the ML-KEM specification (NIST FIPS 203).
271275
func TestMLKEMConstantSizes(t *testing.T) {
276+
t.Parallel()
272277
if cng.SharedKeySizeMLKEM != mlkem.SharedKeySize {
273278
t.Errorf("SharedKeySize mismatch: got %d, want %d", cng.SharedKeySizeMLKEM, mlkem.SharedKeySize)
274279
}
@@ -299,8 +304,10 @@ func TestMLKEMInteropWithStdlib(t *testing.T) {
299304
if !cng.SupportsMLKEM() {
300305
t.Skip("ML-KEM not supported on this platform")
301306
}
307+
t.Parallel()
302308

303309
t.Run("768_CNG_to_Stdlib", func(t *testing.T) {
310+
t.Parallel()
304311
// Generate key with CNG
305312
cngDK, err := cng.GenerateKeyMLKEM768()
306313
if err != nil {
@@ -330,6 +337,7 @@ func TestMLKEMInteropWithStdlib(t *testing.T) {
330337
})
331338

332339
t.Run("768_Stdlib_to_CNG", func(t *testing.T) {
340+
t.Parallel()
333341
// Generate key with stdlib
334342
stdlibDK, err := mlkem.GenerateKey768()
335343
if err != nil {
@@ -359,6 +367,7 @@ func TestMLKEMInteropWithStdlib(t *testing.T) {
359367
})
360368

361369
t.Run("768_Bidirectional", func(t *testing.T) {
370+
t.Parallel()
362371
// Generate keys with both implementations
363372
cngDK, err := cng.GenerateKeyMLKEM768()
364373
if err != nil {
@@ -414,6 +423,7 @@ func TestMLKEMInteropWithStdlib(t *testing.T) {
414423
})
415424

416425
t.Run("1024_CNG_to_Stdlib", func(t *testing.T) {
426+
t.Parallel()
417427
// Generate key with CNG
418428
cngDK, err := cng.GenerateKeyMLKEM1024()
419429
if err != nil {
@@ -443,6 +453,7 @@ func TestMLKEMInteropWithStdlib(t *testing.T) {
443453
})
444454

445455
t.Run("1024_Stdlib_to_CNG", func(t *testing.T) {
456+
t.Parallel()
446457
// Generate key with stdlib
447458
stdlibDK, err := mlkem.GenerateKey1024()
448459
if err != nil {
@@ -472,6 +483,7 @@ func TestMLKEMInteropWithStdlib(t *testing.T) {
472483
})
473484

474485
t.Run("1024_Bidirectional", func(t *testing.T) {
486+
t.Parallel()
475487
// Generate keys with both implementations
476488
cngDK, err := cng.GenerateKeyMLKEM1024()
477489
if err != nil {

0 commit comments

Comments
 (0)