Skip to content

Commit 795de6c

Browse files
committed
Add unit tests for FIPS202 APIs
1 parent 664a006 commit 795de6c

File tree

2 files changed

+170
-2
lines changed

2 files changed

+170
-2
lines changed

mldsa/fips202/fips202.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,18 @@
1717
#define SHA3_256_HASHBYTES 32
1818
#define SHA3_512_HASHBYTES 64
1919

20+
#ifndef FIPS202_NAMESPACE
2021
#define FIPS202_NAMESPACE(s) mldsa_fips202_ref_##s
22+
#endif
2123

24+
#define mld_shake128ctx FIPS202_NAMESPACE(shake128ctx)
2225
typedef struct
2326
{
2427
uint64_t s[MLD_KECCAK_LANES];
2528
unsigned int pos;
2629
} mld_shake128ctx;
2730

31+
#define mld_shake256ctx FIPS202_NAMESPACE(shake256ctx)
2832
typedef struct
2933
{
3034
uint64_t s[MLD_KECCAK_LANES];

test/test_unit.c

Lines changed: 166 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,13 @@
1414

1515

1616
#define FIPS202_NAMESPACE(A) test_priv_##A
17-
#define MLD_FIPS202_FIPS202_H
1817
#define MLD_COMMON_H
18+
#define mld_memset memset
19+
#define mld_memcpy memcpy
20+
#define mld_zeroize(PTR, LEN) memset(PTR, 0, LEN)
1921

2022
#include "../mldsa/fips202/keccakf1600.c"
23+
#include "../mldsa/fips202/fips202.c"
2124

2225
#undef FIPS202_NAMESPACE
2326
#undef mld_keccakf1600_xor_bytes
@@ -26,13 +29,35 @@
2629
#undef mld_keccakf1600x4_extract_bytes
2730
#undef mld_keccakf1600_permute
2831
#undef mld_keccakf1600x4_permute
32+
#undef mld_shake256
33+
#undef mld_shake256_release
34+
#undef mld_shake256_squeeze
35+
#undef mld_shake256_finalize
36+
#undef mld_shake256_absorb
37+
#undef mld_shake256_init
38+
#undef mld_shake128_release
39+
#undef mld_shake128_squeeze
40+
#undef mld_shake128_finalize
41+
#undef mld_shake128_absorb
42+
#undef mld_shake128_init
43+
#undef mld_keccakf1600x4_permute
44+
#undef mld_keccakf1600_permute
45+
#undef mld_keccakf1600x4_xor_bytes
46+
#undef mld_keccakf1600x4_extract_bytes
47+
#undef mld_keccakf1600_xor_bytes
48+
#undef mld_keccakf1600_extract_bytes
49+
50+
#undef mld_memset
51+
#undef mld_memcpy
52+
#undef mld_zeroize
2953

3054
#undef MLD_FIPS202_KECCAKF1600_H
3155
#undef MLD_FIPS202_FIPS202_H
3256
#undef MLD_COMMON_H
3357

3458
/* Under test */
3559
#include "../mldsa/fips202/keccakf1600.h"
60+
#include "../mldsa/fips202/fips202.h"
3661

3762
#define CHECK(x) \
3863
do \
@@ -49,7 +74,7 @@
4974
#define STATE_BYTES (MLD_KECCAK_LANES * sizeof(uint64_t))
5075

5176
#if defined(MLD_USE_FIPS202_X1_XOR_NATIVE) || defined(MLD_USE_FIPS202_X4_XOR_NATIVE)
52-
/* Bit interleave/deinterleave helpers as specified */
77+
/* Bit interleave/deinterleave helpers */
5378
static uint64_t test_priv_bit_interleave(uint64_t in)
5479
{
5580
uint64_t e = in & 0x5555555555555555ULL;
@@ -155,6 +180,141 @@ static void unpack_x4_native(uint64_t ch_interleaved_out[MLD_KECCAK_WAY][MLD_KEC
155180
}
156181
#endif /* MLD_USE_FIPS202_X4_NATIVE && MLD_USE_FIPS202_X4_XOR_NATIVE */
157182

183+
static int shake128_ctx_equals(const test_priv_shake128ctx *ref,
184+
const mld_shake128ctx *dut)
185+
{
186+
if (ref->pos != dut->pos) return 0;
187+
#if defined(MLD_USE_FIPS202_X1_XOR_NATIVE)
188+
uint64_t norm[MLD_KECCAK_LANES];
189+
deinterleave_state(norm, dut->s);
190+
return memcmp(ref->s, norm, sizeof norm) == 0;
191+
#else
192+
return memcmp(ref->s, dut->s, sizeof ref->s) == 0;
193+
#endif
194+
}
195+
196+
static int shake256_ctx_equals(const test_priv_shake256ctx *ref,
197+
const mld_shake256ctx *dut)
198+
{
199+
if (ref->pos != dut->pos) return 0;
200+
#if defined(MLD_USE_FIPS202_X1_XOR_NATIVE)
201+
uint64_t norm[MLD_KECCAK_LANES];
202+
deinterleave_state(norm, dut->s);
203+
return memcmp(ref->s, norm, sizeof norm) == 0;
204+
#else
205+
return memcmp(ref->s, dut->s, sizeof ref->s) == 0;
206+
#endif
207+
}
208+
209+
static int test_shake128_api(void)
210+
{
211+
int fails = 0;
212+
static const size_t absorb_lens[] = {0, 1, SHAKE128_RATE - 1, SHAKE128_RATE,
213+
SHAKE128_RATE + 7, SHAKE128_RATE * 2 + 5};
214+
static const size_t squeeze_lens[] = {0, 1, 32, SHAKE128_RATE,
215+
SHAKE128_RATE + 11, SHAKE128_RATE * 3};
216+
uint8_t chunk[SHAKE128_RATE * 3];
217+
uint8_t out_ref[SHAKE128_RATE * 4];
218+
uint8_t out_dut[SHAKE128_RATE * 4];
219+
test_priv_shake128ctx ref_ctx;
220+
mld_shake128ctx dut_ctx;
221+
222+
test_priv_shake128_init(&ref_ctx);
223+
mld_shake128_init(&dut_ctx);
224+
CHECK(shake128_ctx_equals(&ref_ctx, &dut_ctx));
225+
226+
for (size_t i = 0; i < sizeof(absorb_lens) / sizeof(absorb_lens[0]); i++) {
227+
size_t len = absorb_lens[i];
228+
if (len > 0) randombytes(chunk, len);
229+
test_priv_shake128_absorb(&ref_ctx, chunk, len);
230+
mld_shake128_absorb(&dut_ctx, chunk, len);
231+
CHECK(shake128_ctx_equals(&ref_ctx, &dut_ctx));
232+
}
233+
234+
test_priv_shake128_finalize(&ref_ctx);
235+
mld_shake128_finalize(&dut_ctx);
236+
CHECK(shake128_ctx_equals(&ref_ctx, &dut_ctx));
237+
238+
for (size_t i = 0; i < sizeof(squeeze_lens) / sizeof(squeeze_lens[0]); i++) {
239+
size_t len = squeeze_lens[i];
240+
memset(out_ref, 0, len);
241+
memset(out_dut, 0, len);
242+
test_priv_shake128_squeeze(out_ref, len, &ref_ctx);
243+
mld_shake128_squeeze(out_dut, len, &dut_ctx);
244+
CHECK(memcmp(out_ref, out_dut, len) == 0);
245+
CHECK(shake128_ctx_equals(&ref_ctx, &dut_ctx));
246+
}
247+
248+
return fails;
249+
}
250+
251+
static int test_shake256_api(void)
252+
{
253+
int fails = 0;
254+
static const size_t absorb_lens[] = {0, 1, SHAKE256_RATE - 1, SHAKE256_RATE,
255+
SHAKE256_RATE + 9, SHAKE256_RATE * 2 + 3};
256+
static const size_t squeeze_lens[] = {0, 1, 48, SHAKE256_RATE,
257+
SHAKE256_RATE + 13, SHAKE256_RATE * 3};
258+
uint8_t chunk[SHAKE256_RATE * 3];
259+
uint8_t out_ref[SHAKE256_RATE * 4];
260+
uint8_t out_dut[SHAKE256_RATE * 4];
261+
test_priv_shake256ctx ref_ctx;
262+
mld_shake256ctx dut_ctx;
263+
264+
test_priv_shake256_init(&ref_ctx);
265+
mld_shake256_init(&dut_ctx);
266+
CHECK(shake256_ctx_equals(&ref_ctx, &dut_ctx));
267+
268+
for (size_t i = 0; i < sizeof(absorb_lens) / sizeof(absorb_lens[0]); i++) {
269+
size_t len = absorb_lens[i];
270+
if (len > 0) randombytes(chunk, len);
271+
test_priv_shake256_absorb(&ref_ctx, chunk, len);
272+
mld_shake256_absorb(&dut_ctx, chunk, len);
273+
CHECK(shake256_ctx_equals(&ref_ctx, &dut_ctx));
274+
}
275+
276+
test_priv_shake256_finalize(&ref_ctx);
277+
mld_shake256_finalize(&dut_ctx);
278+
CHECK(shake256_ctx_equals(&ref_ctx, &dut_ctx));
279+
280+
for (size_t i = 0; i < sizeof(squeeze_lens) / sizeof(squeeze_lens[0]); i++) {
281+
size_t len = squeeze_lens[i];
282+
memset(out_ref, 0, len);
283+
memset(out_dut, 0, len);
284+
test_priv_shake256_squeeze(out_ref, len, &ref_ctx);
285+
mld_shake256_squeeze(out_dut, len, &dut_ctx);
286+
CHECK(memcmp(out_ref, out_dut, len) == 0);
287+
CHECK(shake256_ctx_equals(&ref_ctx, &dut_ctx));
288+
}
289+
290+
return fails;
291+
}
292+
293+
static int test_shake256_function(void)
294+
{
295+
int fails = 0;
296+
static const size_t in_lens[] = {0, 1, 50, 200};
297+
static const size_t out_lens[] = {0, 1, 64, 256};
298+
uint8_t inbuf[200];
299+
uint8_t out_ref[256];
300+
uint8_t out_dut[256];
301+
302+
for (size_t i = 0; i < sizeof(in_lens) / sizeof(in_lens[0]); i++) {
303+
size_t inlen = in_lens[i];
304+
if (inlen > 0) randombytes(inbuf, inlen);
305+
for (size_t j = 0; j < sizeof(out_lens) / sizeof(out_lens[0]); j++) {
306+
size_t outlen = out_lens[j];
307+
memset(out_ref, 0, outlen);
308+
memset(out_dut, 0, outlen);
309+
test_priv_shake256(out_ref, outlen, inbuf, inlen);
310+
mld_shake256(out_dut, outlen, inbuf, inlen);
311+
CHECK(memcmp(out_ref, out_dut, outlen) == 0);
312+
}
313+
}
314+
315+
return fails;
316+
}
317+
158318
#if defined(MLD_USE_FIPS202_X4_NATIVE)
159319
static int test_x4_xor_bytes(void)
160320
{
@@ -421,6 +581,10 @@ int main(void)
421581

422582
randombytes_reset();
423583

584+
fails += test_shake128_api();
585+
fails += test_shake256_api();
586+
fails += test_shake256_function();
587+
424588
#if defined(MLD_USE_FIPS202_X1_NATIVE)
425589
fails += test_xor_bytes();
426590
fails += test_extract_bytes();

0 commit comments

Comments
 (0)