Skip to content

Commit f2df58e

Browse files
committed
AVX2: Add native implementation of polyz_unpack
This adds the AVX2 intrinsics implementation of polyz_unpack from https://github.com/pq-crystals/dilithium/blob/master/avx2/poly.c. Signed-off-by: jammychiou1 <[email protected]>
1 parent d6f0aff commit f2df58e

File tree

7 files changed

+274
-4
lines changed

7 files changed

+274
-4
lines changed

BIBLIOGRAPHY.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,8 @@ source code and documentation.
156156
- [mldsa/native/x86_64/src/poly_decompose_88_avx2.c](mldsa/native/x86_64/src/poly_decompose_88_avx2.c)
157157
- [mldsa/native/x86_64/src/poly_use_hint_32_avx2.c](mldsa/native/x86_64/src/poly_use_hint_32_avx2.c)
158158
- [mldsa/native/x86_64/src/poly_use_hint_88_avx2.c](mldsa/native/x86_64/src/poly_use_hint_88_avx2.c)
159+
- [mldsa/native/x86_64/src/polyz_unpack_17_avx2.c](mldsa/native/x86_64/src/polyz_unpack_17_avx2.c)
160+
- [mldsa/native/x86_64/src/polyz_unpack_19_avx2.c](mldsa/native/x86_64/src/polyz_unpack_19_avx2.c)
159161
- [mldsa/native/x86_64/src/rej_uniform_avx2.c](mldsa/native/x86_64/src/rej_uniform_avx2.c)
160162
- [mldsa/native/x86_64/src/rej_uniform_eta2_avx2.c](mldsa/native/x86_64/src/rej_uniform_eta2_avx2.c)
161163
- [mldsa/native/x86_64/src/rej_uniform_eta4_avx2.c](mldsa/native/x86_64/src/rej_uniform_eta4_avx2.c)

mldsa/native/api.h

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,4 +270,32 @@ static MLD_INLINE void mld_poly_use_hint_88_native(int32_t *b, const int32_t *a,
270270
static MLD_INLINE uint32_t mld_poly_chknorm_native(const int32_t *a, int32_t B);
271271
#endif /* MLD_USE_NATIVE_POLY_CHKNORM */
272272

273+
#if defined(MLD_USE_NATIVE_POLYZ_UNPACK_17)
274+
/*************************************************
275+
* Name: mld_polyz_unpack_17_native
276+
*
277+
* Description: Native implementation of polyz_unpack for GAMMA1 = 2^17.
278+
* Unpack polynomial z with coefficients
279+
* in [-(MLDSA_GAMMA1 - 1), MLDSA_GAMMA1].
280+
*
281+
* Arguments: - int32_t *r: pointer to output polynomial
282+
* - const uint8_t *a: byte array with bit-packed polynomial
283+
**************************************************/
284+
static MLD_INLINE void mld_polyz_unpack_17_native(int32_t *r, const uint8_t *a);
285+
#endif /* MLD_USE_NATIVE_POLYZ_UNPACK_17 */
286+
287+
#if defined(MLD_USE_NATIVE_POLYZ_UNPACK_19)
288+
/*************************************************
289+
* Name: mld_polyz_unpack_19_native
290+
*
291+
* Description: Native implementation of polyz_unpack for GAMMA1 = 2^19.
292+
* Unpack polynomial z with coefficients
293+
* in [-(MLDSA_GAMMA1 - 1), MLDSA_GAMMA1].
294+
*
295+
* Arguments: - int32_t *r: pointer to output polynomial
296+
* - const uint8_t *a: byte array with bit-packed polynomial
297+
**************************************************/
298+
static MLD_INLINE void mld_polyz_unpack_19_native(int32_t *r, const uint8_t *a);
299+
#endif /* MLD_USE_NATIVE_POLYZ_UNPACK_19 */
300+
273301
#endif /* !MLD_NATIVE_API_H */

mldsa/native/x86_64/meta.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
#define MLD_USE_NATIVE_POLY_USE_HINT_32
2424
#define MLD_USE_NATIVE_POLY_USE_HINT_88
2525
#define MLD_USE_NATIVE_POLY_CHKNORM
26+
#define MLD_USE_NATIVE_POLYZ_UNPACK_17
27+
#define MLD_USE_NATIVE_POLYZ_UNPACK_19
2628

2729
#if !defined(__ASSEMBLER__)
2830
#include <string.h>
@@ -139,6 +141,16 @@ static MLD_INLINE uint32_t mld_poly_chknorm_native(const int32_t *a, int32_t B)
139141
return mld_poly_chknorm_avx2((const __m256i *)a, B);
140142
}
141143

