Skip to content

Commit cce203f

Browse files
hanno-beckermkannwischer
authored andcommitted
Bundle multiple allocations into workspaces
This commit introduces 'workspace structures', which are are function-local structures wrapping all objects and buffers utilized by the function into a single struct. The benefit of the workspace struct is a simpler allocation boilerplate: Instead of allocating multiple structures and separately handling their success/failure, we now only need to allocate a single object per function. Signed-off-by: Hanno Becker <beckphan@amazon.co.uk>
1 parent b288906 commit cce203f

File tree

2 files changed

+175
-163
lines changed

2 files changed

+175
-163
lines changed

mlkem/src/indcpa.c

Lines changed: 113 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -379,32 +379,36 @@ int mlk_indcpa_keypair_derand(uint8_t pk[MLKEM_INDCPA_PUBLICKEYBYTES],
379379
uint8_t sk[MLKEM_INDCPA_SECRETKEYBYTES],
380380
const uint8_t coins[MLKEM_SYMBYTES])
381381
{
382+
typedef struct
383+
{
384+
MLK_ALIGN uint8_t buf[2 * MLKEM_SYMBYTES];
385+
MLK_ALIGN uint8_t coins_with_domain_separator[MLKEM_SYMBYTES + 1];
386+
mlk_polymat a;
387+
mlk_polyvec e;
388+
mlk_polyvec pkpv;
389+
mlk_polyvec skpv;
390+
mlk_polyvec_mulcache skpv_cache;
391+
} workspace;
392+
382393
int ret = 0;
383394
const uint8_t *publicseed;
384395
const uint8_t *noiseseed;
385-
MLK_ALLOC(buf, uint8_t, 2 * MLKEM_SYMBYTES);
386-
MLK_ALLOC(coins_with_domain_separator, uint8_t, MLKEM_SYMBYTES + 1);
387-
MLK_ALLOC(a, mlk_polymat, 1);
388-
MLK_ALLOC(e, mlk_polyvec, 1);
389-
MLK_ALLOC(pkpv, mlk_polyvec, 1);
390-
MLK_ALLOC(skpv, mlk_polyvec, 1);
391-
MLK_ALLOC(skpv_cache, mlk_polyvec_mulcache, 1);
392-
393-
if (buf == NULL || coins_with_domain_separator == NULL || a == NULL ||
394-
e == NULL || pkpv == NULL || skpv == NULL || skpv_cache == NULL)
396+
MLK_ALLOC(ws, workspace, 1);
397+
398+
if (ws == NULL)
395399
{
396400
ret = MLK_ERR_OUT_OF_MEMORY;
397401
goto cleanup;
398402
}
399403

400-
publicseed = buf;
401-
noiseseed = buf + MLKEM_SYMBYTES;
404+
publicseed = ws->buf;
405+
noiseseed = ws->buf + MLKEM_SYMBYTES;
402406

403407
/* Concatenate coins with MLKEM_K for domain separation of security levels */
404-
mlk_memcpy(coins_with_domain_separator, coins, MLKEM_SYMBYTES);
405-
coins_with_domain_separator[MLKEM_SYMBYTES] = MLKEM_K;
408+
mlk_memcpy(ws->coins_with_domain_separator, coins, MLKEM_SYMBYTES);
409+
ws->coins_with_domain_separator[MLKEM_SYMBYTES] = MLKEM_K;
406410

407-
mlk_hash_g(buf, coins_with_domain_separator, MLKEM_SYMBYTES + 1);
411+
mlk_hash_g(ws->buf, ws->coins_with_domain_separator, MLKEM_SYMBYTES + 1);
408412

409413
/*
410414
* Declassify the public seed.
@@ -414,54 +418,49 @@ int mlk_indcpa_keypair_derand(uint8_t pk[MLKEM_INDCPA_PUBLICKEYBYTES],
414418
*/
415419
MLK_CT_TESTING_DECLASSIFY(publicseed, MLKEM_SYMBYTES);
416420

417-
mlk_gen_matrix(a, publicseed, 0 /* no transpose */);
421+
mlk_gen_matrix(&ws->a, publicseed, 0 /* no transpose */);
418422

419423
#if MLKEM_K == 2
420-
mlk_poly_getnoise_eta1_4x(&skpv->vec[0], &skpv->vec[1], &e->vec[0],
421-
&e->vec[1], noiseseed, 0, 1, 2, 3);
424+
mlk_poly_getnoise_eta1_4x(&ws->skpv.vec[0], &ws->skpv.vec[1], &ws->e.vec[0],
425+
&ws->e.vec[1], noiseseed, 0, 1, 2, 3);
422426
#elif MLKEM_K == 3
423427
/*
424428
* Only the first three output buffers are needed.
425429
* The laster parameter is a dummy that's overwritten later.
426430
*/
427-
mlk_poly_getnoise_eta1_4x(&skpv->vec[0], &skpv->vec[1], &skpv->vec[2],
428-
&pkpv->vec[0] /* irrelevant */, noiseseed, 0, 1, 2,
429-
0xFF /* irrelevant */);
431+
mlk_poly_getnoise_eta1_4x(&ws->skpv.vec[0], &ws->skpv.vec[1],
432+
&ws->skpv.vec[2], &ws->pkpv.vec[0] /* irrelevant */,
433+
noiseseed, 0, 1, 2, 0xFF /* irrelevant */);
430434
/* Same here */
431-
mlk_poly_getnoise_eta1_4x(&e->vec[0], &e->vec[1], &e->vec[2],
432-
&pkpv->vec[0] /* irrelevant */, noiseseed, 3, 4, 5,
433-
0xFF /* irrelevant */);
435+
mlk_poly_getnoise_eta1_4x(&ws->e.vec[0], &ws->e.vec[1], &ws->e.vec[2],
436+
&ws->pkpv.vec[0] /* irrelevant */, noiseseed, 3, 4,
437+
5, 0xFF /* irrelevant */);
434438
#elif MLKEM_K == 4
435-
mlk_poly_getnoise_eta1_4x(&skpv->vec[0], &skpv->vec[1], &skpv->vec[2],
436-
&skpv->vec[3], noiseseed, 0, 1, 2, 3);
437-
mlk_poly_getnoise_eta1_4x(&e->vec[0], &e->vec[1], &e->vec[2], &e->vec[3],
438-
noiseseed, 4, 5, 6, 7);
439+
mlk_poly_getnoise_eta1_4x(&ws->skpv.vec[0], &ws->skpv.vec[1],
440+
&ws->skpv.vec[2], &ws->skpv.vec[3], noiseseed, 0, 1,
441+
2, 3);
442+
mlk_poly_getnoise_eta1_4x(&ws->e.vec[0], &ws->e.vec[1], &ws->e.vec[2],
443+
&ws->e.vec[3], noiseseed, 4, 5, 6, 7);
439444
#endif /* MLKEM_K == 4 */
440445

441-
mlk_polyvec_ntt(skpv);
442-
mlk_polyvec_ntt(e);
446+
mlk_polyvec_ntt(&ws->skpv);
447+
mlk_polyvec_ntt(&ws->e);
443448

444-
mlk_polyvec_mulcache_compute(skpv_cache, skpv);
445-
mlk_matvec_mul(pkpv, a, skpv, skpv_cache);
446-
mlk_polyvec_tomont(pkpv);
449+
mlk_polyvec_mulcache_compute(&ws->skpv_cache, &ws->skpv);
450+
mlk_matvec_mul(&ws->pkpv, &ws->a, &ws->skpv, &ws->skpv_cache);
451+
mlk_polyvec_tomont(&ws->pkpv);
447452

448-
mlk_polyvec_add(pkpv, e);
449-
mlk_polyvec_reduce(pkpv);
450-
mlk_polyvec_reduce(skpv);
453+
mlk_polyvec_add(&ws->pkpv, &ws->e);
454+
mlk_polyvec_reduce(&ws->pkpv);
455+
mlk_polyvec_reduce(&ws->skpv);
451456

452-
mlk_pack_sk(sk, skpv);
453-
mlk_pack_pk(pk, pkpv, publicseed);
457+
mlk_pack_sk(sk, &ws->skpv);
458+
mlk_pack_pk(pk, &ws->pkpv, publicseed);
454459

455460
cleanup:
456461
/* Specification: Partially implements
457462
* @[FIPS203, Section 3.3, Destruction of intermediate values] */
458-
MLK_FREE(skpv_cache, mlk_polyvec_mulcache, 1);
459-
MLK_FREE(skpv, mlk_polyvec, 1);
460-
MLK_FREE(pkpv, mlk_polyvec, 1);
461-
MLK_FREE(e, mlk_polyvec, 1);
462-
MLK_FREE(a, mlk_polymat, 1);
463-
MLK_FREE(coins_with_domain_separator, uint8_t, MLKEM_SYMBYTES + 1);
464-
MLK_FREE(buf, uint8_t, 2 * MLKEM_SYMBYTES);
463+
MLK_FREE(ws, workspace, 1);
465464
return ret;
466465
}
467466

