diff --git a/.github/actions/multi-functest/action.yml b/.github/actions/multi-functest/action.yml index 252931918..72b73370a 100644 --- a/.github/actions/multi-functest/action.yml +++ b/.github/actions/multi-functest/action.yml @@ -161,7 +161,7 @@ runs: custom_shell: ${{ inputs.custom_shell }} cflags: "${{ inputs.cflags }} -DMLK_FORCE_RISCV64" cross_prefix: riscv64-unknown-linux-gnu- - exec_wrapper: qemu-riscv64 + exec_wrapper: "qemu-riscv64 -cpu rv64,v=true,vlen=256" # This needs to be changed once we test other VLs opt: ${{ inputs.opt }} func: ${{ inputs.func }} kat: ${{ inputs.kat }} diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 92c25ff06..e2d15991f 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -134,16 +134,15 @@ jobs: runs-on: ${{ matrix.target.runner }} steps: - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 - - name: build + test + - name: build + test (no-opt) uses: ./.github/actions/multi-functest with: nix-shell: ${{ matrix.target.nix_shell }} nix-cache: ${{ matrix.target.mode == 'native' && 'false' || 'true' }} gh_token: ${{ secrets.GITHUB_TOKEN }} compile_mode: ${{ matrix.target.mode }} - # There is no native code yet on PPC64LE, R-V or AArch64_be, so no point running opt tests - opt: ${{ (matrix.target.arch != 'ppc64le' && matrix.target.arch != 'riscv64' && matrix.target.arch != 'riscv32' && matrix.target.arch != 'aarch64_be') && 'all' || 'no_opt' }} - - name: build + test (+debug+memsan+ubsan) + opt: 'no_opt' + - name: build + test (+debug+memsan+ubsan, native) uses: ./.github/actions/multi-functest if: ${{ matrix.target.mode == 'native' }} with: @@ -151,6 +150,17 @@ jobs: compile_mode: native cflags: "-DMLKEM_DEBUG -fsanitize=address -fsanitize=undefined -fno-sanitize-recover=all" check_namespace: 'false' + - name: build + test (+debug, cross, opt) + uses: ./.github/actions/multi-functest + # There is no native code yet on PPC64LE, riscv32 or AArch64_be, so no point running opt tests + if: ${{ matrix.target.mode != 'native' && (matrix.target.arch != 'ppc64le' && matrix.target.arch != 'riscv32' && matrix.target.arch != 'aarch64_be') }} + with: + nix-shell: ${{ matrix.target.nix_shell }} + nix-cache: ${{ matrix.target.mode == 'native' && 'false' || 'true' }} + gh_token: ${{ secrets.GITHUB_TOKEN }} + compile_mode: ${{ matrix.target.mode }} + cflags: "-DMLKEM_DEBUG" + opt: 'opt' backend_tests: name: AArch64 FIPS202 backends (${{ matrix.backend }}) strategy: diff --git a/examples/monolithic_build_multilevel_native/Makefile b/examples/monolithic_build_multilevel_native/Makefile index 34f4abb30..ad81d202c 100644 --- a/examples/monolithic_build_multilevel_native/Makefile +++ b/examples/monolithic_build_multilevel_native/Makefile @@ -77,6 +77,7 @@ else ifneq ($(findstring aarch64_be, $(CROSS_PREFIX)),) else ifneq ($(findstring aarch64, $(CROSS_PREFIX)),) CFLAGS += -DMLK_FORCE_AARCH64 else ifneq ($(findstring riscv64, $(CROSS_PREFIX)),) + CFLAGS += -march=rv64gcv_zvl256b CFLAGS += -DMLK_FORCE_RISCV64 else ifneq ($(findstring riscv32, $(CROSS_PREFIX)),) CFLAGS += -DMLK_FORCE_RISCV32 diff --git a/examples/monolithic_build_native/Makefile b/examples/monolithic_build_native/Makefile index e27909053..9ac7ddd33 100644 --- a/examples/monolithic_build_native/Makefile +++ b/examples/monolithic_build_native/Makefile @@ -60,6 +60,41 @@ LIB1024=libmlkem1024.a MLK_OBJS=$(BUILD_DIR)/mlkem_native.c.o $(BUILD_DIR)/mlkem_native.S.o +# Automatically detect system architecture and set preprocessor etc accordingly +HOST_PLATFORM := $(shell uname -s)-$(shell uname -m) + +# linux x86_64 +ifeq ($(HOST_PLATFORM),Linux-x86_64) + CFLAGS += -z noexecstack +endif + +# Native compilation +ifeq ($(CROSS_PREFIX),) +ifeq ($(HOST_PLATFORM),Linux-x86_64) + CFLAGS += -mavx2 -mbmi2 -mpopcnt -maes + CFLAGS += -DMLK_FORCE_X86_64 +else ifeq ($(HOST_PLATFORM),Linux-aarch64) + CFLAGS += -DMLK_FORCE_AARCH64 +else ifeq ($(HOST_PLATFORM),Darwin-arm64) + CFLAGS += -DMLK_FORCE_AARCH64 +endif +# Cross compilation +else ifneq ($(findstring x86_64, $(CROSS_PREFIX)),) + CFLAGS += -mavx2 -mbmi2 -mpopcnt -maes + CFLAGS += -DMLK_FORCE_X86_64 +else ifneq ($(findstring aarch64_be, $(CROSS_PREFIX)),) + CFLAGS += -DMLK_FORCE_AARCH64_EB +else ifneq ($(findstring aarch64, $(CROSS_PREFIX)),) + CFLAGS += -DMLK_FORCE_AARCH64 +else ifneq ($(findstring riscv64, $(CROSS_PREFIX)),) + CFLAGS += -march=rv64gcv_zvl256b + CFLAGS += -DMLK_FORCE_RISCV64 +else ifneq ($(findstring riscv32, $(CROSS_PREFIX)),) + CFLAGS += -DMLK_FORCE_RISCV32 +else ifneq ($(findstring powerpc64le, $(CROSS_PREFIX)),) + CFLAGS += -DMLK_FORCE_PPC64LE +endif + CFLAGS := \ -Wall \ -Wextra \ diff --git a/examples/multilevel_build_native/Makefile b/examples/multilevel_build_native/Makefile index 2217a2193..12380d9ca 100644 --- a/examples/multilevel_build_native/Makefile +++ b/examples/multilevel_build_native/Makefile @@ -44,6 +44,7 @@ else ifneq ($(findstring aarch64_be, $(CROSS_PREFIX)),) else ifneq ($(findstring aarch64, $(CROSS_PREFIX)),) CFLAGS += -DMLK_FORCE_AARCH64 else ifneq ($(findstring riscv64, $(CROSS_PREFIX)),) + CFLAGS += -march=rv64gcv_zvl256b CFLAGS += -DMLK_FORCE_RISCV64 else ifneq ($(findstring riscv32, $(CROSS_PREFIX)),) CFLAGS += -DMLK_FORCE_RISCV32 diff --git a/mlkem/mlkem_native.S b/mlkem/mlkem_native.S index fcdf182eb..6422be33c 100644 --- a/mlkem/mlkem_native.S +++ b/mlkem/mlkem_native.S @@ -85,6 +85,8 @@ #include "mlkem/src/native/x86_64/src/rej_uniform_asm.S" #include "mlkem/src/native/x86_64/src/tomont.S" #endif /* MLK_SYS_X86_64 */ +#if defined(MLK_SYS_RISCV64) +#endif #endif /* MLK_CONFIG_USE_NATIVE_BACKEND_ARITH */ #if defined(MLK_CONFIG_USE_NATIVE_BACKEND_FIPS202) @@ -344,6 +346,7 @@ #undef MLK_SYS_PPC64LE #undef MLK_SYS_RISCV32 #undef MLK_SYS_RISCV64 +#undef MLK_SYS_RISCV64_V256 #undef MLK_SYS_WINDOWS #undef MLK_SYS_X86_64 #undef MLK_SYS_X86_64_AVX2 @@ -552,5 +555,54 @@ #undef MLK_NATIVE_X86_64_SRC_CONSTS_H #undef mlk_qdata #endif /* MLK_SYS_X86_64 */ +#if defined(MLK_SYS_RISCV64) +/* + * Undefine macros from native code (Arith, RISC-V 64) + */ +/* mlkem/src/native/riscv64/meta.h */ +#undef MLK_ARITH_BACKEND_RISCV64 +#undef MLK_NATIVE_RISCV64_META_H +#undef MLK_USE_NATIVE_INTT +#undef MLK_USE_NATIVE_NTT +#undef MLK_USE_NATIVE_POLYVEC_BASEMUL_ACC_MONTGOMERY_CACHED +#undef MLK_USE_NATIVE_POLY_MULCACHE_COMPUTE +#undef MLK_USE_NATIVE_POLY_REDUCE +#undef MLK_USE_NATIVE_POLY_TOMONT +#undef MLK_USE_NATIVE_REJ_UNIFORM +/* mlkem/src/native/riscv64/src/arith_native_riscv64.h */ +#undef MLK_NATIVE_RISCV64_SRC_ARITH_NATIVE_RISCV64_H +#undef mlk_rv64v_poly_add +#undef mlk_rv64v_poly_basemul_mont_add_k2 +#undef mlk_rv64v_poly_basemul_mont_add_k3 +#undef mlk_rv64v_poly_basemul_mont_add_k4 +#undef mlk_rv64v_poly_invntt_tomont +#undef mlk_rv64v_poly_ntt +#undef mlk_rv64v_poly_reduce +#undef mlk_rv64v_poly_sub +#undef mlk_rv64v_poly_tomont +#undef mlk_rv64v_rej_uniform +/* mlkem/src/native/riscv64/src/rv64v_debug.h */ +#undef MLK_NATIVE_RISCV64_SRC_RV64V_DEBUG_H +#undef mlk_assert_abs_bound_int16m1 +#undef mlk_assert_abs_bound_int16m2 +#undef mlk_assert_bound_int16m1 +#undef mlk_assert_bound_int16m2 +#undef mlk_debug_check_bounds_int16m1 +#undef mlk_debug_check_bounds_int16m2 +/* mlkem/src/native/riscv64/src/rv64v_settings.h */ +#undef MLK_NATIVE_RISCV64_SRC_RV64V_SETTINGS_H +#undef MLK_RVV_E16M1_VL +#undef MLK_RVV_MONT_NR +#undef MLK_RVV_MONT_R1 +#undef MLK_RVV_MONT_R2 +#undef MLK_RVV_QI +#undef MLK_RVV_VLEN +#undef mlk_assert_abs_bound_int16m1 +#undef mlk_assert_abs_bound_int16m2 +#undef mlk_assert_bound_int16m1 +#undef mlk_assert_bound_int16m2 +#undef mlk_debug_check_bounds_int16m1 +#undef mlk_debug_check_bounds_int16m2 +#endif /* MLK_SYS_RISCV64 */ #endif /* MLK_CONFIG_USE_NATIVE_BACKEND_ARITH */ #endif /* !MLK_CONFIG_MONOBUILD_KEEP_SHARED_HEADERS */ diff --git a/mlkem/mlkem_native.c b/mlkem/mlkem_native.c index d846f9f55..ec9d7c831 100644 --- a/mlkem/mlkem_native.c +++ b/mlkem/mlkem_native.c @@ -84,6 +84,10 @@ #include "src/native/x86_64/src/consts.c" #include "src/native/x86_64/src/rej_uniform_table.c" #endif +#if defined(MLK_SYS_RISCV64) +#include "src/native/riscv64/src/rv64v_debug.c" +#include "src/native/riscv64/src/rv64v_poly.c" +#endif #endif /* MLK_CONFIG_USE_NATIVE_BACKEND_ARITH */ #if defined(MLK_CONFIG_USE_NATIVE_BACKEND_FIPS202) @@ -331,6 +335,7 @@ #undef MLK_SYS_PPC64LE #undef MLK_SYS_RISCV32 #undef MLK_SYS_RISCV64 +#undef MLK_SYS_RISCV64_V256 #undef MLK_SYS_WINDOWS #undef MLK_SYS_X86_64 #undef MLK_SYS_X86_64_AVX2 @@ -539,5 +544,54 @@ #undef MLK_NATIVE_X86_64_SRC_CONSTS_H #undef mlk_qdata #endif /* MLK_SYS_X86_64 */ +#if defined(MLK_SYS_RISCV64) +/* + * Undefine macros from native code (Arith, RISC-V 64) + */ +/* mlkem/src/native/riscv64/meta.h */ +#undef MLK_ARITH_BACKEND_RISCV64 +#undef MLK_NATIVE_RISCV64_META_H +#undef MLK_USE_NATIVE_INTT +#undef MLK_USE_NATIVE_NTT +#undef MLK_USE_NATIVE_POLYVEC_BASEMUL_ACC_MONTGOMERY_CACHED +#undef MLK_USE_NATIVE_POLY_MULCACHE_COMPUTE +#undef MLK_USE_NATIVE_POLY_REDUCE +#undef MLK_USE_NATIVE_POLY_TOMONT +#undef MLK_USE_NATIVE_REJ_UNIFORM +/* mlkem/src/native/riscv64/src/arith_native_riscv64.h */ +#undef MLK_NATIVE_RISCV64_SRC_ARITH_NATIVE_RISCV64_H +#undef mlk_rv64v_poly_add +#undef mlk_rv64v_poly_basemul_mont_add_k2 +#undef mlk_rv64v_poly_basemul_mont_add_k3 +#undef mlk_rv64v_poly_basemul_mont_add_k4 +#undef mlk_rv64v_poly_invntt_tomont +#undef mlk_rv64v_poly_ntt +#undef mlk_rv64v_poly_reduce +#undef mlk_rv64v_poly_sub +#undef mlk_rv64v_poly_tomont +#undef mlk_rv64v_rej_uniform +/* mlkem/src/native/riscv64/src/rv64v_debug.h */ +#undef MLK_NATIVE_RISCV64_SRC_RV64V_DEBUG_H +#undef mlk_assert_abs_bound_int16m1 +#undef mlk_assert_abs_bound_int16m2 +#undef mlk_assert_bound_int16m1 +#undef mlk_assert_bound_int16m2 +#undef mlk_debug_check_bounds_int16m1 +#undef mlk_debug_check_bounds_int16m2 +/* mlkem/src/native/riscv64/src/rv64v_settings.h */ +#undef MLK_NATIVE_RISCV64_SRC_RV64V_SETTINGS_H +#undef MLK_RVV_E16M1_VL +#undef MLK_RVV_MONT_NR +#undef MLK_RVV_MONT_R1 +#undef MLK_RVV_MONT_R2 +#undef MLK_RVV_QI +#undef MLK_RVV_VLEN +#undef mlk_assert_abs_bound_int16m1 +#undef mlk_assert_abs_bound_int16m2 +#undef mlk_assert_bound_int16m1 +#undef mlk_assert_bound_int16m2 +#undef mlk_debug_check_bounds_int16m1 +#undef mlk_debug_check_bounds_int16m2 +#endif /* MLK_SYS_RISCV64 */ #endif /* MLK_CONFIG_USE_NATIVE_BACKEND_ARITH */ #endif /* !MLK_CONFIG_MONOBUILD_KEEP_SHARED_HEADERS */ diff --git a/mlkem/src/native/meta.h b/mlkem/src/native/meta.h index f2b9b848b..a976eb2ef 100644 --- a/mlkem/src/native/meta.h +++ b/mlkem/src/native/meta.h @@ -18,4 +18,8 @@ #include "x86_64/meta.h" #endif +#if defined(MLK_SYS_RISCV64_V256) +#include "riscv64/meta.h" +#endif + #endif /* !MLK_NATIVE_META_H */ diff --git a/mlkem/src/native/riscv64/README.md b/mlkem/src/native/riscv64/README.md new file mode 100644 index 000000000..df28916bf --- /dev/null +++ b/mlkem/src/native/riscv64/README.md @@ -0,0 +1,23 @@ +[//]: # (SPDX-License-Identifier: CC-BY-4.0) + +# RISC-V Vector Extension Backend + +Notes on RISC-V Vector support. + +**WARNING** This is highly experimental. Currently vlen=256 only. + +## Implementation Status + +This implementation is inferior to the AArch64 and AVX2 backends in the following ways: + +- **Verification**: No formal verification has been performed on this backend +- **Testing**: Uses the same functional tests as AVX2 and AArch64 backends, but lacks extensive real-hardware testing +- **Audit**: This is new code that has not yet received thorough review or widespread use in production environments + +## Requirements + +- RISC-V 64-bit architecture +- Vector extension (RVV) version 1.0 +- Minimum vector length (VLEN) of 256 bits +- Standard "gc" extensions (integer and compressed instructions) + diff --git a/mlkem/src/native/riscv64/meta.h b/mlkem/src/native/riscv64/meta.h new file mode 100644 index 000000000..7b621ae46 --- /dev/null +++ b/mlkem/src/native/riscv64/meta.h @@ -0,0 +1,102 @@ +/* + * Copyright (c) The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT + */ + +#ifndef MLK_NATIVE_RISCV64_META_H +#define MLK_NATIVE_RISCV64_META_H + +/* Identifier for this backend so that source and assembly files + * in the build can be appropriately guarded. */ +#define MLK_ARITH_BACKEND_RISCV64 + +/* Set of primitives that this backend replaces */ +#define MLK_USE_NATIVE_NTT +#define MLK_USE_NATIVE_INTT +#define MLK_USE_NATIVE_POLY_TOMONT +#define MLK_USE_NATIVE_REJ_UNIFORM +#define MLK_USE_NATIVE_POLY_REDUCE +#define MLK_USE_NATIVE_POLY_MULCACHE_COMPUTE +#define MLK_USE_NATIVE_POLYVEC_BASEMUL_ACC_MONTGOMERY_CACHED + +#include "../../common.h" + +#if !defined(__ASSEMBLER__) +#include "../api.h" +#include "src/arith_native_riscv64.h" + +static MLK_INLINE int mlk_ntt_native(int16_t data[MLKEM_N]) +{ + mlk_rv64v_poly_ntt(data); + return MLK_NATIVE_FUNC_SUCCESS; +} + +static MLK_INLINE int mlk_intt_native(int16_t data[MLKEM_N]) +{ + mlk_rv64v_poly_invntt_tomont(data); + return MLK_NATIVE_FUNC_SUCCESS; +} + +static MLK_INLINE int mlk_poly_tomont_native(int16_t data[MLKEM_N]) +{ + mlk_rv64v_poly_tomont(data); + return MLK_NATIVE_FUNC_SUCCESS; +} + +static MLK_INLINE int mlk_rej_uniform_native(int16_t *r, unsigned len, + const uint8_t *buf, + unsigned buflen) +{ + return mlk_rv64v_rej_uniform(r, len, buf, buflen); +} + +static MLK_INLINE int mlk_poly_reduce_native(int16_t data[MLKEM_N]) +{ + mlk_rv64v_poly_reduce(data); + return MLK_NATIVE_FUNC_SUCCESS; +} + +static MLK_INLINE int mlk_poly_mulcache_compute_native(int16_t x[MLKEM_N / 2], + const int16_t y[MLKEM_N]) +{ + (void)x; /* not using the cache atm */ + (void)y; + return MLK_NATIVE_FUNC_SUCCESS; +} + +#if defined(MLK_CONFIG_MULTILEVEL_WITH_SHARED) || MLKEM_K == 2 +static MLK_INLINE int mlk_polyvec_basemul_acc_montgomery_cached_k2_native( + int16_t r[MLKEM_N], const int16_t a[2 * MLKEM_N], + const int16_t b[2 * MLKEM_N], const int16_t b_cache[2 * (MLKEM_N / 2)]) +{ + (void)b_cache; + mlk_rv64v_poly_basemul_mont_add_k2(r, a, b); + return MLK_NATIVE_FUNC_SUCCESS; +} +#endif /* MLK_CONFIG_MULTILEVEL_WITH_SHARED || MLKEM_K == 2 */ + +#if defined(MLK_CONFIG_MULTILEVEL_WITH_SHARED) || MLKEM_K == 3 +static MLK_INLINE int mlk_polyvec_basemul_acc_montgomery_cached_k3_native( + int16_t r[MLKEM_N], const int16_t a[3 * MLKEM_N], + const int16_t b[3 * MLKEM_N], const int16_t b_cache[3 * (MLKEM_N / 2)]) +{ + (void)b_cache; + mlk_rv64v_poly_basemul_mont_add_k3(r, a, b); + return MLK_NATIVE_FUNC_SUCCESS; +} +#endif /* MLK_CONFIG_MULTILEVEL_WITH_SHARED || MLKEM_K == 3 */ + +#if defined(MLK_CONFIG_MULTILEVEL_WITH_SHARED) || MLKEM_K == 4 +static MLK_INLINE int mlk_polyvec_basemul_acc_montgomery_cached_k4_native( + int16_t r[MLKEM_N], const int16_t a[4 * MLKEM_N], + const int16_t b[4 * MLKEM_N], const int16_t b_cache[4 * (MLKEM_N / 2)]) +{ + (void)b_cache; + mlk_rv64v_poly_basemul_mont_add_k4(r, a, b); + return MLK_NATIVE_FUNC_SUCCESS; +} +#endif /* MLK_CONFIG_MULTILEVEL_WITH_SHARED || MLKEM_K == 4 */ + +#endif /* !__ASSEMBLER__ */ + +#endif /* !MLK_NATIVE_RISCV64_META_H */ diff --git a/mlkem/src/native/riscv64/src/arith_native_riscv64.h b/mlkem/src/native/riscv64/src/arith_native_riscv64.h new file mode 100644 index 000000000..a7749df5a --- /dev/null +++ b/mlkem/src/native/riscv64/src/arith_native_riscv64.h @@ -0,0 +1,45 @@ +/* + * Copyright (c) The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT + */ +#ifndef MLK_NATIVE_RISCV64_SRC_ARITH_NATIVE_RISCV64_H +#define MLK_NATIVE_RISCV64_SRC_ARITH_NATIVE_RISCV64_H + +#include +#include "../../../common.h" + +#define mlk_rv64v_poly_ntt MLK_NAMESPACE(ntt_riscv64) +void mlk_rv64v_poly_ntt(int16_t *); + +#define mlk_rv64v_poly_invntt_tomont MLK_NAMESPACE(intt_riscv64) +void mlk_rv64v_poly_invntt_tomont(int16_t *r); + +#define mlk_rv64v_poly_basemul_mont_add_k2 MLK_NAMESPACE(basemul_add_k2_riscv64) +void mlk_rv64v_poly_basemul_mont_add_k2(int16_t *r, const int16_t *a, + const int16_t *b); + +#define mlk_rv64v_poly_basemul_mont_add_k3 MLK_NAMESPACE(basemul_add_k3_riscv64) +void mlk_rv64v_poly_basemul_mont_add_k3(int16_t *r, const int16_t *a, + const int16_t *b); + +#define mlk_rv64v_poly_basemul_mont_add_k4 MLK_NAMESPACE(basemul_add_k4_riscv64) +void mlk_rv64v_poly_basemul_mont_add_k4(int16_t *r, const int16_t *a, + const int16_t *b); + +#define mlk_rv64v_poly_tomont MLK_NAMESPACE(tomont_riscv64) +void mlk_rv64v_poly_tomont(int16_t *r); + +#define mlk_rv64v_poly_reduce MLK_NAMESPACE(reduce_riscv64) +void mlk_rv64v_poly_reduce(int16_t *r); + +#define mlk_rv64v_poly_add MLK_NAMESPACE(poly_add_riscv64) +void mlk_rv64v_poly_add(int16_t *r, const int16_t *a, const int16_t *b); + +#define mlk_rv64v_poly_sub MLK_NAMESPACE(poly_sub_riscv64) +void mlk_rv64v_poly_sub(int16_t *r, const int16_t *a, const int16_t *b); + +#define mlk_rv64v_rej_uniform MLK_NAMESPACE(rj_uniform_riscv64) +unsigned int mlk_rv64v_rej_uniform(int16_t *r, unsigned int len, + const uint8_t *buf, unsigned int buflen); + +#endif /* !MLK_NATIVE_RISCV64_SRC_ARITH_NATIVE_RISCV64_H */ diff --git a/mlkem/src/native/riscv64/src/rv64v_debug.c b/mlkem/src/native/riscv64/src/rv64v_debug.c new file mode 100644 index 000000000..6a8d7e9aa --- /dev/null +++ b/mlkem/src/native/riscv64/src/rv64v_debug.c @@ -0,0 +1,81 @@ +/* + * Copyright (c) The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT + */ + +/* NOTE: You can remove this file unless you compile with MLKEM_DEBUG. */ + +#include "../../../common.h" + +#if defined(MLK_ARITH_BACKEND_RISCV64) && \ + !defined(MLK_CONFIG_MULTILEVEL_NO_SHARED) && defined(MLKEM_DEBUG) + +#include +#include +#include +#include "../../../debug.h" +#include "rv64v_settings.h" + +#define MLK_DEBUG_ERROR_HEADER "[ERROR:%s:%04d] " + +/************************************************* + * Name: mlk_debug_check_bounds_int16m1 + * + * Description: Check whether values in a vint16m1_t vector + * are within specified bounds. + * + * Implementation: Extract vector elements to a temporary array + * and reuse existing array bounds checking. + **************************************************/ +void mlk_debug_check_bounds_int16m1(const char *file, int line, vint16m1_t vec, + size_t vl, int lower_bound_exclusive, + int upper_bound_exclusive) +{ + /* Allocate temporary array to store vector elements + * We use the maximum possible vector length to be safe */ + int16_t temp_array[MLK_RVV_E16M1_VL]; + + /* Store vector elements to temporary array for inspection */ + __riscv_vse16_v_i16m1(temp_array, vec, vl); + + /* Reuse existing array bounds checking function */ + mlk_debug_check_bounds(file, line, temp_array, (unsigned)vl, + lower_bound_exclusive, upper_bound_exclusive); +} + +/************************************************* + * Name: mlk_debug_check_bounds_int16m2 + * + * Description: Check whether values in a vint16m2_t vector + * are within specified bounds. + * + * Implementation: Extract vector elements to a temporary array + * and reuse existing array bounds checking. + **************************************************/ +void mlk_debug_check_bounds_int16m2(const char *file, int line, vint16m2_t vec, + size_t vl, int lower_bound_exclusive, + int upper_bound_exclusive) +{ + /* Allocate temporary array to store vector elements + * m2 vectors hold 2x the elements of m1 vectors */ + int16_t temp_array[2 * MLK_RVV_E16M1_VL]; + + /* Store vector elements to temporary array for inspection */ + __riscv_vse16_v_i16m2(temp_array, vec, 2 * vl); + + /* Reuse existing array bounds checking function for all elements */ + mlk_debug_check_bounds(file, line, temp_array, (unsigned)(2 * vl), + lower_bound_exclusive, upper_bound_exclusive); +} + +#else /* MLK_ARITH_BACKEND_RISCV64 && !MLK_CONFIG_MULTILEVEL_NO_SHARED && \ + MLKEM_DEBUG */ + +MLK_EMPTY_CU(rv64v_debug) + +#endif /* !(MLK_ARITH_BACKEND_RISCV64 && !MLK_CONFIG_MULTILEVEL_NO_SHARED && \ + MLKEM_DEBUG) */ + +/* To facilitate single-compilation-unit (SCU) builds, undefine all macros. + * Don't modify by hand -- this is auto-generated by scripts/autogen. */ +#undef MLK_DEBUG_ERROR_HEADER diff --git a/mlkem/src/native/riscv64/src/rv64v_debug.h b/mlkem/src/native/riscv64/src/rv64v_debug.h new file mode 100644 index 000000000..929c1a60f --- /dev/null +++ b/mlkem/src/native/riscv64/src/rv64v_debug.h @@ -0,0 +1,144 @@ +/* + * Copyright (c) The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT + */ +#ifndef MLK_NATIVE_RISCV64_SRC_RV64V_DEBUG_H +#define MLK_NATIVE_RISCV64_SRC_RV64V_DEBUG_H + +#include +#include "../../../debug.h" + +/************************************************* + * RISC-V Vector Bounds Assertion Macros + * + * These macros provide runtime bounds checking for RISC-V vector types + * vint16m1_t and vint16m2_t, following the same pattern as the scalar + * bounds assertions in debug.h + * + * The macros are only active when MLKEM_DEBUG is defined, otherwise they + * compile to no-ops for zero runtime overhead in release builds. + **************************************************/ + +#if defined(MLKEM_DEBUG) + +/************************************************* + * Name: mlk_debug_check_bounds_int16m1 + * + * Description: Check whether values in a vint16m1_t vector + * are within specified bounds. + * + * Arguments: - file: filename + * - line: line number + * - vec: RISC-V vector to be checked + * - vl: vector length (number of active elements) + * - lower_bound_exclusive: Exclusive lower bound + * - upper_bound_exclusive: Exclusive upper bound + **************************************************/ +#define mlk_debug_check_bounds_int16m1 \ + MLK_NAMESPACE(mlkem_debug_check_bounds_int16m1) +void mlk_debug_check_bounds_int16m1(const char *file, int line, vint16m1_t vec, + size_t vl, int lower_bound_exclusive, + int upper_bound_exclusive); + +/************************************************* + * Name: mlk_debug_check_bounds_int16m2 + * + * Description: Check whether values in a vint16m2_t vector + * are within specified bounds by splitting into m1 vectors. + * + * Arguments: - file: filename + * - line: line number + * - vec: RISC-V vector to be checked + * - vl: vector length (number of active elements per m1 half) + * - lower_bound_exclusive: Exclusive lower bound + * - upper_bound_exclusive: Exclusive upper bound + **************************************************/ +#define mlk_debug_check_bounds_int16m2 \ + MLK_NAMESPACE(mlkem_debug_check_bounds_int16m2) +void mlk_debug_check_bounds_int16m2(const char *file, int line, vint16m2_t vec, + size_t vl, int lower_bound_exclusive, + int upper_bound_exclusive); + +/* Check bounds in vint16m1_t vector + * vec: RISC-V vector of type vint16m1_t + * vl: Vector length (number of active elements) + * value_lb: Inclusive lower value bound + * value_ub: Exclusive upper value bound */ +#define mlk_assert_bound_int16m1(vec, vl, value_lb, value_ub) \ + mlk_debug_check_bounds_int16m1(__FILE__, __LINE__, (vec), (vl), \ + (value_lb) - 1, (value_ub)) + +/* Check absolute bounds in vint16m1_t vector + * vec: RISC-V vector of type vint16m1_t + * vl: Vector length (number of active elements) + * value_abs_bd: Exclusive absolute upper bound */ +#define mlk_assert_abs_bound_int16m1(vec, vl, value_abs_bd) \ + mlk_assert_bound_int16m1((vec), (vl), (-(value_abs_bd) + 1), (value_abs_bd)) + +/* Check bounds in vint16m2_t vector + * vec: RISC-V vector of type vint16m2_t + * vl: Vector length (number of active elements per m1 half) + * value_lb: Inclusive lower value bound + * value_ub: Exclusive upper value bound */ +#define mlk_assert_bound_int16m2(vec, vl, value_lb, value_ub) \ + mlk_debug_check_bounds_int16m2(__FILE__, __LINE__, (vec), (vl), \ + (value_lb) - 1, (value_ub)) + +/* Check absolute bounds in vint16m2_t vector + * vec: RISC-V vector of type vint16m2_t + * vl: Vector length (number of active elements per m1 half) + * value_abs_bd: Exclusive absolute upper bound */ +#define mlk_assert_abs_bound_int16m2(vec, vl, value_abs_bd) \ + mlk_assert_bound_int16m2((vec), (vl), (-(value_abs_bd) + 1), (value_abs_bd)) + +#elif defined(CBMC) + +/* For CBMC, we would need to implement vector bounds checking using CBMC + * primitives This is complex and would require extracting vector elements, so + * for now we provide empty implementations that could be extended later */ +#define mlk_assert_bound_int16m1(vec, vl, value_lb, value_ub) \ + do \ + { \ + } while (0) + +#define mlk_assert_abs_bound_int16m1(vec, vl, value_abs_bd) \ + do \ + { \ + } while (0) + +#define mlk_assert_bound_int16m2(vec, vl, value_lb, value_ub) \ + do \ + { \ + } while (0) + +#define mlk_assert_abs_bound_int16m2(vec, vl, value_abs_bd) \ + do \ + { \ + } while (0) + +#else /* !MLKEM_DEBUG && CBMC */ + +/* When debugging is disabled, all assertions become no-ops */ +#define mlk_assert_bound_int16m1(vec, vl, value_lb, value_ub) \ + do \ + { \ + } while (0) + +#define mlk_assert_abs_bound_int16m1(vec, vl, value_abs_bd) \ + do \ + { \ + } while (0) + +#define mlk_assert_bound_int16m2(vec, vl, value_lb, value_ub) \ + do \ + { \ + } while (0) + +#define mlk_assert_abs_bound_int16m2(vec, vl, value_abs_bd) \ + do \ + { \ + } while (0) + +#endif /* !MLKEM_DEBUG && !CBMC */ + +#endif /* !MLK_NATIVE_RISCV64_SRC_RV64V_DEBUG_H */ diff --git a/mlkem/src/native/riscv64/src/rv64v_izetas.inc b/mlkem/src/native/riscv64/src/rv64v_izetas.inc new file mode 100644 index 000000000..52ab11333 --- /dev/null +++ b/mlkem/src/native/riscv64/src/rv64v_izetas.inc @@ -0,0 +1,28 @@ +/* + * Copyright (c) The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT + */ + +/* + * WARNING: This file is auto-generated from scripts/autogen + * in the mlkem-native repository. + * Do not modify it directly. + */ + +#include +#include "arith_native_riscv64.h" + +const int16_t izeta[] = { + -1044, 758, 1571, 205, 1275, -677, 1065, -448, -1628, -1522, 1460, + -958, -991, -996, 308, 108, 1517, 359, -411, 1542, 725, 1508, + -961, 398, -478, 870, 854, 1510, -794, 1278, 1530, 1185, -202, + -287, -608, -732, 951, 247, 1421, -107, 1659, 1187, -220, 874, + 1335, -1218, 136, 1215, -1422, -1493, -1017, 681, -830, 271, 90, + 853, -384, 1465, 1285, -1322, -610, -603, -1097, -817, -1468, 1474, + 130, 1602, -1469, -126, 1162, 1618, 75, 156, -329, -418, -349, + 872, -644, 1590, 1202, -962, -1458, 829, 666, 320, 8, -516, + -1119, 602, -1483, 777, 147, -1159, -778, 246, -182, -1577, -383, + -264, 1544, 282, -1491, 1293, -1653, -1574, 460, 291, 235, -177, + -587, -422, -622, 171, 1325, -573, -1015, 552, -652, -1223, -105, + -1550, -871, 1251, -843, -555, -430, 1103, +}; diff --git a/mlkem/src/native/riscv64/src/rv64v_poly.c b/mlkem/src/native/riscv64/src/rv64v_poly.c new file mode 100644 index 000000000..ae1f07b78 --- /dev/null +++ b/mlkem/src/native/riscv64/src/rv64v_poly.c @@ -0,0 +1,786 @@ +/* + * Copyright (c) The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT + */ + +/* === Kyber NTT using RISC-V Vector intrinstics */ + +#include "../../../common.h" + +#if defined(MLK_ARITH_BACKEND_RISCV64) && \ + !defined(MLK_CONFIG_MULTILEVEL_NO_SHARED) + +#include +#include "arith_native_riscv64.h" +#include "rv64v_settings.h" + +/* vector configuration */ +#ifndef MLK_RVV_VLEN +#define MLK_RVV_VLEN 256 +#endif + +/* vl value for a 16-bit wide type */ +#define MLK_RVV_E16M1_VL (MLK_RVV_VLEN / 16) + +/* Montgomery reduction constants */ +/* check-magic: 3327 == unsigned_mod(-pow(MLKEM_Q,-1,2^16), 2^16) */ +#define MLK_RVV_QI 3327 + +/* check-magic: 2285 == unsigned_mod(2^16, MLKEM_Q) */ +#define MLK_RVV_MONT_R1 2285 + +/* check-magic: 1353 == pow(2, 32, MLKEM_Q) */ +#define MLK_RVV_MONT_R2 1353 + +/* check-magic: 1441 == pow(2,32-7,MLKEM_Q) */ +#define MLK_RVV_MONT_NR 1441 + +static inline vint16m1_t fq_redc(vint16m1_t rh, vint16m1_t rl, size_t vl) +{ + vint16m1_t t; + vbool16_t c; + + t = __riscv_vmul_vx_i16m1(rl, MLK_RVV_QI, vl); /* t = l * -Q^-1 */ + t = __riscv_vmulh_vx_i16m1(t, MLKEM_Q, vl); /* t = (t*Q) / R */ + c = __riscv_vmsne_vx_i16m1_b16(rl, 0, vl); /* c = (l != 0) */ + t = __riscv_vadc_vvm_i16m1(t, rh, c, vl); /* t += h + c */ + + return t; +} + +/* Narrowing reduction */ + +static inline vint16m1_t fq_redc2(vint32m2_t z, size_t vl) +{ + vint16m1_t t; + + t = __riscv_vmul_vx_i16m1(__riscv_vncvt_x_x_w_i16m1(z, vl), MLK_RVV_QI, + vl); /* t = l * -Q^-1 */ + z = __riscv_vadd_vv_i32m2(z, __riscv_vwmul_vx_i32m2(t, MLKEM_Q, vl), + vl); /* x = (x + (t*Q)) */ + t = __riscv_vnsra_wx_i16m1(z, 16, vl); + + return t; +} + +/* Narrowing Barrett (per original Kyber) */ + +static inline vint16m1_t fq_barrett(vint16m1_t a, size_t vl) +{ + vint16m1_t t; + const int16_t v = ((1 << 26) + MLKEM_Q / 2) / MLKEM_Q; + + t = __riscv_vmulh_vx_i16m1(a, v, vl); + t = __riscv_vadd_vx_i16m1(t, 1 << (25 - 16), vl); + t = __riscv_vsra_vx_i16m1(t, 26 - 16, vl); + t = __riscv_vmul_vx_i16m1(t, MLKEM_Q, vl); + t = __riscv_vsub_vv_i16m1(a, t, vl); + + mlk_assert_abs_bound_int16m1(t, vl, MLKEM_Q_HALF); + return t; +} + +/* Conditionally add Q (if negative) */ + +static inline vint16m1_t fq_cadd(vint16m1_t rx, size_t vl) +{ + vbool16_t bn; + + bn = __riscv_vmslt_vx_i16m1_b16(rx, 0, vl); /* if x < 0: */ + rx = __riscv_vadd_vx_i16m1_mu(bn, rx, rx, MLKEM_Q, vl); /* x += Q */ + return rx; +} + +/* Conditionally subtract Q (if Q or above) */ + +static inline vint16m1_t fq_csub(vint16m1_t rx, size_t vl) +{ + vbool16_t bn; + + bn = __riscv_vmsge_vx_i16m1_b16(rx, MLKEM_Q, vl); /* if x >= Q: */ + rx = __riscv_vsub_vx_i16m1_mu(bn, rx, rx, MLKEM_Q, vl); /* x -= Q */ + return rx; +} + +/* Montgomery multiply: vector-vector */ + +static inline vint16m1_t fq_mul_vv(vint16m1_t rx, vint16m1_t ry, size_t vl) +{ + vint16m1_t rl, rh; + + rh = __riscv_vmulh_vv_i16m1(rx, ry, vl); /* h = (x * y) / R */ + rl = __riscv_vmul_vv_i16m1(rx, ry, vl); /* l = (x * y) % R */ + return fq_redc(rh, rl, vl); +} + +/* Montgomery multiply: vector-scalar */ + +static inline vint16m1_t fq_mul_vx(vint16m1_t rx, int16_t ry, size_t vl) +{ + vint16m1_t rl, rh; + + rh = __riscv_vmulh_vx_i16m1(rx, ry, vl); /* h = (x * y) / R */ + rl = __riscv_vmul_vx_i16m1(rx, ry, vl); /* l = (x * y) % R */ + return fq_redc(rh, rl, vl); +} + +/* full normalization */ + +static inline vint16m1_t fq_mulq_vx(vint16m1_t rx, int16_t ry, size_t vl) +{ + vint16m1_t result; + + result = fq_cadd(fq_mul_vx(rx, ry, vl), vl); + + mlk_assert_bound_int16m1(result, vl, 0, MLKEM_Q); + return result; +} + +/* create a permutation for swapping index bits a and b, a < b */ + +static vuint16m2_t bitswap_perm(unsigned a, unsigned b, size_t vl) +{ + const vuint16m2_t v2id = __riscv_vid_v_u16m2(vl); + + vuint16m2_t xa, xb; + xa = __riscv_vsrl_vx_u16m2(v2id, b - a, vl); + xa = __riscv_vxor_vv_u16m2(xa, v2id, vl); + xa = __riscv_vand_vx_u16m2(xa, (1 << a), vl); + xb = __riscv_vsll_vx_u16m2(xa, b - a, vl); + xa = __riscv_vxor_vv_u16m2(xa, xb, vl); + xa = __riscv_vxor_vv_u16m2(v2id, xa, vl); + return xa; +} + +/* NOTE: NTT is currently fixed for vlen==256 */ + +#if (MLK_RVV_VLEN == 256) + +/************************************************* + * Name: poly_ntt + * + * Description: Computes negacyclic number-theoretic transform (NTT) of + * a polynomial in place; + * inputs assumed to be in normal order, output in + * bitreversed order + * + * Arguments: - uint16_t *r: pointer to in/output polynomial + **************************************************/ + +/* forward butterfly operation */ + +#define MLK_RVV_BFLY_FX(u0, u1, ut, uc, vl, layer) \ + { \ + mlk_assert_abs_bound(&uc, 1, MLKEM_Q_HALF); \ + \ + ut = fq_mul_vx(u1, uc, vl); \ + u1 = __riscv_vsub_vv_i16m1(u0, ut, vl); \ + u0 = __riscv_vadd_vv_i16m1(u0, ut, vl); \ + \ + mlk_assert_abs_bound_int16m1(u0, vl, (layer + 1) * MLKEM_Q); \ + mlk_assert_abs_bound_int16m1(u1, vl, (layer + 1) * MLKEM_Q); \ + mlk_assert_abs_bound_int16m1(ut, vl, MLKEM_Q); \ + } + +#define MLK_RVV_BFLY_FV(u0, u1, ut, uc, vl, layer) \ + { \ + mlk_assert_abs_bound_int16m1(uc, vl, MLKEM_Q_HALF); \ + \ + ut = fq_mul_vv(u1, uc, vl); \ + u1 = __riscv_vsub_vv_i16m1(u0, ut, vl); \ + u0 = __riscv_vadd_vv_i16m1(u0, ut, vl); \ + \ + mlk_assert_abs_bound_int16m1(u0, vl, (layer + 1) * MLKEM_Q); \ + mlk_assert_abs_bound_int16m1(u1, vl, (layer + 1) * MLKEM_Q); \ + mlk_assert_abs_bound_int16m1(ut, vl, MLKEM_Q); \ + } + +static vint16m2_t mlk_rv64v_ntt2(vint16m2_t vp, vint16m1_t cz) +{ + size_t vl = MLK_RVV_E16M1_VL; + size_t vl2 = 2 * vl; + + const vuint16m2_t v2p8 = bitswap_perm(3, 4, vl2); + const vuint16m2_t v2p4 = bitswap_perm(2, 4, vl2); + const vuint16m2_t v2p2 = bitswap_perm(1, 4, vl2); + + /* p1 = p8(p4(p2)) */ + const vuint16m2_t v2p1 = __riscv_vrgather_vv_u16m2( + __riscv_vrgather_vv_u16m2(v2p2, v2p4, vl2), v2p8, vl2); + + const vuint16m1_t vid = __riscv_vid_v_u16m1(vl); + const vuint16m1_t cs8 = + __riscv_vadd_vx_u16m1(__riscv_vsrl_vx_u16m1(vid, 3, vl), 2, vl); + const vuint16m1_t cs4 = + __riscv_vadd_vx_u16m1(__riscv_vsrl_vx_u16m1(vid, 2, vl), 2 + 2, vl); + const vuint16m1_t cs2 = + __riscv_vadd_vx_u16m1(__riscv_vsrl_vx_u16m1(vid, 1, vl), 2 + 2 + 4, vl); + + vint16m1_t vt, c0, t0, t1; + + /* swap 8 */ + vp = __riscv_vrgatherei16_vv_i16m2(vp, v2p8, vl2); + t0 = __riscv_vget_v_i16m2_i16m1(vp, 0); + t1 = __riscv_vget_v_i16m2_i16m1(vp, 1); + + c0 = __riscv_vrgather_vv_i16m1(cz, cs8, vl); + MLK_RVV_BFLY_FV(t0, t1, vt, c0, vl, 5); + + /* swap 4 */ + vp = __riscv_vcreate_v_i16m1_i16m2(t0, t1); + vp = __riscv_vrgatherei16_vv_i16m2(vp, v2p4, vl2); + t0 = __riscv_vget_v_i16m2_i16m1(vp, 0); + t1 = __riscv_vget_v_i16m2_i16m1(vp, 1); + + c0 = __riscv_vrgather_vv_i16m1(cz, cs4, vl); + MLK_RVV_BFLY_FV(t0, t1, vt, c0, vl, 6); + + /* swap 2 */ + vp = __riscv_vcreate_v_i16m1_i16m2(t0, t1); + vp = __riscv_vrgatherei16_vv_i16m2(vp, v2p2, vl2); + t0 = __riscv_vget_v_i16m2_i16m1(vp, 0); + t1 = __riscv_vget_v_i16m2_i16m1(vp, 1); + + c0 = __riscv_vrgather_vv_i16m1(cz, cs2, vl); + MLK_RVV_BFLY_FV(t0, t1, vt, c0, vl, 7); + + /* normalize */ + t0 = fq_mulq_vx(t0, MLK_RVV_MONT_R1, vl); + t1 = fq_mulq_vx(t1, MLK_RVV_MONT_R1, vl); + + /* reorganize */ + vp = __riscv_vcreate_v_i16m1_i16m2(t0, t1); + vp = __riscv_vrgatherei16_vv_i16m2(vp, v2p1, vl2); + + return vp; +} + +void mlk_rv64v_poly_ntt(int16_t *r) +{ +/* zetas can be compiled into vector constants; don't pass as a pointer */ +#include "rv64v_zetas.inc" + + size_t vl = MLK_RVV_E16M1_VL; + size_t vl2 = 2 * vl; + + vint16m1_t vt; + vint16m1_t v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, va, vb, vc, vd, ve, vf; + + const vint16m1_t z0 = __riscv_vle16_v_i16m1(&zeta[0x00], vl); + const vint16m1_t z2 = __riscv_vle16_v_i16m1(&zeta[0x10], vl); + const vint16m1_t z4 = __riscv_vle16_v_i16m1(&zeta[0x20], vl); + const vint16m1_t z6 = __riscv_vle16_v_i16m1(&zeta[0x30], vl); + const vint16m1_t z8 = __riscv_vle16_v_i16m1(&zeta[0x40], vl); + const vint16m1_t za = __riscv_vle16_v_i16m1(&zeta[0x50], vl); + const vint16m1_t zc = __riscv_vle16_v_i16m1(&zeta[0x60], vl); + const vint16m1_t ze = __riscv_vle16_v_i16m1(&zeta[0x70], vl); + + v0 = __riscv_vle16_v_i16m1(&r[0x00], vl); + v1 = __riscv_vle16_v_i16m1(&r[0x10], vl); + v2 = __riscv_vle16_v_i16m1(&r[0x20], vl); + v3 = __riscv_vle16_v_i16m1(&r[0x30], vl); + v4 = __riscv_vle16_v_i16m1(&r[0x40], vl); + v5 = __riscv_vle16_v_i16m1(&r[0x50], vl); + v6 = __riscv_vle16_v_i16m1(&r[0x60], vl); + v7 = __riscv_vle16_v_i16m1(&r[0x70], vl); + v8 = __riscv_vle16_v_i16m1(&r[0x80], vl); + v9 = __riscv_vle16_v_i16m1(&r[0x90], vl); + va = __riscv_vle16_v_i16m1(&r[0xa0], vl); + vb = __riscv_vle16_v_i16m1(&r[0xb0], vl); + vc = __riscv_vle16_v_i16m1(&r[0xc0], vl); + vd = __riscv_vle16_v_i16m1(&r[0xd0], vl); + ve = __riscv_vle16_v_i16m1(&r[0xe0], vl); + vf = __riscv_vle16_v_i16m1(&r[0xf0], vl); + + MLK_RVV_BFLY_FX(v0, v8, vt, zeta[0x01], vl, 1); + MLK_RVV_BFLY_FX(v1, v9, vt, zeta[0x01], vl, 1); + MLK_RVV_BFLY_FX(v2, va, vt, zeta[0x01], vl, 1); + MLK_RVV_BFLY_FX(v3, vb, vt, zeta[0x01], vl, 1); + MLK_RVV_BFLY_FX(v4, vc, vt, zeta[0x01], vl, 1); + MLK_RVV_BFLY_FX(v5, vd, vt, zeta[0x01], vl, 1); + MLK_RVV_BFLY_FX(v6, ve, vt, zeta[0x01], vl, 1); + MLK_RVV_BFLY_FX(v7, vf, vt, zeta[0x01], vl, 1); + + MLK_RVV_BFLY_FX(v0, v4, vt, zeta[0x10], vl, 2); + MLK_RVV_BFLY_FX(v1, v5, vt, zeta[0x10], vl, 2); + MLK_RVV_BFLY_FX(v2, v6, vt, zeta[0x10], vl, 2); + MLK_RVV_BFLY_FX(v3, v7, vt, zeta[0x10], vl, 2); + MLK_RVV_BFLY_FX(v8, vc, vt, zeta[0x11], vl, 2); + MLK_RVV_BFLY_FX(v9, vd, vt, zeta[0x11], vl, 2); + MLK_RVV_BFLY_FX(va, ve, vt, zeta[0x11], vl, 2); + MLK_RVV_BFLY_FX(vb, vf, vt, zeta[0x11], vl, 2); + + MLK_RVV_BFLY_FX(v0, v2, vt, zeta[0x20], vl, 3); + MLK_RVV_BFLY_FX(v1, v3, vt, zeta[0x20], vl, 3); + MLK_RVV_BFLY_FX(v4, v6, vt, zeta[0x21], vl, 3); + MLK_RVV_BFLY_FX(v5, v7, vt, zeta[0x21], vl, 3); + MLK_RVV_BFLY_FX(v8, va, vt, zeta[0x30], vl, 3); + MLK_RVV_BFLY_FX(v9, vb, vt, zeta[0x30], vl, 3); + MLK_RVV_BFLY_FX(vc, ve, vt, zeta[0x31], vl, 3); + MLK_RVV_BFLY_FX(vd, vf, vt, zeta[0x31], vl, 3); + + MLK_RVV_BFLY_FX(v0, v1, vt, zeta[0x40], vl, 4); + MLK_RVV_BFLY_FX(v2, v3, vt, zeta[0x41], vl, 4); + MLK_RVV_BFLY_FX(v4, v5, vt, zeta[0x50], vl, 4); + MLK_RVV_BFLY_FX(v6, v7, vt, zeta[0x51], vl, 4); + MLK_RVV_BFLY_FX(v8, v9, vt, zeta[0x60], vl, 4); + MLK_RVV_BFLY_FX(va, vb, vt, zeta[0x61], vl, 4); + MLK_RVV_BFLY_FX(vc, vd, vt, zeta[0x70], vl, 4); + MLK_RVV_BFLY_FX(ve, vf, vt, zeta[0x71], vl, 4); + + __riscv_vse16_v_i16m2( + &r[0x00], mlk_rv64v_ntt2(__riscv_vcreate_v_i16m1_i16m2(v0, v1), z0), vl2); + __riscv_vse16_v_i16m2( + &r[0x20], mlk_rv64v_ntt2(__riscv_vcreate_v_i16m1_i16m2(v2, v3), z2), vl2); + __riscv_vse16_v_i16m2( + &r[0x40], mlk_rv64v_ntt2(__riscv_vcreate_v_i16m1_i16m2(v4, v5), z4), vl2); + __riscv_vse16_v_i16m2( + &r[0x60], mlk_rv64v_ntt2(__riscv_vcreate_v_i16m1_i16m2(v6, v7), z6), vl2); + __riscv_vse16_v_i16m2( + &r[0x80], mlk_rv64v_ntt2(__riscv_vcreate_v_i16m1_i16m2(v8, v9), z8), vl2); + __riscv_vse16_v_i16m2( + &r[0xa0], mlk_rv64v_ntt2(__riscv_vcreate_v_i16m1_i16m2(va, vb), za), vl2); + __riscv_vse16_v_i16m2( + &r[0xc0], mlk_rv64v_ntt2(__riscv_vcreate_v_i16m1_i16m2(vc, vd), zc), vl2); + __riscv_vse16_v_i16m2( + &r[0xe0], mlk_rv64v_ntt2(__riscv_vcreate_v_i16m1_i16m2(ve, vf), ze), vl2); +} + +#undef MLK_RVV_BFLY_FX +#undef MLK_RVV_BFLY_FV + +/************************************************* + * Name: poly_invntt_tomont + * + * Description: Computes inverse of negacyclic number-theoretic transform (NTT) + * of a polynomial in place; + * inputs assumed to be in bitreversed order, + * output in normal order + * + * Arguments: - uint16_t *r: pointer to in/output polynomial + **************************************************/ + +/* reverse butterfly operation */ + +#define MLK_RVV_BFLY_RX(u0, u1, ut, uc, vl) \ + { \ + ut = __riscv_vsub_vv_i16m1(u0, u1, vl); \ + u0 = __riscv_vadd_vv_i16m1(u0, u1, vl); \ + u0 = fq_csub(u0, vl); \ + u1 = fq_mul_vx(ut, uc, vl); \ + u1 = fq_cadd(u1, vl); \ + \ + mlk_assert_bound_int16m1(u0, vl, 0, MLKEM_Q); \ + mlk_assert_bound_int16m1(u1, vl, 0, MLKEM_Q); \ + } + +#define MLK_RVV_BFLY_RV(u0, u1, ut, uc, vl) \ + { \ + ut = __riscv_vsub_vv_i16m1(u0, u1, vl); \ + u0 = __riscv_vadd_vv_i16m1(u0, u1, vl); \ + u0 = fq_csub(u0, vl); \ + u1 = fq_mul_vv(ut, uc, vl); \ + u1 = fq_cadd(u1, vl); \ + \ + mlk_assert_bound_int16m1(u0, vl, 0, MLKEM_Q); \ + mlk_assert_bound_int16m1(u1, vl, 0, MLKEM_Q); \ + } + +static vint16m2_t mlk_rv64v_intt2(vint16m2_t vp, vint16m1_t cz) +{ + size_t vl = MLK_RVV_E16M1_VL; + size_t vl2 = 2 * vl; + + const vuint16m2_t v2p8 = bitswap_perm(3, 4, vl2); + const vuint16m2_t v2p4 = bitswap_perm(2, 4, vl2); + const vuint16m2_t v2p2 = bitswap_perm(1, 4, vl2); + + /* p0 = p2(p4(p8)) */ + const vuint16m2_t v2p0 = __riscv_vrgather_vv_u16m2( + __riscv_vrgather_vv_u16m2(v2p8, v2p4, vl2), v2p2, vl2); + + const vuint16m1_t vid = __riscv_vid_v_u16m1(vl); + const vuint16m1_t cs8 = + __riscv_vadd_vx_u16m1(__riscv_vsrl_vx_u16m1(vid, 3, vl), 2, vl); + const vuint16m1_t cs4 = + __riscv_vadd_vx_u16m1(__riscv_vsrl_vx_u16m1(vid, 2, vl), 2 + 2, vl); + const vuint16m1_t cs2 = + __riscv_vadd_vx_u16m1(__riscv_vsrl_vx_u16m1(vid, 1, vl), 2 + 2 + 4, vl); + + vint16m1_t t0, t1, c0, vt; + + /* initial permute */ + vp = __riscv_vrgatherei16_vv_i16m2(vp, v2p0, vl2); + t0 = __riscv_vget_v_i16m2_i16m1(vp, 0); + t1 = __riscv_vget_v_i16m2_i16m1(vp, 1); + + /* pre-scale and move to positive range [0, q-1] for inverse transform */ + t0 = fq_mulq_vx(t0, MLK_RVV_MONT_NR, vl); + t1 = fq_mulq_vx(t1, MLK_RVV_MONT_NR, vl); + + c0 = __riscv_vrgather_vv_i16m1(cz, cs2, vl); + MLK_RVV_BFLY_RV(t0, t1, vt, c0, vl); + + /* swap 2 */ + vp = __riscv_vcreate_v_i16m1_i16m2(t0, t1); + vp = __riscv_vrgatherei16_vv_i16m2(vp, v2p2, vl2); + t0 = __riscv_vget_v_i16m2_i16m1(vp, 0); + t1 = __riscv_vget_v_i16m2_i16m1(vp, 1); + c0 = __riscv_vrgather_vv_i16m1(cz, cs4, vl); + MLK_RVV_BFLY_RV(t0, t1, vt, c0, vl); + + /* swap 4 */ + vp = __riscv_vcreate_v_i16m1_i16m2(t0, t1); + vp = __riscv_vrgatherei16_vv_i16m2(vp, v2p4, vl2); + t0 = __riscv_vget_v_i16m2_i16m1(vp, 0); + t1 = __riscv_vget_v_i16m2_i16m1(vp, 1); + c0 = __riscv_vrgather_vv_i16m1(cz, cs8, vl); + MLK_RVV_BFLY_RV(t0, t1, vt, c0, vl); + + /* swap 8 */ + vp = __riscv_vcreate_v_i16m1_i16m2(t0, t1); + vp = __riscv_vrgatherei16_vv_i16m2(vp, v2p8, vl2); + t0 = __riscv_vget_v_i16m2_i16m1(vp, 0); + t1 = __riscv_vget_v_i16m2_i16m1(vp, 1); + + vp = __riscv_vcreate_v_i16m1_i16m2(t0, t1); + + return vp; +} + +void mlk_rv64v_poly_invntt_tomont(int16_t *r) +{ +/* zetas can be compiled into vector constants; don't pass as a pointer */ +#include "rv64v_izetas.inc" + + size_t vl = MLK_RVV_E16M1_VL; + size_t vl2 = 2 * vl; + + const vint16m1_t z0 = __riscv_vle16_v_i16m1(&izeta[0x00], vl); + const vint16m1_t z2 = __riscv_vle16_v_i16m1(&izeta[0x10], vl); + const vint16m1_t z4 = __riscv_vle16_v_i16m1(&izeta[0x20], vl); + const vint16m1_t z6 = __riscv_vle16_v_i16m1(&izeta[0x30], vl); + const vint16m1_t z8 = __riscv_vle16_v_i16m1(&izeta[0x40], vl); + const vint16m1_t za = __riscv_vle16_v_i16m1(&izeta[0x50], vl); + const vint16m1_t zc = __riscv_vle16_v_i16m1(&izeta[0x60], vl); + const vint16m1_t ze = __riscv_vle16_v_i16m1(&izeta[0x70], vl); + + vint16m1_t vt; + vint16m1_t v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, va, vb, vc, vd, ve, vf; + vint16m2_t vp; + + vp = mlk_rv64v_intt2(__riscv_vle16_v_i16m2(&r[0x00], vl2), z0); + v0 = __riscv_vget_v_i16m2_i16m1(vp, 0); + v1 = __riscv_vget_v_i16m2_i16m1(vp, 1); + + vp = mlk_rv64v_intt2(__riscv_vle16_v_i16m2(&r[0x20], vl2), z2); + v2 = __riscv_vget_v_i16m2_i16m1(vp, 0); + v3 = __riscv_vget_v_i16m2_i16m1(vp, 1); + + vp = mlk_rv64v_intt2(__riscv_vle16_v_i16m2(&r[0x40], vl2), z4); + v4 = __riscv_vget_v_i16m2_i16m1(vp, 0); + v5 = __riscv_vget_v_i16m2_i16m1(vp, 1); + + vp = mlk_rv64v_intt2(__riscv_vle16_v_i16m2(&r[0x60], vl2), z6); + v6 = __riscv_vget_v_i16m2_i16m1(vp, 0); + v7 = __riscv_vget_v_i16m2_i16m1(vp, 1); + + vp = mlk_rv64v_intt2(__riscv_vle16_v_i16m2(&r[0x80], vl2), z8); + v8 = __riscv_vget_v_i16m2_i16m1(vp, 0); + v9 = __riscv_vget_v_i16m2_i16m1(vp, 1); + + vp = mlk_rv64v_intt2(__riscv_vle16_v_i16m2(&r[0xa0], vl2), za); + va = __riscv_vget_v_i16m2_i16m1(vp, 0); + vb = __riscv_vget_v_i16m2_i16m1(vp, 1); + + vp = mlk_rv64v_intt2(__riscv_vle16_v_i16m2(&r[0xc0], vl2), zc); + vc = __riscv_vget_v_i16m2_i16m1(vp, 0); + vd = __riscv_vget_v_i16m2_i16m1(vp, 1); + + vp = mlk_rv64v_intt2(__riscv_vle16_v_i16m2(&r[0xe0], vl2), ze); + ve = __riscv_vget_v_i16m2_i16m1(vp, 0); + vf = __riscv_vget_v_i16m2_i16m1(vp, 1); + + MLK_RVV_BFLY_RX(v0, v1, vt, izeta[0x40], vl); + MLK_RVV_BFLY_RX(v2, v3, vt, izeta[0x41], vl); + MLK_RVV_BFLY_RX(v4, v5, vt, izeta[0x50], vl); + MLK_RVV_BFLY_RX(v6, v7, vt, izeta[0x51], vl); + MLK_RVV_BFLY_RX(v8, v9, vt, izeta[0x60], vl); + MLK_RVV_BFLY_RX(va, vb, vt, izeta[0x61], vl); + MLK_RVV_BFLY_RX(vc, vd, vt, izeta[0x70], vl); + MLK_RVV_BFLY_RX(ve, vf, vt, izeta[0x71], vl); + + MLK_RVV_BFLY_RX(v0, v2, vt, izeta[0x20], vl); + MLK_RVV_BFLY_RX(v1, v3, vt, izeta[0x20], vl); + MLK_RVV_BFLY_RX(v4, v6, vt, izeta[0x21], vl); + MLK_RVV_BFLY_RX(v5, v7, vt, izeta[0x21], vl); + MLK_RVV_BFLY_RX(v8, va, vt, izeta[0x30], vl); + MLK_RVV_BFLY_RX(v9, vb, vt, izeta[0x30], vl); + MLK_RVV_BFLY_RX(vc, ve, vt, izeta[0x31], vl); + MLK_RVV_BFLY_RX(vd, vf, vt, izeta[0x31], vl); + + MLK_RVV_BFLY_RX(v0, v4, vt, izeta[0x10], vl); + MLK_RVV_BFLY_RX(v1, v5, vt, izeta[0x10], vl); + MLK_RVV_BFLY_RX(v2, v6, vt, izeta[0x10], vl); + MLK_RVV_BFLY_RX(v3, v7, vt, izeta[0x10], vl); + MLK_RVV_BFLY_RX(v8, vc, vt, izeta[0x11], vl); + MLK_RVV_BFLY_RX(v9, vd, vt, izeta[0x11], vl); + MLK_RVV_BFLY_RX(va, ve, vt, izeta[0x11], vl); + MLK_RVV_BFLY_RX(vb, vf, vt, izeta[0x11], vl); + + MLK_RVV_BFLY_RX(v0, v8, vt, izeta[0x01], vl); + MLK_RVV_BFLY_RX(v1, v9, vt, izeta[0x01], vl); + MLK_RVV_BFLY_RX(v2, va, vt, izeta[0x01], vl); + MLK_RVV_BFLY_RX(v3, vb, vt, izeta[0x01], vl); + MLK_RVV_BFLY_RX(v4, vc, vt, izeta[0x01], vl); + MLK_RVV_BFLY_RX(v5, vd, vt, izeta[0x01], vl); + MLK_RVV_BFLY_RX(v6, ve, vt, izeta[0x01], vl); + MLK_RVV_BFLY_RX(v7, vf, vt, izeta[0x01], vl); + + __riscv_vse16_v_i16m1(&r[0x00], v0, vl); + __riscv_vse16_v_i16m1(&r[0x10], v1, vl); + __riscv_vse16_v_i16m1(&r[0x20], v2, vl); + __riscv_vse16_v_i16m1(&r[0x30], v3, vl); + __riscv_vse16_v_i16m1(&r[0x40], v4, vl); + __riscv_vse16_v_i16m1(&r[0x50], v5, vl); + __riscv_vse16_v_i16m1(&r[0x60], v6, vl); + __riscv_vse16_v_i16m1(&r[0x70], v7, vl); + __riscv_vse16_v_i16m1(&r[0x80], v8, vl); + __riscv_vse16_v_i16m1(&r[0x90], v9, vl); + __riscv_vse16_v_i16m1(&r[0xa0], va, vl); + __riscv_vse16_v_i16m1(&r[0xb0], vb, vl); + __riscv_vse16_v_i16m1(&r[0xc0], vc, vl); + __riscv_vse16_v_i16m1(&r[0xd0], vd, vl); + __riscv_vse16_v_i16m1(&r[0xe0], ve, vl); + __riscv_vse16_v_i16m1(&r[0xf0], vf, vl); +} + +#undef MLK_RVV_BFLY_RX +#undef MLK_RVV_BFLY_RV + +/* Kyber's middle field GF(3329)[X]/(X^2) multiplication */ + +static inline void mlk_rv64v_poly_basemul_mont_add_k(int16_t *r, + const int16_t *a, + const int16_t *b, + unsigned kn) +{ +#include "rv64v_zetas_basemul.inc" + + size_t vl = MLK_RVV_E16M1_VL; + size_t i, j; + + const vuint16m1_t sw0 = __riscv_vxor_vx_u16m1(__riscv_vid_v_u16m1(vl), 1, vl); + const vbool16_t sb0 = __riscv_vmseq_vx_u16m1_b16( + __riscv_vand_vx_u16m1(__riscv_vid_v_u16m1(vl), 1, vl), 0, vl); + + vint16m1_t vt, vu; + vint32m2_t wa, wb, ws; + + for (i = 0; i < MLKEM_N; i += vl) + { + const vint16m1_t vz = __riscv_vle16_v_i16m1(&roots[i], vl); + + for (j = 0; j < kn; j += MLKEM_N) + { + vt = __riscv_vle16_v_i16m1(&a[i + j], vl); + vu = __riscv_vle16_v_i16m1(&b[i + j], vl); + + wa = __riscv_vwmul_vv_i32m2(vz, fq_mul_vv(vt, vu, vl), vl); + wb = __riscv_vwmul_vv_i32m2(vt, __riscv_vrgather_vv_i16m1(vu, sw0, vl), + vl); + + wa = + __riscv_vadd_vv_i32m2(wa, __riscv_vslidedown_vx_i32m2(wa, 1, vl), vl); + wb = __riscv_vadd_vv_i32m2(wb, __riscv_vslideup_vx_i32m2(wb, wb, 1, vl), + vl); + + wa = __riscv_vmerge_vvm_i32m2(wb, wa, sb0, vl); + + if (j == 0) + { + ws = wa; + } + else + { + ws = __riscv_vadd_vv_i32m2(ws, wa, vl); + } + } + /* the idea is to keep 32-bit intermediate result, reduce in the end */ + __riscv_vse16_v_i16m1(&r[i], fq_redc2(ws, vl), vl); + } +} + +void mlk_rv64v_poly_basemul_mont_add_k2(int16_t *r, const int16_t *a, + const int16_t *b) +{ + mlk_rv64v_poly_basemul_mont_add_k(r, a, b, 2 * MLKEM_N); +} + +void mlk_rv64v_poly_basemul_mont_add_k3(int16_t *r, const int16_t *a, + const int16_t *b) +{ + mlk_rv64v_poly_basemul_mont_add_k(r, a, b, 3 * MLKEM_N); +} + +void mlk_rv64v_poly_basemul_mont_add_k4(int16_t *r, const int16_t *a, + const int16_t *b) +{ + mlk_rv64v_poly_basemul_mont_add_k(r, a, b, 4 * MLKEM_N); +} + +#endif /* MLK_RVV_VLEN == 256 */ + +/************************************************* + * Name: poly_tomont + * + * Description: Inplace conversion of all coefficients of a polynomial + * from normal domain to Montgomery domain + * + * Arguments: - int16_t *r: pointer to input/output polynomial + **************************************************/ +void mlk_rv64v_poly_tomont(int16_t *r) +{ + size_t vl = MLK_RVV_E16M1_VL; + + for (size_t i = 0; i < MLKEM_N; i += vl) + { + vint16m1_t vec = __riscv_vle16_v_i16m1(&r[i], vl); + vec = fq_mul_vx(vec, MLK_RVV_MONT_R2, vl); + __riscv_vse16_v_i16m1(&r[i], vec, vl); + } +} + +/************************************************* + * Name: poly_reduce + * + * Description: Applies Barrett reduction to all coefficients of a polynomial + * for details of the Barrett reduction see + *comments in reduce.c + * + * Arguments: - int16_t *r: pointer to input/output polynomial + **************************************************/ +void mlk_rv64v_poly_reduce(int16_t *r) +{ + size_t vl = MLK_RVV_E16M1_VL; + vint16m1_t vt; + + for (size_t i = 0; i < MLKEM_N; i += vl) + { + vt = __riscv_vle16_v_i16m1(&r[i], vl); + vt = fq_barrett(vt, vl); + vt = fq_cadd(vt, vl); + __riscv_vse16_v_i16m1(&r[i], vt, vl); + } +} + +/************************************************* + * Name: poly_add + * + * Description: Add two polynomials; no modular reduction is performed + * + * Arguments: - int16_t *r: pointer to output polynomial + * - const int16_t *a: pointer to first input polynomial + * - const int16_t *b: pointer to second input polynomial + **************************************************/ +void mlk_rv64v_poly_add(int16_t *r, const int16_t *a, const int16_t *b) +{ + size_t vl = MLK_RVV_E16M1_VL; + + for (size_t i = 0; i < MLKEM_N; i += vl) + { + __riscv_vse16_v_i16m1( + &r[i], + __riscv_vadd_vv_i16m1(__riscv_vle16_v_i16m1(&a[i], vl), + __riscv_vle16_v_i16m1(&b[i], vl), vl), + vl); + } +} + +/************************************************* + * Name: poly_sub + * + * Description: Subtract two polynomials; no modular reduction is performed + * + * Arguments: - int16_t *r: pointer to output polynomial + * - const int16_t *a: pointer to first input polynomial + * - const int16_t *b: pointer to second input polynomial + **************************************************/ +void mlk_rv64v_poly_sub(int16_t *r, const int16_t *a, const int16_t *b) +{ + size_t vl = MLK_RVV_E16M1_VL; + + for (size_t i = 0; i < MLKEM_N; i += vl) + { + __riscv_vse16_v_i16m1( + &r[i], + __riscv_vsub_vv_i16m1(__riscv_vle16_v_i16m1(&a[i], vl), + __riscv_vle16_v_i16m1(&b[i], vl), vl), + vl); + } +} + +/* Run rejection sampling to get uniform random integers mod q */ + +unsigned int mlk_rv64v_rej_uniform(int16_t *r, unsigned int len, + const uint8_t *buf, unsigned int buflen) +{ + const size_t vl = MLK_RVV_E16M1_VL; + const size_t vl23 = (MLK_RVV_E16M1_VL * 24) / 32; + + const vuint16m1_t vid = __riscv_vid_v_u16m1(vl); + const vuint16m1_t srl12v = __riscv_vmul_vx_u16m1(vid, 12, vl); + const vuint16m1_t sel12v = __riscv_vsrl_vx_u16m1(srl12v, 4, vl); + const vuint16m1_t sll12v = __riscv_vsll_vx_u16m1(vid, 2, vl); + + size_t n, ctr, pos; + vuint16m1_t x, y; + vbool16_t lt; + + pos = 0; + ctr = 0; + + while (ctr < len && pos + (vl23 * 2) <= buflen) + { + x = __riscv_vle16_v_u16m1((uint16_t *)&buf[pos], vl23); + pos += vl23 * 2; + x = __riscv_vrgather_vv_u16m1(x, sel12v, vl); + x = __riscv_vor_vv_u16m1( + __riscv_vsrl_vv_u16m1(x, srl12v, vl), + __riscv_vsll_vv_u16m1(__riscv_vslidedown(x, 1, vl), sll12v, vl), vl); + x = __riscv_vand_vx_u16m1(x, 0xFFF, vl); + + lt = __riscv_vmsltu_vx_u16m1_b16(x, MLKEM_Q, vl); + y = __riscv_vcompress_vm_u16m1(x, lt, vl); + n = __riscv_vcpop_m_b16(lt, vl); + + if (ctr + n > len) + { + n = len - ctr; + } + __riscv_vse16_v_u16m1((uint16_t *)&r[ctr], y, n); + ctr += n; + } + + return ctr; +} + +#else /* MLK_ARITH_BACKEND_RISCV64 && !MLK_CONFIG_MULTILEVEL_NO_SHARED */ + +MLK_EMPTY_CU(rv64v_poly) + +#endif /* !(MLK_ARITH_BACKEND_RISCV64 && !MLK_CONFIG_MULTILEVEL_NO_SHARED) */ + +/* To facilitate single-compilation-unit (SCU) builds, undefine all macros. + * Don't modify by hand -- this is auto-generated by scripts/autogen. */ +#undef MLK_RVV_BFLY_FX +#undef MLK_RVV_BFLY_FV +#undef MLK_RVV_BFLY_RX +#undef MLK_RVV_BFLY_RV +/* Some macros are kept because they are also defined in a header. */ +/* Keep: MLK_RVV_VLEN (rv64v_settings.h) */ +/* Keep: MLK_RVV_E16M1_VL (rv64v_settings.h) */ +/* Keep: MLK_RVV_QI (rv64v_settings.h) */ +/* Keep: MLK_RVV_MONT_R1 (rv64v_settings.h) */ +/* Keep: MLK_RVV_MONT_R2 (rv64v_settings.h) */ +/* Keep: MLK_RVV_MONT_NR (rv64v_settings.h) */ diff --git a/mlkem/src/native/riscv64/src/rv64v_settings.h b/mlkem/src/native/riscv64/src/rv64v_settings.h new file mode 100644 index 000000000..868825c59 --- /dev/null +++ b/mlkem/src/native/riscv64/src/rv64v_settings.h @@ -0,0 +1,168 @@ +/* + * Copyright (c) The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT + */ +#ifndef MLK_NATIVE_RISCV64_SRC_RV64V_SETTINGS_H +#define MLK_NATIVE_RISCV64_SRC_RV64V_SETTINGS_H + +#include +#include "../../../debug.h" + +/************************************************* + * RISC-V Vector Bounds Assertion Macros + * + * These macros provide runtime bounds checking for RISC-V vector types + * vint16m1_t and vint16m2_t, following the same pattern as the scalar + * bounds assertions in debug.h + * + * The macros are only active when MLKEM_DEBUG is defined, otherwise they + * compile to no-ops for zero runtime overhead in release builds. + **************************************************/ + +#if defined(MLKEM_DEBUG) + +/************************************************* + * Name: mlk_debug_check_bounds_int16m1 + * + * Description: Check whether values in a vint16m1_t vector + * are within specified bounds. + * + * Arguments: - file: filename + * - line: line number + * - vec: RISC-V vector to be checked + * - vl: vector length (number of active elements) + * - lower_bound_exclusive: Exclusive lower bound + * - upper_bound_exclusive: Exclusive upper bound + **************************************************/ +#define mlk_debug_check_bounds_int16m1 \ + MLK_NAMESPACE(mlkem_debug_check_bounds_int16m1) +void mlk_debug_check_bounds_int16m1(const char *file, int line, vint16m1_t vec, + size_t vl, int lower_bound_exclusive, + int upper_bound_exclusive); + +/************************************************* + * Name: mlk_debug_check_bounds_int16m2 + * + * Description: Check whether values in a vint16m2_t vector + * are within specified bounds by splitting into m1 vectors. + * + * Arguments: - file: filename + * - line: line number + * - vec: RISC-V vector to be checked + * - vl: vector length (number of active elements per m1 half) + * - lower_bound_exclusive: Exclusive lower bound + * - upper_bound_exclusive: Exclusive upper bound + **************************************************/ +#define mlk_debug_check_bounds_int16m2 \ + MLK_NAMESPACE(mlkem_debug_check_bounds_int16m2) +void mlk_debug_check_bounds_int16m2(const char *file, int line, vint16m2_t vec, + size_t vl, int lower_bound_exclusive, + int upper_bound_exclusive); + +/* Check bounds in vint16m1_t vector + * vec: RISC-V vector of type vint16m1_t + * vl: Vector length (number of active elements) + * value_lb: Inclusive lower value bound + * value_ub: Exclusive upper value bound */ +#define mlk_assert_bound_int16m1(vec, vl, value_lb, value_ub) \ + mlk_debug_check_bounds_int16m1(__FILE__, __LINE__, (vec), (vl), \ + (value_lb) - 1, (value_ub)) + +/* Check absolute bounds in vint16m1_t vector + * vec: RISC-V vector of type vint16m1_t + * vl: Vector length (number of active elements) + * value_abs_bd: Exclusive absolute upper bound */ +#define mlk_assert_abs_bound_int16m1(vec, vl, value_abs_bd) \ + mlk_assert_bound_int16m1((vec), (vl), (-(value_abs_bd) + 1), (value_abs_bd)) + +/* Check bounds in vint16m2_t vector + * vec: RISC-V vector of type vint16m2_t + * vl: Vector length (number of active elements per m1 half) + * value_lb: Inclusive lower value bound + * value_ub: Exclusive upper value bound */ +#define mlk_assert_bound_int16m2(vec, vl, value_lb, value_ub) \ + mlk_debug_check_bounds_int16m2(__FILE__, __LINE__, (vec), (vl), \ + (value_lb) - 1, (value_ub)) + +/* Check absolute bounds in vint16m2_t vector + * vec: RISC-V vector of type vint16m2_t + * vl: Vector length (number of active elements per m1 half) + * value_abs_bd: Exclusive absolute upper bound */ +#define mlk_assert_abs_bound_int16m2(vec, vl, value_abs_bd) \ + mlk_assert_bound_int16m2((vec), (vl), (-(value_abs_bd) + 1), (value_abs_bd)) + +#elif defined(CBMC) + +/* For CBMC, we would need to implement vector bounds checking using CBMC + * primitives This is complex and would require extracting vector elements, so + * for now we provide empty implementations that could be extended later */ +#define mlk_assert_bound_int16m1(vec, vl, value_lb, value_ub) \ + do \ + { \ + } while (0) + +#define mlk_assert_abs_bound_int16m1(vec, vl, value_abs_bd) \ + do \ + { \ + } while (0) + +#define mlk_assert_bound_int16m2(vec, vl, value_lb, value_ub) \ + do \ + { \ + } while (0) + +#define mlk_assert_abs_bound_int16m2(vec, vl, value_abs_bd) \ + do \ + { \ + } while (0) + +#else /* !MLKEM_DEBUG && CBMC */ + +/* When debugging is disabled, all assertions become no-ops */ +#define mlk_assert_bound_int16m1(vec, vl, value_lb, value_ub) \ + do \ + { \ + } while (0) + +#define mlk_assert_abs_bound_int16m1(vec, vl, value_abs_bd) \ + do \ + { \ + } while (0) + +#define mlk_assert_bound_int16m2(vec, vl, value_lb, value_ub) \ + do \ + { \ + } while (0) + +#define mlk_assert_abs_bound_int16m2(vec, vl, value_abs_bd) \ + do \ + { \ + } while (0) + +#endif /* !MLKEM_DEBUG && !CBMC */ + +/* === vector configuration */ +#ifndef MLK_RVV_VLEN +#define MLK_RVV_VLEN 256 +#endif + + + +/* vl value for a 16-bit wide type */ +#define MLK_RVV_E16M1_VL (MLK_RVV_VLEN / 16) + +/* Montgomery reduction constants */ +/* n = 256; q = 3329; r = 2^16 */ +/* check-magic: 3327 == unsigned_mod(-pow(MLKEM_Q,-1,2^16), 2^16) */ +#define MLK_RVV_QI 3327 + +/* check-magic: 2285 == unsigned_mod(2^16, MLKEM_Q) */ +#define MLK_RVV_MONT_R1 2285 + +/* check-magic: 1353 == pow(2, 32, MLKEM_Q) */ +#define MLK_RVV_MONT_R2 1353 + +/* check-magic: 1441 == pow(2,32-7,MLKEM_Q) */ +#define MLK_RVV_MONT_NR 1441 + +#endif /* !MLK_NATIVE_RISCV64_SRC_RV64V_SETTINGS_H */ diff --git a/mlkem/src/native/riscv64/src/rv64v_zetas.inc b/mlkem/src/native/riscv64/src/rv64v_zetas.inc new file mode 100644 index 000000000..3ec50b08f --- /dev/null +++ b/mlkem/src/native/riscv64/src/rv64v_zetas.inc @@ -0,0 +1,28 @@ +/* + * Copyright (c) The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT + */ + +/* + * WARNING: This file is auto-generated from scripts/autogen + * in the mlkem-native repository. + * Do not modify it directly. + */ + +#include +#include "arith_native_riscv64.h" + +const int16_t zeta[] = { + -1044, -758, 573, -1325, 1223, 652, -552, 1015, -1103, 430, 555, + 843, -1251, 871, 1550, 105, -359, -1517, 264, 383, -1293, 1491, + -282, -1544, 422, 587, 177, -235, -291, -460, 1574, 1653, 1493, + 1422, -829, 1458, 516, -8, -320, -666, -246, 778, 1159, -147, + -777, 1483, -602, 1119, 287, 202, -1602, -130, -1618, -1162, 126, + 1469, -1590, 644, -872, 349, 418, 329, -156, -75, -171, 622, + -681, 1017, -853, -90, -271, 830, 817, 1097, 603, 610, 1322, + -1285, -1465, 384, 1577, 182, 732, 608, 107, -1421, -247, -951, + -1215, -136, 1218, -1335, -874, 220, -1187, -1659, 962, -1202, -1542, + 411, -398, 961, -1508, -725, -1185, -1530, -1278, 794, -1510, -854, + -870, 478, -1474, 1468, -205, -1571, 448, -1065, 677, -1275, -108, + -308, 996, 991, 958, -1460, 1522, 1628, +}; diff --git a/mlkem/src/native/riscv64/src/rv64v_zetas_basemul.inc b/mlkem/src/native/riscv64/src/rv64v_zetas_basemul.inc new file mode 100644 index 000000000..e161926f7 --- /dev/null +++ b/mlkem/src/native/riscv64/src/rv64v_zetas_basemul.inc @@ -0,0 +1,40 @@ +/* + * Copyright (c) The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT + */ + +/* + * WARNING: This file is auto-generated from scripts/autogen + * in the mlkem-native repository. + * Do not modify it directly. + */ + +#include +#include "arith_native_riscv64.h" + +const int16_t roots[] = { + -1044, -1103, -1044, 1103, -1044, 430, -1044, -430, -1044, 555, -1044, + -555, -1044, 843, -1044, -843, -1044, -1251, -1044, 1251, -1044, 871, + -1044, -871, -1044, 1550, -1044, -1550, -1044, 105, -1044, -105, -1044, + 422, -1044, -422, -1044, 587, -1044, -587, -1044, 177, -1044, -177, + -1044, -235, -1044, 235, -1044, -291, -1044, 291, -1044, -460, -1044, + 460, -1044, 1574, -1044, -1574, -1044, 1653, -1044, -1653, -1044, -246, + -1044, 246, -1044, 778, -1044, -778, -1044, 1159, -1044, -1159, -1044, + -147, -1044, 147, -1044, -777, -1044, 777, -1044, 1483, -1044, -1483, + -1044, -602, -1044, 602, -1044, 1119, -1044, -1119, -1044, -1590, -1044, + 1590, -1044, 644, -1044, -644, -1044, -872, -1044, 872, -1044, 349, + -1044, -349, -1044, 418, -1044, -418, -1044, 329, -1044, -329, -1044, + -156, -1044, 156, -1044, -75, -1044, 75, -1044, 817, -1044, -817, + -1044, 1097, -1044, -1097, -1044, 603, -1044, -603, -1044, 610, -1044, + -610, -1044, 1322, -1044, -1322, -1044, -1285, -1044, 1285, -1044, -1465, + -1044, 1465, -1044, 384, -1044, -384, -1044, -1215, -1044, 1215, -1044, + -136, -1044, 136, -1044, 1218, -1044, -1218, -1044, -1335, -1044, 1335, + -1044, -874, -1044, 874, -1044, 220, -1044, -220, -1044, -1187, -1044, + 1187, -1044, -1659, -1044, 1659, -1044, -1185, -1044, 1185, -1044, -1530, + -1044, 1530, -1044, -1278, -1044, 1278, -1044, 794, -1044, -794, -1044, + -1510, -1044, 1510, -1044, -854, -1044, 854, -1044, -870, -1044, 870, + -1044, 478, -1044, -478, -1044, -108, -1044, 108, -1044, -308, -1044, + 308, -1044, 996, -1044, -996, -1044, 991, -1044, -991, -1044, 958, + -1044, -958, -1044, -1460, -1044, 1460, -1044, 1522, -1044, -1522, -1044, + 1628, -1044, -1628, +}; diff --git a/mlkem/src/sys.h b/mlkem/src/sys.h index 90b991ad4..1124a684b 100644 --- a/mlkem/src/sys.h +++ b/mlkem/src/sys.h @@ -48,6 +48,12 @@ #define MLK_SYS_RISCV64 #endif +#if defined(MLK_SYS_RISCV64) && defined(__riscv_vector) && \ + defined(__riscv_v_intrinsic) && \ + (defined(__riscv_v_min_vlen) && __riscv_v_min_vlen >= 256) +#define MLK_SYS_RISCV64_V256 +#endif + #if defined(__riscv) && defined(__riscv_xlen) && __riscv_xlen == 32 #define MLK_SYS_RISCV32 #endif diff --git a/scripts/autogen b/scripts/autogen index e5895598e..5d5b9063e 100755 --- a/scripts/autogen +++ b/scripts/autogen @@ -1082,6 +1082,141 @@ def gen_avx2_mulcache_twiddles_file(dry_run=False): ) +def gen_riscv64_root_of_unity_for_block(layer, block, inv=False): + # We are computing a negacyclic NTT; the twiddles needed here is + # the second half of the twiddles for a cyclic NTT of twice the size. + # For ease of calculating the roots, layers are numbers 0 through 6 + # in this function. + log = bitreverse(pow(2, layer) + block, 7) + if inv is True: + log = -log + root = signed_reduce(pow(root_of_unity, log, modulus) * montgomery_factor) + return root + + +def gen_riscv64_zetas(): + zetas_0123 = [ + signed_reduce(montgomery_factor), + gen_riscv64_root_of_unity_for_block(0, 0), + ] + + zetas_0123 += [gen_riscv64_root_of_unity_for_block(1, i) for i in range(2)] + zetas_0123 += [gen_riscv64_root_of_unity_for_block(2, i) for i in range(4)] + zetas_0123 += [gen_riscv64_root_of_unity_for_block(3, i) for i in range(8)] + + for j in range(8): + yield from [zetas_0123[j * 2]] + yield from [zetas_0123[j * 2 + 1]] + for i in range(2): + yield from [gen_riscv64_root_of_unity_for_block(4, j * 2 + i)] + for i in range(4): + yield from [gen_riscv64_root_of_unity_for_block(5, j * 4 + i)] + for i in range(8): + yield from [gen_riscv64_root_of_unity_for_block(6, j * 8 + i)] + + +def gen_riscv64_izetas(): + zetas_0123 = [ + signed_reduce(montgomery_factor), + gen_riscv64_root_of_unity_for_block(0, 0, inv=True), + ] + + zetas_0123 += [ + gen_riscv64_root_of_unity_for_block(1, i, inv=True) for i in range(2) + ] + zetas_0123 += [ + gen_riscv64_root_of_unity_for_block(2, i, inv=True) for i in range(4) + ] + zetas_0123 += [ + gen_riscv64_root_of_unity_for_block(3, i, inv=True) for i in range(8) + ] + + for j in range(8): + yield from [zetas_0123[j * 2]] + yield from [zetas_0123[j * 2 + 1]] + for i in range(2): + yield from [gen_riscv64_root_of_unity_for_block(4, j * 2 + i, inv=True)] + for i in range(4): + yield from [gen_riscv64_root_of_unity_for_block(5, j * 4 + i, inv=True)] + for i in range(8): + yield from [gen_riscv64_root_of_unity_for_block(6, j * 8 + i, inv=True)] + + +def gen_riscv64_basemul_roots(): + R = signed_reduce(montgomery_factor) + for i in range(64): + yield from [ + R, + signed_reduce( + pow(root_of_unity, bitreverse(64 + i, 7), modulus) * montgomery_factor + ), + ] + yield from [ + R, + signed_reduce( + -pow(root_of_unity, bitreverse(64 + i, 7), modulus) * montgomery_factor + ), + ] + + +def gen_riscv64_zeta_files(dry_run=False): + """Generate all RISC-V 64 zeta files""" + + # Generate rv64v_zetas.inc + def gen_zetas(): + yield from gen_header() + yield "#include " + yield '#include "arith_native_riscv64.h"' + yield "" + yield "const int16_t zeta[] = {" + yield from map(lambda t: str(t) + ",", gen_riscv64_zetas()) + yield "};" + yield "" + + update_file( + "mlkem/src/native/riscv64/src/rv64v_zetas.inc", + "\n".join(gen_zetas()), + dry_run=dry_run, + force_format=True, + ) + + # Generate rv64v_izetas.inc + def gen_izetas(): + yield from gen_header() + yield "#include " + yield '#include "arith_native_riscv64.h"' + yield "" + yield "const int16_t izeta[] = {" + yield from map(lambda t: str(t) + ",", gen_riscv64_izetas()) + yield "};" + yield "" + + update_file( + "mlkem/src/native/riscv64/src/rv64v_izetas.inc", + "\n".join(gen_izetas()), + dry_run=dry_run, + force_format=True, + ) + + # Generate rv64v_zetas_basemul.inc + def gen_basemul(): + yield from gen_header() + yield "#include " + yield '#include "arith_native_riscv64.h"' + yield " " + yield "const int16_t roots[] = {" + yield from map(lambda t: str(t) + ",", gen_riscv64_basemul_roots()) + yield "};" + yield "" + + update_file( + "mlkem/src/native/riscv64/src/rv64v_zetas_basemul.inc", + "\n".join(gen_basemul()), + dry_run=dry_run, + force_format=True, + ) + + def get_c_source_files(main_only=False, core_only=False, strip_mlkem=False): if main_only is True: return get_files("mlkem/src/**/*.c", strip_mlkem=strip_mlkem) @@ -1220,6 +1355,10 @@ def x86_64(c): return "/x86_64/" in c +def riscv64(c): + return "/riscv64/" in c + + def native_fips202(c): return native(c) and fips202(c) @@ -1252,9 +1391,16 @@ def native_arith_x86_64(c): return native_arith(c) and x86_64(c) +def native_arith_riscv64(c): + return native_arith(c) and riscv64(c) + + def native_arith_core(c): return ( - native_arith(c) and not native_arith_x86_64(c) and not native_arith_aarch64(c) + native_arith(c) + and not native_arith_x86_64(c) + and not native_arith_aarch64(c) + and not native_arith_riscv64(c) ) @@ -1353,6 +1499,11 @@ def gen_macro_undefs(extra_notes=None): filt=native_arith_x86_64, desc="native code (Arith, X86_64)" ) yield "#endif" + yield "#if defined(MLK_SYS_RISCV64)" + yield from gen_monolithic_undef_all_core( + filt=native_arith_riscv64, desc="native code (Arith, RISC-V 64)" + ) + yield "#endif" yield "#endif" yield "#endif" yield "" @@ -1431,6 +1582,10 @@ def gen_monolithic_source_file(dry_run=False): for c in filter(native_arith_x86_64, c_sources): yield f'#include "{c}"' yield "#endif" + yield "#if defined(MLK_SYS_RISCV64)" + for c in filter(native_arith_riscv64, c_sources): + yield f'#include "{c}"' + yield "#endif" yield "#endif" yield "" yield "#if defined(MLK_CONFIG_USE_NATIVE_BACKEND_FIPS202)" @@ -1515,6 +1670,10 @@ def gen_monolithic_asm_file(dry_run=False): for c in filter(native_arith_x86_64, asm_sources): yield f'#include "{c}"' yield "#endif" + yield "#if defined(MLK_SYS_RISCV64)" + for c in filter(native_arith_riscv64, asm_sources): + yield f'#include "{c}"' + yield "#endif" yield "#endif" yield "" yield "#if defined(MLK_CONFIG_USE_NATIVE_BACKEND_FIPS202)" @@ -2801,6 +2960,7 @@ def _main(): gen_avx2_zeta_file(args.dry_run) gen_avx2_rej_uniform_table(args.dry_run) gen_avx2_mulcache_twiddles_file(args.dry_run) + gen_riscv64_zeta_files(args.dry_run) high_level_status("Generated zeta and lookup tables") if platform.machine().lower() in ["arm64", "aarch64"]: diff --git a/test/mk/auto.mk b/test/mk/auto.mk index bcbf3ac1c..fb2a6b801 100644 --- a/test/mk/auto.mk +++ b/test/mk/auto.mk @@ -96,6 +96,7 @@ endif # aarch64_be # RISC-V 64-bit CFLAGS configuration ifeq ($(ARCH),riscv64) CFLAGS += -DMLK_FORCE_RISCV64 +CFLAGS += -march=rv64gcv_zvl256b endif # riscv64 # RISC-V 32-bit CFLAGS configuration diff --git a/test/mk/components.mk b/test/mk/components.mk index cdcc3eb5d..16ca78e52 100644 --- a/test/mk/components.mk +++ b/test/mk/components.mk @@ -7,7 +7,7 @@ endif SOURCES += $(wildcard mlkem/src/*.c) ifeq ($(OPT),1) - SOURCES += $(wildcard mlkem/src/native/aarch64/src/*.[csS]) $(wildcard mlkem/src/native/x86_64/src/*.[csS]) + SOURCES += $(wildcard mlkem/src/native/aarch64/src/*.[csS]) $(wildcard mlkem/src/native/x86_64/src/*.[csS]) $(wildcard mlkem/src/native/riscv64/src/*.[csS]) CFLAGS += -DMLK_CONFIG_USE_NATIVE_BACKEND_ARITH -DMLK_CONFIG_USE_NATIVE_BACKEND_FIPS202 endif