Skip to content

Commit 4f60463

Browse files
committed
Add unit tests for x4 keccak
1 parent 795de6c commit 4f60463

File tree

2 files changed

+154
-10
lines changed

2 files changed

+154
-10
lines changed

mldsa/fips202/fips202x4.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,13 @@
1616
#include "keccakf1600.h"
1717

1818
/* Context for non-incremental API */
19+
#define mld_shake128x4ctx FIPS202_NAMESPACE(shake128x4ctx)
1920
typedef struct
2021
{
2122
uint64_t ctx[MLD_KECCAK_LANES * MLD_KECCAK_WAY];
2223
} mld_shake128x4ctx;
2324

25+
#define mld_shake256x4ctx FIPS202_NAMESPACE(shake256x4ctx)
2426
typedef struct
2527
{
2628
uint64_t ctx[MLD_KECCAK_LANES * MLD_KECCAK_WAY];

test/test_unit.c

Lines changed: 152 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
#include "../mldsa/fips202/keccakf1600.c"
2323
#include "../mldsa/fips202/fips202.c"
24+
#include "../mldsa/fips202/fips202x4.c"
2425

2526
#undef FIPS202_NAMESPACE
2627
#undef mld_keccakf1600_xor_bytes
@@ -40,24 +41,30 @@
4041
#undef mld_shake128_finalize
4142
#undef mld_shake128_absorb
4243
#undef mld_shake128_init
43-
#undef mld_keccakf1600x4_permute
44-
#undef mld_keccakf1600_permute
45-
#undef mld_keccakf1600x4_xor_bytes
46-
#undef mld_keccakf1600x4_extract_bytes
47-
#undef mld_keccakf1600_xor_bytes
48-
#undef mld_keccakf1600_extract_bytes
44+
#undef mld_shake256x4_release
45+
#undef mld_shake256x4_init
46+
#undef mld_shake256x4_squeezeblocks
47+
#undef mld_shake256x4_absorb_once
48+
#undef mld_shake128x4_release
49+
#undef mld_shake128x4_init
50+
#undef mld_shake128x4_squeezeblocks
51+
#undef mld_shake128x4_absorb_once
52+
#undef mld_shake256ctx
53+
#undef mld_shake128ctx
4954

5055
#undef mld_memset
5156
#undef mld_memcpy
5257
#undef mld_zeroize
5358

5459
#undef MLD_FIPS202_KECCAKF1600_H
5560
#undef MLD_FIPS202_FIPS202_H
61+
#undef MLD_FIPS202_FIPS202X4_H
5662
#undef MLD_COMMON_H
5763

5864
/* Under test */
5965
#include "../mldsa/fips202/keccakf1600.h"
6066
#include "../mldsa/fips202/fips202.h"
67+
#include "../mldsa/fips202/fips202x4.h"
6168

6269
#define CHECK(x) \
6370
do \
@@ -206,6 +213,42 @@ static int shake256_ctx_equals(const test_priv_shake256ctx *ref,
206213
#endif
207214
}
208215

216+
static int shake_x4_ctx_equals(const uint64_t *ref_ctx,
217+
const uint64_t *dut_ctx)
218+
{
219+
#if defined(MLD_USE_FIPS202_X4_XOR_NATIVE)
220+
uint64_t dut_inter[MLD_KECCAK_WAY][MLD_KECCAK_LANES];
221+
unpack_x4_native(dut_inter, dut_ctx);
222+
#endif
223+
224+
for (int ch = 0; ch < MLD_KECCAK_WAY; ch++) {
225+
const uint64_t *ref_lane = ref_ctx + (size_t)ch * MLD_KECCAK_LANES;
226+
#if defined(MLD_USE_FIPS202_X4_XOR_NATIVE)
227+
uint64_t dut_norm[MLD_KECCAK_LANES];
228+
deinterleave_state(dut_norm, dut_inter[ch]);
229+
if (memcmp(ref_lane, dut_norm, sizeof dut_norm) != 0) return 0;
230+
#else
231+
const uint64_t *dut_lane = dut_ctx + (size_t)ch * MLD_KECCAK_LANES;
232+
if (memcmp(ref_lane, dut_lane, sizeof(uint64_t) * MLD_KECCAK_LANES) != 0)
233+
return 0;
234+
#endif
235+
}
236+
237+
return 1;
238+
}
239+
240+
static int shake128x4_ctx_equals(const test_priv_shake128x4ctx *ref,
241+
const mld_shake128x4ctx *dut)
242+
{
243+
return shake_x4_ctx_equals(ref->ctx, dut->ctx);
244+
}
245+
246+
static int shake256x4_ctx_equals(const test_priv_shake256x4ctx *ref,
247+
const mld_shake256x4ctx *dut)
248+
{
249+
return shake_x4_ctx_equals(ref->ctx, dut->ctx);
250+
}
251+
209252
static int test_shake128_api(void)
210253
{
211254
int fails = 0;
@@ -315,6 +358,99 @@ static int test_shake256_function(void)
315358
return fails;
316359
}
317360

361+
static int test_shake128x4_api(void)
362+
{
363+
int fails = 0;
364+
static const size_t absorb_lens[] = {0, 1, SHAKE128_RATE - 1, SHAKE128_RATE,
365+
SHAKE128_RATE + 7, SHAKE128_RATE * 2 + 5};
366+
static const size_t block_counts[] = {0, 1, 2, 3};
367+
uint8_t in[MLD_KECCAK_WAY][SHAKE128_RATE * 3];
368+
uint8_t out_ref[MLD_KECCAK_WAY][SHAKE128_RATE * 3];
369+
uint8_t out_dut[MLD_KECCAK_WAY][SHAKE128_RATE * 3];
370+
371+
for (size_t ai = 0; ai < sizeof(absorb_lens) / sizeof(absorb_lens[0]); ai++) {
372+
size_t inlen = absorb_lens[ai];
373+
for (int ch = 0; ch < MLD_KECCAK_WAY; ch++) {
374+
if (inlen > 0) randombytes(in[ch], inlen);
375+
}
376+
377+
test_priv_shake128x4ctx ref_ctx;
378+
mld_shake128x4ctx dut_ctx;
379+
test_priv_shake128x4_absorb_once(&ref_ctx, in[0], in[1], in[2], in[3], inlen);
380+
mld_shake128x4_absorb_once(&dut_ctx, in[0], in[1], in[2], in[3], inlen);
381+
CHECK(shake128x4_ctx_equals(&ref_ctx, &dut_ctx));
382+
383+
for (size_t bi = 0; bi < sizeof(block_counts) / sizeof(block_counts[0]); bi++) {
384+
size_t nblocks = block_counts[bi];
385+
test_priv_shake128x4ctx ref_copy;
386+
mld_shake128x4ctx dut_copy;
387+
memcpy(&ref_copy, &ref_ctx, sizeof(ref_ctx));
388+
memcpy(&dut_copy, &dut_ctx, sizeof(dut_ctx));
389+
for (int ch = 0; ch < MLD_KECCAK_WAY; ch++) {
390+
memset(out_ref[ch], 0, nblocks * SHAKE128_RATE);
391+
memset(out_dut[ch], 0, nblocks * SHAKE128_RATE);
392+
}
393+
test_priv_shake128x4_squeezeblocks(out_ref[0], out_ref[1], out_ref[2],
394+
out_ref[3], nblocks, &ref_copy);
395+
mld_shake128x4_squeezeblocks(out_dut[0], out_dut[1], out_dut[2],
396+
out_dut[3], nblocks, &dut_copy);
397+
for (int ch = 0; ch < MLD_KECCAK_WAY; ch++) {
398+
CHECK(memcmp(out_ref[ch], out_dut[ch], nblocks * SHAKE128_RATE) == 0);
399+
}
400+
CHECK(shake128x4_ctx_equals(&ref_copy, &dut_copy));
401+
}
402+
}
403+
404+
return fails;
405+
}
406+
407+
static int test_shake256x4_api(void)
408+
{
409+
int fails = 0;
410+
static const size_t absorb_lens[] = {0, 1, SHAKE256_RATE - 1, SHAKE256_RATE,
411+
SHAKE256_RATE + 9, SHAKE256_RATE * 2 + 3};
412+
static const size_t block_counts[] = {0, 1, 2, 3};
413+
uint8_t in[MLD_KECCAK_WAY][SHAKE256_RATE * 3];
414+
uint8_t out_ref[MLD_KECCAK_WAY][SHAKE256_RATE * 3];
415+
uint8_t out_dut[MLD_KECCAK_WAY][SHAKE256_RATE * 3];
416+
417+
for (size_t ai = 0; ai < sizeof(absorb_lens) / sizeof(absorb_lens[0]); ai++) {
418+
size_t inlen = absorb_lens[ai];
419+
for (int ch = 0; ch < MLD_KECCAK_WAY; ch++) {
420+
if (inlen > 0) randombytes(in[ch], inlen);
421+
}
422+
423+
test_priv_shake256x4ctx ref_ctx;
424+
mld_shake256x4ctx dut_ctx;
425+
test_priv_shake256x4_absorb_once(&ref_ctx, in[0], in[1], in[2], in[3], inlen);
426+
mld_shake256x4_absorb_once(&dut_ctx, in[0], in[1], in[2], in[3], inlen);
427+
CHECK(shake256x4_ctx_equals(&ref_ctx, &dut_ctx));
428+
429+
for (size_t bi = 0; bi < sizeof(block_counts) / sizeof(block_counts[0]); bi++) {
430+
size_t nblocks = block_counts[bi];
431+
test_priv_shake256x4ctx ref_copy;
432+
mld_shake256x4ctx dut_copy;
433+
memcpy(&ref_copy, &ref_ctx, sizeof(ref_ctx));
434+
memcpy(&dut_copy, &dut_ctx, sizeof(dut_ctx));
435+
436+
for (int ch = 0; ch < MLD_KECCAK_WAY; ch++) {
437+
memset(out_ref[ch], 0, nblocks * SHAKE256_RATE);
438+
memset(out_dut[ch], 0, nblocks * SHAKE256_RATE);
439+
}
440+
test_priv_shake256x4_squeezeblocks(out_ref[0], out_ref[1], out_ref[2],
441+
out_ref[3], nblocks, &ref_copy);
442+
mld_shake256x4_squeezeblocks(out_dut[0], out_dut[1], out_dut[2],
443+
out_dut[3], nblocks, &dut_copy);
444+
for (int ch = 0; ch < MLD_KECCAK_WAY; ch++) {
445+
CHECK(memcmp(out_ref[ch], out_dut[ch], nblocks * SHAKE256_RATE) == 0);
446+
}
447+
CHECK(shake256x4_ctx_equals(&ref_copy, &dut_copy));
448+
}
449+
}
450+
451+
return fails;
452+
}
453+
318454
#if defined(MLD_USE_FIPS202_X4_NATIVE)
319455
static int test_x4_xor_bytes(void)
320456
{
@@ -581,10 +717,6 @@ int main(void)
581717

582718
randombytes_reset();
583719

584-
fails += test_shake128_api();
585-
fails += test_shake256_api();
586-
fails += test_shake256_function();
587-
588720
#if defined(MLD_USE_FIPS202_X1_NATIVE)
589721
fails += test_xor_bytes();
590722
fails += test_extract_bytes();
@@ -595,6 +727,16 @@ int main(void)
595727
fails += test_x4_extract_bytes();
596728
fails += test_x4_permute();
597729
#endif
730+
#if defined(MLD_USE_FIPS202_X1_NATIVE)
731+
fails += test_shake128_api();
732+
fails += test_shake256_api();
733+
fails += test_shake256_function();
734+
#endif
735+
#if defined(MLD_USE_FIPS202_X4_NATIVE)
736+
fails += test_shake128x4_api();
737+
fails += test_shake256x4_api();
738+
#endif
739+
598740

599741
if (fails) {
600742
fprintf(stderr, "unit tests: FAILED (%d failure%s)\n", fails, fails==1?"":"s");

0 commit comments

Comments
 (0)