Skip to content

Commit ae85594

Browse files
committed
Add runtime dispatch (mld_ntt_native)
- Change mld_ntt_native() return type from void to int - Add runtime capability checking with fallback support - Implement dispatch logic in mld_poly_ntt() to try native first, fallback to C - Add MLD_NATIVE_FUNC_SUCCESS/FALLBACK return codes - Add mld_sys_check_capability() for system capability detection - Add test configuration for AVX2, static ON/OFF, add to CI test. Signed-off-by: willieyz <[email protected]>
1 parent 9206ac0 commit ae85594

22 files changed

+1955
-15
lines changed

.github/actions/config-variations/action.yml

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,3 +123,42 @@ runs:
123123
acvp: true
124124
opt: ${{ inputs.opt }}
125125
examples: false # Some examples use a custom config themselves
126+
- name: "Custom native capability functions (static ON)"
127+
if: ${{ inputs.tests == 'all' || contains(inputs.tests, 'native-cap-ON') }}
128+
uses: ./.github/actions/multi-functest
129+
with:
130+
gh_token: ${{ inputs.gh_token }}
131+
compile_mode: native
132+
cflags: "-std=c11 -D_GNU_SOURCE -DMLD_CONFIG_FILE=\\\\\\\"../../test/custom_native_capability_config_1.h\\\\\\\" -fsanitize=address -fsanitize=undefined -fno-sanitize-recover=all"
133+
ldflags: "-fsanitize=address -fsanitize=undefined -fno-sanitize-recover=all"
134+
func: true
135+
kat: true
136+
acvp: true
137+
opt: ${{ inputs.opt }}
138+
examples: false # Some examples use a custom config themselves
139+
- name: "Custom native capability functions (static OFF)"
140+
if: ${{ inputs.tests == 'all' || contains(inputs.tests, 'native-cap-OFF') }}
141+
uses: ./.github/actions/multi-functest
142+
with:
143+
gh_token: ${{ inputs.gh_token }}
144+
compile_mode: native
145+
cflags: "-std=c11 -D_GNU_SOURCE -DMLD_CONFIG_FILE=\\\\\\\"../../test/custom_native_capability_config_0.h\\\\\\\" -fsanitize=address -fsanitize=undefined -fno-sanitize-recover=all"
146+
ldflags: "-fsanitize=address -fsanitize=undefined -fno-sanitize-recover=all"
147+
func: true
148+
kat: true
149+
acvp: true
150+
opt: ${{ inputs.opt }}
151+
examples: false # Some examples use a custom config themselves
152+
- name: "Custom native capability functions (CPUID AVX2 detection)"
153+
if: ${{ (inputs.tests == 'all' || contains(inputs.tests, 'native-cap-CPUID_AVX2')) && runner.os == 'Linux' && runner.arch == 'X64' }}
154+
uses: ./.github/actions/multi-functest
155+
with:
156+
gh_token: ${{ inputs.gh_token }}
157+
compile_mode: native
158+
cflags: "-std=c11 -mavx2 -mbmi2 -mpopcnt -D_GNU_SOURCE -DMLD_CONFIG_FILE=\\\\\\\"../../test/custom_native_capability_config_CPUID_AVX2.h\\\\\\\" -fsanitize=address -fsanitize=undefined -fno-sanitize-recover=all"
159+
ldflags: "-fsanitize=address -fsanitize=undefined -fno-sanitize-recover=all"
160+
func: true
161+
kat: true
162+
acvp: true
163+
opt: ${{ inputs.opt }}
164+
examples: false # Some examples use a custom config themselves

BIBLIOGRAPHY.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@ source code and documentation.
2626
- [test/break_pct_config.h](test/break_pct_config.h)
2727
- [test/custom_memcpy_config.h](test/custom_memcpy_config.h)
2828
- [test/custom_memset_config.h](test/custom_memset_config.h)
29+
- [test/custom_native_capability_config_0.h](test/custom_native_capability_config_0.h)
30+
- [test/custom_native_capability_config_1.h](test/custom_native_capability_config_1.h)
31+
- [test/custom_native_capability_config_CPUID_AVX2.h](test/custom_native_capability_config_CPUID_AVX2.h)
2932
- [test/custom_randombytes_config.h](test/custom_randombytes_config.h)
3033
- [test/custom_stdlib_config.h](test/custom_stdlib_config.h)
3134
- [test/custom_zeroize_config.h](test/custom_zeroize_config.h)
@@ -68,6 +71,9 @@ source code and documentation.
6871
- [test/break_pct_config.h](test/break_pct_config.h)
6972
- [test/custom_memcpy_config.h](test/custom_memcpy_config.h)
7073
- [test/custom_memset_config.h](test/custom_memset_config.h)
74+
- [test/custom_native_capability_config_0.h](test/custom_native_capability_config_0.h)
75+
- [test/custom_native_capability_config_1.h](test/custom_native_capability_config_1.h)
76+
- [test/custom_native_capability_config_CPUID_AVX2.h](test/custom_native_capability_config_CPUID_AVX2.h)
7177
- [test/custom_randombytes_config.h](test/custom_randombytes_config.h)
7278
- [test/custom_stdlib_config.h](test/custom_stdlib_config.h)
7379
- [test/custom_zeroize_config.h](test/custom_zeroize_config.h)