144+
static MLD_INLINE void mld_polyz_unpack_17_native(int32_t *r, const uint8_t *a)
145+
{
146+
mld_polyz_unpack_17_avx2((__m256i *)r, a);
147+
}
148+
149+
static MLD_INLINE void mld_polyz_unpack_19_native(int32_t *r, const uint8_t *a)
150+
{
151+
mld_polyz_unpack_19_avx2((__m256i *)r, a);
152+
}
153+
142154
#endif /* !__ASSEMBLER__ */
143155

144156
#endif /* !MLD_NATIVE_X86_64_META_H */

mldsa/native/x86_64/src/arith_native_x86_64.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,4 +72,10 @@ void mld_poly_use_hint_88_avx2(__m256i *b, const __m256i *a, const __m256i *h);
7272
#define mld_poly_chknorm_avx2 MLD_NAMESPACE(mld_poly_chknorm_avx2)
7373
uint32_t mld_poly_chknorm_avx2(const __m256i *a, int32_t B);
7474

75+
#define mld_polyz_unpack_17_avx2 MLD_NAMESPACE(mld_polyz_unpack_17_avx2)
76+
void mld_polyz_unpack_17_avx2(__m256i *r, const uint8_t *a);
77+
78+
#define mld_polyz_unpack_19_avx2 MLD_NAMESPACE(mld_polyz_unpack_19_avx2)
79+
void mld_polyz_unpack_19_avx2(__m256i *r, const uint8_t *a);
80+
7581
#endif /* !MLD_NATIVE_X86_64_SRC_ARITH_NATIVE_X86_64_H */
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
/*
2+
* Copyright (c) The mldsa-native project authors
3+
* SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT
4+
*/
5+
6+
/* References
7+
* ==========
8+
*
9+
* - [REF_AVX2]
10+
* CRYSTALS-Dilithium optimized AVX2 implementation
11+
* Bai, Ducas, Kiltz, Lepoint, Lyubashevsky, Schwabe, Seiler, Stehlé
12+
* https://github.com/pq-crystals/dilithium/tree/master/avx2
13+
*/
14+
15+
/*
16+
* This file is derived from the public domain
17+
* AVX2 Dilithium implementation @[REF_AVX2].
18+
*/
19+
20+
#include "../../../common.h"
21+
22+
#if defined(MLD_ARITH_BACKEND_X86_64_DEFAULT) && \
23+
!defined(MLD_CONFIG_MULTILEVEL_NO_SHARED)
24+
25+
#include <immintrin.h>
26+
#include <stdint.h>
27+
#include "arith_native_x86_64.h"
28+
29+
void mld_polyz_unpack_17_avx2(__m256i *r, const uint8_t *a)
30+
{
31+
unsigned int i;
32+
__m256i f;
33+
const __m256i shufbidx =
34+
_mm256_set_epi8(-1, 9, 8, 7, -1, 7, 6, 5, -1, 5, 4, 3, -1, 3, 2, 1, -1, 8,
35+
7, 6, -1, 6, 5, 4, -1, 4, 3, 2, -1, 2, 1, 0);
36+
const __m256i srlvdidx = _mm256_set_epi32(6, 4, 2, 0, 6, 4, 2, 0);
37+
const __m256i mask = _mm256_set1_epi32(0x3FFFF);
38+
const __m256i gamma1 = _mm256_set1_epi32(MLDSA_GAMMA1);
39+
40+
for (i = 0; i < MLDSA_N / 8; i++)
41+
{
42+
f = _mm256_loadu_si256((__m256i *)&a[18 * i]);
43+
44+
/* Permute 64-bit lanes
45+
* 0x94 = 10010100b rearranges 64-bit lanes as: [3,2,1,0] -> [2,1,1,0]
46+
*
47+
* ╔═══════════════════════════════════════════════════════════════════════╗
48+
* ║ Original Layout ║
49+
* ╚═══════════════════════════════════════════════════════════════════════╝
50+
* ┌─────────────────┬─────────────────┬─────────────────┬─────────────────┐
51+
* │ Lane 0 │ Lane 1 │ Lane 2 │ Lane 3 │
52+
* │ bytes 0..7 │ bytes 8..15 │ bytes 16..23 │ bytes 24..31 │
53+
* └─────────────────┴─────────────────┴─────────────────┴─────────────────┘
54+
*
55+
* ╔═══════════════════════════════════════════════════════════════════════╗
56+
* ║ Layout after permute ║
57+
* ║ Byte indices in high half shifted down by 8 positions ║
58+
* ╚═══════════════════════════════════════════════════════════════════════╝
59+
* ┌───────────────┬─────────────────┐ ┌─────────────────┬─────────────────┐
60+
* │ Lane 0 │ Lane 1 │ │ Lane 2 │ Lane 3 │
61+
* │ bytes 0..7 │ bytes 8..15 │ │ bytes 8..15 │ bytes 16..23 │
62+
* └───────────────┴─────────────────┘ └─────────────────┴─────────────────┘
63+
* Lower 128-bit lane (bytes 0-15) Upper 128-bit lane (bytes 16-31)
64+
*/
65+
f = _mm256_permute4x64_epi64(f, 0x94);
66+
67+
/* Shuffling 8-bit lanes
68+
*
69+
* ┌─ Indices 0-8 into low 128-bit half of permuted vector ────────────────┐
70+
* │ Shuffle: [-1, 8, 7, 6, -1, 6, 5, 4, -1, 4, 3, 2, -1, 2, 1, 0] │
71+
* │ Result: [0, byte8, byte7, byte6, ..., 0, byte2, byte1, byte0] │
72+
* └───────────────────────────────────────────────────────────────────────┘
73+
*
74+
* ┌─ Indices 1-9 into high 128-bit half of permuted vector ───────────────┐
75+
* │ Shuffle: [-1, 9, 8, 7, -1, 7, 6, 5, -1, 5, 4, 3, -1, 3, 2, 1] │
76+
* │ Result: [0, byte17, byte16, byte15, ..., 0, byte11, byte10, byte9] │
77+
* └───────────────────────────────────────────────────────────────────────┘
78+
*/
79+
f = _mm256_shuffle_epi8(f, shufbidx);
80+
81+
/* Keep only 18 out of 24 bits in each 32-bit lane */
82+
/* Bits 0..23 16..39 32..55 48..71
83+
* 72..95 88..111 104..127 120..143 */
84+
f = _mm256_srlv_epi32(f, srlvdidx);
85+
/* Bits 0..23 18..39 36..55 54..71
86+
* 72..95 90..111 108..127 126..143 */
87+
f = _mm256_and_si256(f, mask);
88+
/* Bits 0..17 18..35 36..53 54..71
89+
* 72..89 90..107 108..125 126..143 */
90+
91+
/* Map [0, 1, ..., 2^18-1] to [2^17, 2^17-1, ..., -2^17+1] */
92+
f = _mm256_sub_epi32(gamma1, f);
93+
94+
_mm256_store_si256(&r[i], f);
95+
}
96+
}
97+
98+
#else /* MLD_ARITH_BACKEND_X86_64_DEFAULT && !MLD_CONFIG_MULTILEVEL_NO_SHARED \
99+
*/
100+
101+
MLD_EMPTY_CU(avx2_polyz_unpack)
102+
103+
#endif /* !(MLD_ARITH_BACKEND_X86_64_DEFAULT && \
104+
!MLD_CONFIG_MULTILEVEL_NO_SHARED) */
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
/*
2+
* Copyright (c) The mldsa-native project authors
3+
* SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT
4+
*/
5+
6+
/* References
7+
* ==========
8+
*
9+
* - [REF_AVX2]
10+
* CRYSTALS-Dilithium optimized AVX2 implementation
11+
* Bai, Ducas, Kiltz, Lepoint, Lyubashevsky, Schwabe, Seiler, Stehlé
12+
* https://github.com/pq-crystals/dilithium/tree/master/avx2
13+
*/
14+
15+
/*
16+
* This file is derived from the public domain
17+
* AVX2 Dilithium implementation @[REF_AVX2].
18+
*/
19+
20+
#include "../../../common.h"
21+
22+
#if defined(MLD_ARITH_BACKEND_X86_64_DEFAULT) && \
23+
!defined(MLD_CONFIG_MULTILEVEL_NO_SHARED)
24+
25+
#include <immintrin.h>
26+
#include <stdint.h>
27+
#include "arith_native_x86_64.h"
28+
29+
void mld_polyz_unpack_19_avx2(__m256i *r, const uint8_t *a)
30+
{
31+
unsigned int i;
32+
__m256i f;
33+
const __m256i shufbidx =
34+
_mm256_set_epi8(-1, 11, 10, 9, -1, 9, 8, 7, -1, 6, 5, 4, -1, 4, 3, 2, -1,
35+
9, 8, 7, -1, 7, 6, 5, -1, 4, 3, 2, -1, 2, 1, 0);
36+
/* Equivalent to _mm256_set_epi32(4, 0, 4, 0, 4, 0, 4, 0) */
37+
const __m256i srlvdidx = _mm256_set1_epi64x((uint64_t)4 << 32);
38+
const __m256i mask = _mm256_set1_epi32(0xFFFFF);
39+
const __m256i gamma1 = _mm256_set1_epi32(MLDSA_GAMMA1);
40+
41+
for (i = 0; i < MLDSA_N / 8; i++)
42+
{
43+
f = _mm256_loadu_si256((__m256i *)&a[20 * i]);
44+
45+
/* Permute 64-bit lanes
46+
* 0x94 = 10010100b rearranges 64-bit lanes as: [3,2,1,0] -> [2,1,1,0]
47+
*
48+
* ╔═══════════════════════════════════════════════════════════════════════╗
49+
* ║ Original Layout ║
50+
* ╚═══════════════════════════════════════════════════════════════════════╝
51+
* ┌─────────────────┬─────────────────┬─────────────────┬─────────────────┐
52+
* │ Lane 0 │ Lane 1 │ Lane 2 │ Lane 3 │
53+
* │ bytes 0..7 │ bytes 8..15 │ bytes 16..23 │ bytes 24..31 │
54+
* └─────────────────┴─────────────────┴─────────────────┴─────────────────┘
55+
*
56+
* ╔═══════════════════════════════════════════════════════════════════════╗
57+
* ║ Layout after permute ║
58+
* ║ Byte indices in high half shifted down by 8 positions ║
59+
* ╚═══════════════════════════════════════════════════════════════════════╝
60+
* ┌───────────────┬─────────────────┐ ┌─────────────────┬─────────────────┐
61+
* │ Lane 0 │ Lane 1 │ │ Lane 2 │ Lane 3 │
62+
* │ bytes 0..7 │ bytes 8..15 │ │ bytes 8..15 │ bytes 16..23 │
63+
* └───────────────┴─────────────────┘ └─────────────────┴─────────────────┘
64+
* Lower 128-bit lane (bytes 0-15) Upper 128-bit lane (bytes 16-31)
65+
*/
66+
f = _mm256_permute4x64_epi64(f, 0x94);
67+
68+
/* Shuffling 8-bit lanes
69+
*
70+
* ┌─ Indices 0-9 into low 128-bit half of permuted vector ────────────────┐
71+
* │ Shuffle: [-1, 9, 8, 7, -1, 7, 6, 5, -1, 4, 3, 2, -1, 2, 1, 0] │
72+
* │ Result: [0, byte9, byte8, byte7, ..., 0, byte2, byte1, byte0] │
73+
* └───────────────────────────────────────────────────────────────────────┘
74+
*
75+
* ┌─ Indices 2-11 into high 128-bit half of permuted vector ──────────────┐
76+
* │ Shuffle: [-1, 11, 9, 8, -1, 9, 8, 7, -1, 6, 5, 4, -1, 4, 3, 2] │
77+
* │ Result: [0, byte19, byte18, byte17, ..., 0, byte12, byte11, byte10] │
78+
* └───────────────────────────────────────────────────────────────────────┘
79+
*/
80+
f = _mm256_shuffle_epi8(f, shufbidx);
81+
82+
/* Keep only 20 out of 24 bits in each 32-bit lane */
83+
/* Bits 0..23 16..39 40..63 56..79
84+
* 80..103 96..119 120..143 136..159 */
85+
f = _mm256_srlv_epi32(f, srlvdidx);
86+
/* Bits 0..23 20..39 40..63 60..79
87+
* 80..103 100..119 120..143 140..159 */
88+
f = _mm256_and_si256(f, mask);
89+
/* Bits 0..19 20..39 40..59 60..79
90+
* 80..99 100..119 120..139 140..159 */
91+
92+
/* Map [0, 1, ..., 2^20-1] to [2^19, 2^19-1, ..., -2^19+1] */
93+
f = _mm256_sub_epi32(gamma1, f);
94+
95+
_mm256_store_si256(&r[i], f);
96+
}
97+
}
98+
99+
#else /* MLD_ARITH_BACKEND_X86_64_DEFAULT && !MLD_CONFIG_MULTILEVEL_NO_SHARED \
100+
*/
101+
102+
MLD_EMPTY_CU(avx2_polyz_unpack)
103+
104+
#endif /* !(MLD_ARITH_BACKEND_X86_64_DEFAULT && \
105+
!MLD_CONFIG_MULTILEVEL_NO_SHARED) */