@@ -479,91 +478,87 @@ int mlk_indcpa_enc(uint8_t c[MLKEM_INDCPA_BYTES],
479478
const uint8_t pk[MLKEM_INDCPA_PUBLICKEYBYTES],
480479
const uint8_t coins[MLKEM_SYMBYTES])
481480
{
481+
typedef struct
482+
{
483+
MLK_ALIGN uint8_t seed[MLKEM_SYMBYTES];
484+
mlk_polymat at;
485+
mlk_polyvec sp;
486+
mlk_polyvec pkpv;
487+
mlk_polyvec ep;
488+
mlk_polyvec b;
489+
mlk_poly v;
490+
mlk_poly k;
491+
mlk_poly epp;
492+
mlk_polyvec_mulcache sp_cache;
493+
} workspace;
494+
482495
int ret = 0;
483-
MLK_ALLOC(seed, uint8_t, MLKEM_SYMBYTES);
484-
MLK_ALLOC(at, mlk_polymat, 1);
485-
MLK_ALLOC(sp, mlk_polyvec, 1);
486-
MLK_ALLOC(pkpv, mlk_polyvec, 1);
487-
MLK_ALLOC(ep, mlk_polyvec, 1);
488-
MLK_ALLOC(b, mlk_polyvec, 1);
489-
MLK_ALLOC(v, mlk_poly, 1);
490-
MLK_ALLOC(k, mlk_poly, 1);
491-
MLK_ALLOC(epp, mlk_poly, 1);
492-
MLK_ALLOC(sp_cache, mlk_polyvec_mulcache, 1);
493-
494-
if (seed == NULL || at == NULL || sp == NULL || pkpv == NULL || ep == NULL ||
495-
b == NULL || v == NULL || k == NULL || epp == NULL || sp_cache == NULL)
496+
MLK_ALLOC(ws, workspace, 1);
497+
498+
if (ws == NULL)
496499
{
497500
ret = MLK_ERR_OUT_OF_MEMORY;
498501
goto cleanup;
499502
}
500503

501-
mlk_unpack_pk(pkpv, seed, pk);
502-
mlk_poly_frommsg(k, m);
504+
mlk_unpack_pk(&ws->pkpv, ws->seed, pk);
505+
mlk_poly_frommsg(&ws->k, m);
503506

504507
/*
505508
* Declassify the public seed.
506509
* Required to use it in conditional-branches in rejection sampling.
507510
* This is needed because in re-encryption the publicseed originated from sk
508511
* which is marked undefined.
509512
*/
510-
MLK_CT_TESTING_DECLASSIFY(seed, MLKEM_SYMBYTES);
513+
MLK_CT_TESTING_DECLASSIFY(ws->seed, MLKEM_SYMBYTES);
511514

512-
mlk_gen_matrix(at, seed, 1 /* transpose */);
515+
mlk_gen_matrix(&ws->at, ws->seed, 1 /* transpose */);
513516

514517
#if MLKEM_K == 2
515-
mlk_poly_getnoise_eta1122_4x(&sp->vec[0], &sp->vec[1], &ep->vec[0],
516-
&ep->vec[1], coins, 0, 1, 2, 3);
517-
mlk_poly_getnoise_eta2(epp, coins, 4);
518+
mlk_poly_getnoise_eta1122_4x(&ws->sp.vec[0], &ws->sp.vec[1], &ws->ep.vec[0],
519+
&ws->ep.vec[1], coins, 0, 1, 2, 3);
520+
mlk_poly_getnoise_eta2(&ws->epp, coins, 4);
518521
#elif MLKEM_K == 3
519522
/*
520523
* In this call, only the first three output buffers are needed.
521524
* The last parameter is a dummy that's overwritten later.
522525
*/
523-
mlk_poly_getnoise_eta1_4x(&sp->vec[0], &sp->vec[1], &sp->vec[2], &b->vec[0],
524-
coins, 0, 1, 2, 0xFF);
526+
mlk_poly_getnoise_eta1_4x(&ws->sp.vec[0], &ws->sp.vec[1], &ws->sp.vec[2],
527+
&ws->b.vec[0], coins, 0, 1, 2, 0xFF);
525528
/* The fourth output buffer in this call _is_ used. */
526-
mlk_poly_getnoise_eta2_4x(&ep->vec[0], &ep->vec[1], &ep->vec[2], epp, coins,
527-
3, 4, 5, 6);
529+
mlk_poly_getnoise_eta2_4x(&ws->ep.vec[0], &ws->ep.vec[1], &ws->ep.vec[2],
530+
&ws->epp, coins, 3, 4, 5, 6);
528531
#elif MLKEM_K == 4
529-
mlk_poly_getnoise_eta1_4x(&sp->vec[0], &sp->vec[1], &sp->vec[2], &sp->vec[3],
530-
coins, 0, 1, 2, 3);
531-
mlk_poly_getnoise_eta2_4x(&ep->vec[0], &ep->vec[1], &ep->vec[2], &ep->vec[3],
532-
coins, 4, 5, 6, 7);
533-
mlk_poly_getnoise_eta2(epp, coins, 8);
532+
mlk_poly_getnoise_eta1_4x(&ws->sp.vec[0], &ws->sp.vec[1], &ws->sp.vec[2],
533+
&ws->sp.vec[3], coins, 0, 1, 2, 3);
534+
mlk_poly_getnoise_eta2_4x(&ws->ep.vec[0], &ws->ep.vec[1], &ws->ep.vec[2],
535+
&ws->ep.vec[3], coins, 4, 5, 6, 7);
536+
mlk_poly_getnoise_eta2(&ws->epp, coins, 8);
534537
#endif /* MLKEM_K == 4 */
535538

536-
mlk_polyvec_ntt(sp);
539+
mlk_polyvec_ntt(&ws->sp);
537540

538-
mlk_polyvec_mulcache_compute(sp_cache, sp);
539-
mlk_matvec_mul(b, at, sp, sp_cache);
540-
mlk_polyvec_basemul_acc_montgomery_cached(v, pkpv, sp, sp_cache);
541+
mlk_polyvec_mulcache_compute(&ws->sp_cache, &ws->sp);
542+
mlk_matvec_mul(&ws->b, &ws->at, &ws->sp, &ws->sp_cache);
543+
mlk_polyvec_basemul_acc_montgomery_cached(&ws->v, &ws->pkpv, &ws->sp,
544+
&ws->sp_cache);
541545

542-
mlk_polyvec_invntt_tomont(b);
543-
mlk_poly_invntt_tomont(v);
546+
mlk_polyvec_invntt_tomont(&ws->b);
547+
mlk_poly_invntt_tomont(&ws->v);
544548

545-
mlk_polyvec_add(b, ep);
546-
mlk_poly_add(v, epp);
547-
mlk_poly_add(v, k);
549+
mlk_polyvec_add(&ws->b, &ws->ep);
550+
mlk_poly_add(&ws->v, &ws->epp);
551+
mlk_poly_add(&ws->v, &ws->k);
548552

549-
mlk_polyvec_reduce(b);
550-
mlk_poly_reduce(v);
553+
mlk_polyvec_reduce(&ws->b);
554+
mlk_poly_reduce(&ws->v);
551555

552-
mlk_pack_ciphertext(c, b, v);
556+
mlk_pack_ciphertext(c, &ws->b, &ws->v);
553557

554558
cleanup:
555559
/* Specification: Partially implements
556560
* @[FIPS203, Section 3.3, Destruction of intermediate values] */
557-
MLK_FREE(sp_cache, mlk_polyvec_mulcache, 1);
558-
MLK_FREE(epp, mlk_poly, 1);
559-
MLK_FREE(k, mlk_poly, 1);
560-
MLK_FREE(v, mlk_poly, 1);
561-
MLK_FREE(b, mlk_polyvec, 1);
562-
MLK_FREE(ep, mlk_polyvec, 1);
563-
MLK_FREE(pkpv, mlk_polyvec, 1);
564-
MLK_FREE(sp, mlk_polyvec, 1);
565-
MLK_FREE(at, mlk_polymat, 1);
566-
MLK_FREE(seed, uint8_t, MLKEM_SYMBYTES);
561+
MLK_FREE(ws, workspace, 1);
567562
return ret;
568563
}
569564

