2121
2222#include "../mldsa/fips202/keccakf1600.c"
2323#include "../mldsa/fips202/fips202.c"
24+ #include "../mldsa/fips202/fips202x4.c"
2425
2526#undef FIPS202_NAMESPACE
2627#undef mld_keccakf1600_xor_bytes
4041#undef mld_shake128_finalize
4142#undef mld_shake128_absorb
4243#undef mld_shake128_init
43- #undef mld_keccakf1600x4_permute
44- #undef mld_keccakf1600_permute
45- #undef mld_keccakf1600x4_xor_bytes
46- #undef mld_keccakf1600x4_extract_bytes
47- #undef mld_keccakf1600_xor_bytes
48- #undef mld_keccakf1600_extract_bytes
44+ #undef mld_shake256x4_release
45+ #undef mld_shake256x4_init
46+ #undef mld_shake256x4_squeezeblocks
47+ #undef mld_shake256x4_absorb_once
48+ #undef mld_shake128x4_release
49+ #undef mld_shake128x4_init
50+ #undef mld_shake128x4_squeezeblocks
51+ #undef mld_shake128x4_absorb_once
52+ #undef mld_shake256ctx
53+ #undef mld_shake128ctx
4954
5055#undef mld_memset
5156#undef mld_memcpy
5257#undef mld_zeroize
5358
5459#undef MLD_FIPS202_KECCAKF1600_H
5560#undef MLD_FIPS202_FIPS202_H
61+ #undef MLD_FIPS202_FIPS202X4_H
5662#undef MLD_COMMON_H
5763
5864/* Under test */
5965#include "../mldsa/fips202/keccakf1600.h"
6066#include "../mldsa/fips202/fips202.h"
67+ #include "../mldsa/fips202/fips202x4.h"
6168
6269#define CHECK (x ) \
6370 do \
@@ -206,6 +213,42 @@ static int shake256_ctx_equals(const test_priv_shake256ctx *ref,
206213#endif
207214}
208215
216+ static int shake_x4_ctx_equals (const uint64_t * ref_ctx ,
217+ const uint64_t * dut_ctx )
218+ {
219+ #if defined(MLD_USE_FIPS202_X4_XOR_NATIVE )
220+ uint64_t dut_inter [MLD_KECCAK_WAY ][MLD_KECCAK_LANES ];
221+ unpack_x4_native (dut_inter , dut_ctx );
222+ #endif
223+
224+ for (int ch = 0 ; ch < MLD_KECCAK_WAY ; ch ++ ) {
225+ const uint64_t * ref_lane = ref_ctx + (size_t )ch * MLD_KECCAK_LANES ;
226+ #if defined(MLD_USE_FIPS202_X4_XOR_NATIVE )
227+ uint64_t dut_norm [MLD_KECCAK_LANES ];
228+ deinterleave_state (dut_norm , dut_inter [ch ]);
229+ if (memcmp (ref_lane , dut_norm , sizeof dut_norm ) != 0 ) return 0 ;
230+ #else
231+ const uint64_t * dut_lane = dut_ctx + (size_t )ch * MLD_KECCAK_LANES ;
232+ if (memcmp (ref_lane , dut_lane , sizeof (uint64_t ) * MLD_KECCAK_LANES ) != 0 )
233+ return 0 ;
234+ #endif
235+ }
236+
237+ return 1 ;
238+ }
239+
240+ static int shake128x4_ctx_equals (const test_priv_shake128x4ctx * ref ,
241+ const mld_shake128x4ctx * dut )
242+ {
243+ return shake_x4_ctx_equals (ref -> ctx , dut -> ctx );
244+ }
245+
246+ static int shake256x4_ctx_equals (const test_priv_shake256x4ctx * ref ,
247+ const mld_shake256x4ctx * dut )
248+ {
249+ return shake_x4_ctx_equals (ref -> ctx , dut -> ctx );
250+ }
251+
209252static int test_shake128_api (void )
210253{
211254 int fails = 0 ;
@@ -315,6 +358,99 @@ static int test_shake256_function(void)
315358 return fails ;
316359}
317360
361+ static int test_shake128x4_api (void )
362+ {
363+ int fails = 0 ;
364+ static const size_t absorb_lens [] = {0 , 1 , SHAKE128_RATE - 1 , SHAKE128_RATE ,
365+ SHAKE128_RATE + 7 , SHAKE128_RATE * 2 + 5 };
366+ static const size_t block_counts [] = {0 , 1 , 2 , 3 };
367+ uint8_t in [MLD_KECCAK_WAY ][SHAKE128_RATE * 3 ];
368+ uint8_t out_ref [MLD_KECCAK_WAY ][SHAKE128_RATE * 3 ];
369+ uint8_t out_dut [MLD_KECCAK_WAY ][SHAKE128_RATE * 3 ];
370+
371+ for (size_t ai = 0 ; ai < sizeof (absorb_lens ) / sizeof (absorb_lens [0 ]); ai ++ ) {
372+ size_t inlen = absorb_lens [ai ];
373+ for (int ch = 0 ; ch < MLD_KECCAK_WAY ; ch ++ ) {
374+ if (inlen > 0 ) randombytes (in [ch ], inlen );
375+ }
376+
377+ test_priv_shake128x4ctx ref_ctx ;
378+ mld_shake128x4ctx dut_ctx ;
379+ test_priv_shake128x4_absorb_once (& ref_ctx , in [0 ], in [1 ], in [2 ], in [3 ], inlen );
380+ mld_shake128x4_absorb_once (& dut_ctx , in [0 ], in [1 ], in [2 ], in [3 ], inlen );
381+ CHECK (shake128x4_ctx_equals (& ref_ctx , & dut_ctx ));
382+
383+ for (size_t bi = 0 ; bi < sizeof (block_counts ) / sizeof (block_counts [0 ]); bi ++ ) {
384+ size_t nblocks = block_counts [bi ];
385+ test_priv_shake128x4ctx ref_copy ;
386+ mld_shake128x4ctx dut_copy ;
387+ memcpy (& ref_copy , & ref_ctx , sizeof (ref_ctx ));
388+ memcpy (& dut_copy , & dut_ctx , sizeof (dut_ctx ));
389+ for (int ch = 0 ; ch < MLD_KECCAK_WAY ; ch ++ ) {
390+ memset (out_ref [ch ], 0 , nblocks * SHAKE128_RATE );
391+ memset (out_dut [ch ], 0 , nblocks * SHAKE128_RATE );
392+ }
393+ test_priv_shake128x4_squeezeblocks (out_ref [0 ], out_ref [1 ], out_ref [2 ],
394+ out_ref [3 ], nblocks , & ref_copy );
395+ mld_shake128x4_squeezeblocks (out_dut [0 ], out_dut [1 ], out_dut [2 ],
396+ out_dut [3 ], nblocks , & dut_copy );
397+ for (int ch = 0 ; ch < MLD_KECCAK_WAY ; ch ++ ) {
398+ CHECK (memcmp (out_ref [ch ], out_dut [ch ], nblocks * SHAKE128_RATE ) == 0 );
399+ }
400+ CHECK (shake128x4_ctx_equals (& ref_copy , & dut_copy ));
401+ }
402+ }
403+
404+ return fails ;
405+ }
406+
407+ static int test_shake256x4_api (void )
408+ {
409+ int fails = 0 ;
410+ static const size_t absorb_lens [] = {0 , 1 , SHAKE256_RATE - 1 , SHAKE256_RATE ,
411+ SHAKE256_RATE + 9 , SHAKE256_RATE * 2 + 3 };
412+ static const size_t block_counts [] = {0 , 1 , 2 , 3 };
413+ uint8_t in [MLD_KECCAK_WAY ][SHAKE256_RATE * 3 ];
414+ uint8_t out_ref [MLD_KECCAK_WAY ][SHAKE256_RATE * 3 ];
415+ uint8_t out_dut [MLD_KECCAK_WAY ][SHAKE256_RATE * 3 ];
416+
417+ for (size_t ai = 0 ; ai < sizeof (absorb_lens ) / sizeof (absorb_lens [0 ]); ai ++ ) {
418+ size_t inlen = absorb_lens [ai ];
419+ for (int ch = 0 ; ch < MLD_KECCAK_WAY ; ch ++ ) {
420+ if (inlen > 0 ) randombytes (in [ch ], inlen );
421+ }
422+
423+ test_priv_shake256x4ctx ref_ctx ;
424+ mld_shake256x4ctx dut_ctx ;
425+ test_priv_shake256x4_absorb_once (& ref_ctx , in [0 ], in [1 ], in [2 ], in [3 ], inlen );
426+ mld_shake256x4_absorb_once (& dut_ctx , in [0 ], in [1 ], in [2 ], in [3 ], inlen );
427+ CHECK (shake256x4_ctx_equals (& ref_ctx , & dut_ctx ));
428+
429+ for (size_t bi = 0 ; bi < sizeof (block_counts ) / sizeof (block_counts [0 ]); bi ++ ) {
430+ size_t nblocks = block_counts [bi ];
431+ test_priv_shake256x4ctx ref_copy ;
432+ mld_shake256x4ctx dut_copy ;
433+ memcpy (& ref_copy , & ref_ctx , sizeof (ref_ctx ));
434+ memcpy (& dut_copy , & dut_ctx , sizeof (dut_ctx ));
435+
436+ for (int ch = 0 ; ch < MLD_KECCAK_WAY ; ch ++ ) {
437+ memset (out_ref [ch ], 0 , nblocks * SHAKE256_RATE );
438+ memset (out_dut [ch ], 0 , nblocks * SHAKE256_RATE );
439+ }
440+ test_priv_shake256x4_squeezeblocks (out_ref [0 ], out_ref [1 ], out_ref [2 ],
441+ out_ref [3 ], nblocks , & ref_copy );
442+ mld_shake256x4_squeezeblocks (out_dut [0 ], out_dut [1 ], out_dut [2 ],
443+ out_dut [3 ], nblocks , & dut_copy );
444+ for (int ch = 0 ; ch < MLD_KECCAK_WAY ; ch ++ ) {
445+ CHECK (memcmp (out_ref [ch ], out_dut [ch ], nblocks * SHAKE256_RATE ) == 0 );
446+ }
447+ CHECK (shake256x4_ctx_equals (& ref_copy , & dut_copy ));
448+ }
449+ }
450+
451+ return fails ;
452+ }
453+
318454#if defined(MLD_USE_FIPS202_X4_NATIVE )
319455static int test_x4_xor_bytes (void )
320456{
@@ -581,10 +717,6 @@ int main(void)
581717
582718 randombytes_reset ();
583719
584- fails += test_shake128_api ();
585- fails += test_shake256_api ();
586- fails += test_shake256_function ();
587-
588720#if defined(MLD_USE_FIPS202_X1_NATIVE )
589721 fails += test_xor_bytes ();
590722 fails += test_extract_bytes ();
@@ -595,6 +727,16 @@ int main(void)
595727 fails += test_x4_extract_bytes ();
596728 fails += test_x4_permute ();
597729#endif
730+ #if defined(MLD_USE_FIPS202_X1_NATIVE )
731+ fails += test_shake128_api ();
732+ fails += test_shake256_api ();
733+ fails += test_shake256_function ();
734+ #endif
735+ #if defined(MLD_USE_FIPS202_X4_NATIVE )
736+ fails += test_shake128x4_api ();
737+ fails += test_shake256x4_api ();
738+ #endif
739+
598740
599741 if (fails ) {
600742 fprintf (stderr , "unit tests: FAILED (%d failure%s)\n" , fails , fails == 1 ?"" :"s" );
0 commit comments