mldsa/poly_kl.c

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -702,9 +702,15 @@ void mld_polyz_pack(uint8_t *r, const mld_poly *a)
702702
MLD_INTERNAL_API
703703
void mld_polyz_unpack(mld_poly *r, const uint8_t *a)
704704
{
705+
#if defined(MLD_USE_NATIVE_POLYZ_UNPACK_17) && MLD_CONFIG_PARAMETER_SET == 44
706+
/* TODO: proof */
707+
mld_polyz_unpack_17_native(r->coeffs, a);
708+
#elif defined(MLD_USE_NATIVE_POLYZ_UNPACK_19) && \
709+
(MLD_CONFIG_PARAMETER_SET == 65 || MLD_CONFIG_PARAMETER_SET == 87)
710+
/* TODO: proof */
711+
mld_polyz_unpack_19_native(r->coeffs, a);
712+
#elif MLD_CONFIG_PARAMETER_SET == 44
705713
unsigned int i;
706-
707-
#if MLD_CONFIG_PARAMETER_SET == 44
708714
for (i = 0; i < MLDSA_N / 4; ++i)
709715
__loop__(
710716
invariant(i <= MLDSA_N/4)
@@ -735,7 +741,11 @@ void mld_polyz_unpack(mld_poly *r, const uint8_t *a)
735741
r->coeffs[4 * i + 2] = MLDSA_GAMMA1 - r->coeffs[4 * i + 2];
736742
r->coeffs[4 * i + 3] = MLDSA_GAMMA1 - r->coeffs[4 * i + 3];
737743
}
738-
#else /* MLD_CONFIG_PARAMETER_SET == 44 */
744+
#else /* !(MLD_USE_NATIVE_POLYZ_UNPACK_17 && MLD_CONFIG_PARAMETER_SET == 44) \
745+
&& !(MLD_USE_NATIVE_POLYZ_UNPACK_19 && (MLD_CONFIG_PARAMETER_SET == \
746+
65 || MLD_CONFIG_PARAMETER_SET == 87)) && MLD_CONFIG_PARAMETER_SET == \
747+
44 */
748+
unsigned int i;
739749
for (i = 0; i < MLDSA_N / 2; ++i)
740750
__loop__(
741751
invariant(i <= MLDSA_N/2)
@@ -755,7 +765,10 @@ void mld_polyz_unpack(mld_poly *r, const uint8_t *a)
755765
r->coeffs[2 * i + 0] = MLDSA_GAMMA1 - r->coeffs[2 * i + 0];
756766
r->coeffs[2 * i + 1] = MLDSA_GAMMA1 - r->coeffs[2 * i + 1];
757767
}
758-
#endif /* MLD_CONFIG_PARAMETER_SET != 44 */
768+
#endif /* !(MLD_USE_NATIVE_POLYZ_UNPACK_17 && MLD_CONFIG_PARAMETER_SET == 44) \
769+
&& !(MLD_USE_NATIVE_POLYZ_UNPACK_19 && (MLD_CONFIG_PARAMETER_SET == \
770+
65 || MLD_CONFIG_PARAMETER_SET == 87)) && MLD_CONFIG_PARAMETER_SET \
771+
!= 44 */
759772

760773
mld_assert_bound(r->coeffs, MLDSA_N, -(MLDSA_GAMMA1 - 1), MLDSA_GAMMA1 + 1);
761774
}

0 commit comments

Comments
 (0)