dev/aarch64_clean/meta.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,14 @@
3232

3333

3434
#if !defined(__ASSEMBLER__)
35+
#include "../api.h"
3536
#include "src/arith_native_aarch64.h"
3637

37-
static MLD_INLINE void mld_ntt_native(int32_t data[MLDSA_N])
38+
static MLD_INLINE int mld_ntt_native(int32_t data[MLDSA_N])
3839
{
3940
mld_ntt_asm(data, mld_aarch64_ntt_zetas_layer123456,
4041
mld_aarch64_ntt_zetas_layer78);
42+
return MLD_NATIVE_FUNC_SUCCESS;
4143
}
4244

4345
static MLD_INLINE void mld_intt_native(int32_t data[MLDSA_N])

dev/x86_64/meta.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,16 +33,23 @@
3333
#if !defined(__ASSEMBLER__)
3434
#include <string.h>
3535
#include "../../common.h"
36+
#include "../api.h"
3637
#include "src/arith_native_x86_64.h"
3738

3839
static MLD_INLINE void mld_poly_permute_bitrev_to_custom(int32_t data[MLDSA_N])
3940
{
4041
mld_nttunpack_avx2((__m256i *)(data));
4142
}
4243

43-
static MLD_INLINE void mld_ntt_native(int32_t data[MLDSA_N])
44+
static MLD_INLINE int mld_ntt_native(int32_t data[MLDSA_N])
4445
{
46+
if (!mld_sys_check_capability(MLD_SYS_CAP_AVX2))
47+
{
48+
return MLD_NATIVE_FUNC_FALLBACK;
49+
}
50+
4551
mld_ntt_avx2((__m256i *)data, mld_qdata.vec);
52+
return MLD_NATIVE_FUNC_SUCCESS;
4653
}
4754
static MLD_INLINE void mld_intt_native(int32_t data[MLDSA_N])
4855
{

mldsa/mldsa_native.c

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,8 +392,10 @@
392392
#undef MLD_RESTRICT
393393
#undef MLD_SYS_AARCH64
394394
#undef MLD_SYS_AARCH64_EB
395+
#undef MLD_SYS_APPLE
395396
#undef MLD_SYS_BIG_ENDIAN
396397
#undef MLD_SYS_H
398+
#undef MLD_SYS_LINUX
397399
#undef MLD_SYS_LITTLE_ENDIAN
398400
#undef MLD_SYS_PPC64LE
399401
#undef MLD_SYS_RISCV64
@@ -508,6 +510,8 @@
508510
#if defined(MLD_CONFIG_USE_NATIVE_BACKEND_ARITH)
509511
/* mldsa/src/native/api.h */
510512
#undef MLD_NATIVE_API_H
513+
#undef MLD_NATIVE_FUNC_FALLBACK
514+
#undef MLD_NATIVE_FUNC_SUCCESS
511515
/* mldsa/src/native/meta.h */
512516
#undef MLD_NATIVE_META_H
513517
#if defined(MLD_SYS_AARCH64)

mldsa/src/config.h

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,36 @@
298298
#endif
299299
*/
300300

301+
302+
/******************************************************************************
303+
* Name: MLD_CONFIG_CUSTOM_CAPABILITY_FUNC
304+
*
305+
* Description: mldsa-native backends may rely on specific hardware features.
306+
* Those backends will only be included in an mldsa-native build
307+
* if support for the respective features is enabled at
308+
* compile-time. However, when building for a heteroneous set
309+
* of CPUs to run the resulting binary/library on, feature
310+
* detection at _runtime_ is needed to decided whether a backend
311+
* can be used or not.
312+
*
313+
* Set this option and define `mld_sys_check_capability` if you
314+
* want to use a custom method to dispatch between implementations.
315+
*
316+
* If this option is not set, mldsa-native uses compile-time
317+
* feature detection only to decide which backend to use.
318+
*
319+
* If you compile mldsa-native on a system with different
320+
* capabilities than the system that the resulting binary/library
321+
* will be run on, you must use this option.
322+
*
323+
*****************************************************************************/
324+
/* #define MLD_CONFIG_CUSTOM_CAPABILITY_FUNC
325+
static MLD_INLINE int mld_sys_check_capability(mld_sys_cap cap)
326+
{
327+
... your implementation ...
328+
}
329+
*/
330+
301331
/******************************************************************************
302332
* Name: MLD_CONFIG_KEYGEN_PCT
303333
*

mldsa/src/native/aarch64/meta.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,14 @@
3232

3333

3434
#if !defined(__ASSEMBLER__)
35+
#include "../api.h"
3536
#include "src/arith_native_aarch64.h"
3637

37-
static MLD_INLINE void mld_ntt_native(int32_t data[MLDSA_N])
38+
static MLD_INLINE int mld_ntt_native(int32_t data[MLDSA_N])
3839
{
3940
mld_ntt_asm(data, mld_aarch64_ntt_zetas_layer123456,
4041
mld_aarch64_ntt_zetas_layer78);
42+
return MLD_NATIVE_FUNC_SUCCESS;
4143
}
4244

4345
static MLD_INLINE void mld_intt_native(int32_t data[MLDSA_N])

mldsa/src/native/api.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,14 @@
2222
#include "../cbmc.h"
2323
#include "../common.h"
2424

25+
/* Backends must return MLD_NATIVE_FUNC_SUCCESS upon success. */
26+
#define MLD_NATIVE_FUNC_SUCCESS (0)
27+
/* Backends may return MLD_NATIVE_FUNC_FALLBACK to signal to the frontend that
28+
* the target/parameters are unsupported; typically, this would be because of
29+
* dependencies on CPU features not detected on the host CPU. In this case,
30+
* the frontend falls back to the default C implementation. */
31+
#define MLD_NATIVE_FUNC_FALLBACK (-1)
32+
2533
/*
2634
* This is the C<->native interface allowing for the drop-in of
2735
* native code for performance critical arithmetic components of ML-DSA.
@@ -52,7 +60,7 @@
5260
*
5361
* Arguments: - int32_t p[MLDSA_N]: pointer to in/output polynomial
5462
**************************************************/
55-
static MLD_INLINE void mld_ntt_native(int32_t p[MLDSA_N]);
63+
static MLD_INLINE int mld_ntt_native(int32_t p[MLDSA_N]);
5664
#endif /* MLD_USE_NATIVE_NTT */
5765

5866

mldsa/src/native/x86_64/meta.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,16 +33,23 @@
3333
#if !defined(__ASSEMBLER__)
3434
#include <string.h>
3535
#include "../../common.h"
36+
#include "../api.h"
3637
#include "src/arith_native_x86_64.h"
3738

3839
static MLD_INLINE void mld_poly_permute_bitrev_to_custom(int32_t data[MLDSA_N])
3940
{
4041
mld_nttunpack_avx2((__m256i *)(data));
4142
}
4243

43-
static MLD_INLINE void mld_ntt_native(int32_t data[MLDSA_N])
44+
static MLD_INLINE int mld_ntt_native(int32_t data[MLDSA_N])
4445
{
46+
if (!mld_sys_check_capability(MLD_SYS_CAP_AVX2))
47+
{
48+
return MLD_NATIVE_FUNC_FALLBACK;
49+
}
50+
4551
mld_ntt_avx2((__m256i *)data, mld_qdata.vec);
52+
return MLD_NATIVE_FUNC_SUCCESS;
4653
}
4754
static MLD_INLINE void mld_intt_native(int32_t data[MLDSA_N])
4855
{

mldsa/src/poly.c

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -142,23 +142,24 @@ void mld_poly_shiftl(mld_poly *a)
142142
mld_assert_bound(a->coeffs, MLDSA_N, 0, MLDSA_Q);
143143
}
144144

145-
#if !defined(MLD_USE_NATIVE_NTT)
146145
MLD_INTERNAL_API
147146
void mld_poly_ntt(mld_poly *a)
148147
{
149148
mld_assert_abs_bound(a->coeffs, MLDSA_N, MLDSA_Q);
149+
#if defined(MLD_USE_NATIVE_NTT)
150+
{
151+
int ret;
152+
ret = mld_ntt_native(a->coeffs);
153+
if (ret == MLD_NATIVE_FUNC_SUCCESS)
154+
{
155+
mld_assert_abs_bound(a->coeffs, MLDSA_N, MLD_NTT_BOUND);
156+
return;
157+
}
158+
}
159+
#endif /* MLD_USE_NATIVE_NTT */
150160
mld_ntt(a->coeffs);
151161
mld_assert_abs_bound(a->coeffs, MLDSA_N, MLD_NTT_BOUND);
152162
}
153-
#else /* !MLD_USE_NATIVE_NTT */
154-
MLD_INTERNAL_API
155-
void mld_poly_ntt(mld_poly *p)
156-
{
157-
mld_assert_abs_bound(p->coeffs, MLDSA_N, MLDSA_Q);
158-
mld_ntt_native(p->coeffs);
159-
mld_assert_abs_bound(p->coeffs, MLDSA_N, MLD_NTT_BOUND);
160-
}
161-
#endif /* MLD_USE_NATIVE_NTT */
162163

163164
#if !defined(MLD_USE_NATIVE_INTT)
164165
MLD_INTERNAL_API

0 commit comments

Comments
 (0)