@@ -575,40 +570,42 @@ int mlk_indcpa_dec(uint8_t m[MLKEM_INDCPA_MSGBYTES],
575570
const uint8_t c[MLKEM_INDCPA_BYTES],
576571
const uint8_t sk[MLKEM_INDCPA_SECRETKEYBYTES])
577572
{
573+
typedef struct
574+
{
575+
mlk_polyvec b;
576+
mlk_polyvec skpv;
577+
mlk_poly v;
578+
mlk_poly sb;
579+
mlk_polyvec_mulcache b_cache;
580+
} workspace;
581+
578582
int ret = 0;
579-
MLK_ALLOC(b, mlk_polyvec, 1);
580-
MLK_ALLOC(skpv, mlk_polyvec, 1);
581-
MLK_ALLOC(v, mlk_poly, 1);
582-
MLK_ALLOC(sb, mlk_poly, 1);
583-
MLK_ALLOC(b_cache, mlk_polyvec_mulcache, 1);
583+
MLK_ALLOC(ws, workspace, 1);
584584

585-
if (b == NULL || skpv == NULL || v == NULL || sb == NULL || b_cache == NULL)
585+
if (ws == NULL)
586586
{
587587
ret = MLK_ERR_OUT_OF_MEMORY;
588588
goto cleanup;
589589
}
590590

591-
mlk_unpack_ciphertext(b, v, c);
592-
mlk_unpack_sk(skpv, sk);
591+
mlk_unpack_ciphertext(&ws->b, &ws->v, c);
592+
mlk_unpack_sk(&ws->skpv, sk);
593593

594-
mlk_polyvec_ntt(b);
595-
mlk_polyvec_mulcache_compute(b_cache, b);
596-
mlk_polyvec_basemul_acc_montgomery_cached(sb, skpv, b, b_cache);
597-
mlk_poly_invntt_tomont(sb);
594+
mlk_polyvec_ntt(&ws->b);
595+
mlk_polyvec_mulcache_compute(&ws->b_cache, &ws->b);
596+
mlk_polyvec_basemul_acc_montgomery_cached(&ws->sb, &ws->skpv, &ws->b,
597+
&ws->b_cache);
598+
mlk_poly_invntt_tomont(&ws->sb);
598599

599-
mlk_poly_sub(v, sb);
600-
mlk_poly_reduce(v);
600+
mlk_poly_sub(&ws->v, &ws->sb);
601+
mlk_poly_reduce(&ws->v);
601602

602-
mlk_poly_tomsg(m, v);
603+
mlk_poly_tomsg(m, &ws->v);
603604

604605
cleanup:
605606
/* Specification: Partially implements
606607
* @[FIPS203, Section 3.3, Destruction of intermediate values] */
607-
MLK_FREE(b_cache, mlk_polyvec_mulcache, 1);
608-
MLK_FREE(sb, mlk_poly, 1);
609-
MLK_FREE(v, mlk_poly, 1);
610-
MLK_FREE(skpv, mlk_polyvec, 1);
611-
MLK_FREE(b, mlk_polyvec, 1);
608+
MLK_FREE(ws, workspace, 1);
612609
return ret;
613610
}
614611

0 commit comments

Comments
 (0)