Skip to content

Commit abf8281

Browse files
authored
Merge pull request #510 from pq-code-package/polyz-unpack-asm
Add native implementation of polyz_unpack
2 parents d6f0aff + 40fd58c commit abf8281

File tree

12 files changed

+538
-4
lines changed

12 files changed

+538
-4
lines changed

BIBLIOGRAPHY.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,8 @@ source code and documentation.
156156
- [mldsa/native/x86_64/src/poly_decompose_88_avx2.c](mldsa/native/x86_64/src/poly_decompose_88_avx2.c)
157157
- [mldsa/native/x86_64/src/poly_use_hint_32_avx2.c](mldsa/native/x86_64/src/poly_use_hint_32_avx2.c)
158158
- [mldsa/native/x86_64/src/poly_use_hint_88_avx2.c](mldsa/native/x86_64/src/poly_use_hint_88_avx2.c)
159+
- [mldsa/native/x86_64/src/polyz_unpack_17_avx2.c](mldsa/native/x86_64/src/polyz_unpack_17_avx2.c)
160+
- [mldsa/native/x86_64/src/polyz_unpack_19_avx2.c](mldsa/native/x86_64/src/polyz_unpack_19_avx2.c)
159161
- [mldsa/native/x86_64/src/rej_uniform_avx2.c](mldsa/native/x86_64/src/rej_uniform_avx2.c)
160162
- [mldsa/native/x86_64/src/rej_uniform_eta2_avx2.c](mldsa/native/x86_64/src/rej_uniform_eta2_avx2.c)
161163
- [mldsa/native/x86_64/src/rej_uniform_eta4_avx2.c](mldsa/native/x86_64/src/rej_uniform_eta4_avx2.c)

mldsa/native/aarch64/meta.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
#define MLD_USE_NATIVE_POLY_USE_HINT_32
2020
#define MLD_USE_NATIVE_POLY_USE_HINT_88
2121
#define MLD_USE_NATIVE_POLY_CHKNORM
22+
#define MLD_USE_NATIVE_POLYZ_UNPACK_17
23+
#define MLD_USE_NATIVE_POLYZ_UNPACK_19
2224

2325
/* Identifier for this backend so that source and assembly files
2426
* in the build can be appropriately guarded. */
@@ -133,5 +135,17 @@ static MLD_INLINE uint32_t mld_poly_chknorm_native(const int32_t *a, int32_t B)
133135
return mld_poly_chknorm_asm(a, B);
134136
}
135137

138+
static MLD_INLINE void mld_polyz_unpack_17_native(int32_t *r,
139+
const uint8_t *buf)
140+
{
141+
mld_polyz_unpack_17_asm(r, buf, mld_polyz_unpack_17_indices);
142+
}
143+
144+
static MLD_INLINE void mld_polyz_unpack_19_native(int32_t *r,
145+
const uint8_t *buf)
146+
{
147+
mld_polyz_unpack_19_asm(r, buf, mld_polyz_unpack_19_indices);
148+
}
149+
136150
#endif /* !__ASSEMBLER__ */
137151
#endif /* !MLD_NATIVE_AARCH64_META_H */

mldsa/native/aarch64/src/arith_native_aarch64.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,11 @@ extern const uint8_t mld_rej_uniform_table[];
2929
#define mld_rej_uniform_eta_table MLD_NAMESPACE(rej_uniform_eta_table)
3030
extern const uint8_t mld_rej_uniform_eta_table[];
3131

32+
#define mld_polyz_unpack_17_indices MLD_NAMESPACE(polyz_unpack_17_indices)
33+
extern const uint8_t mld_polyz_unpack_17_indices[];
34+
#define mld_polyz_unpack_19_indices MLD_NAMESPACE(polyz_unpack_19_indices)
35+
extern const uint8_t mld_polyz_unpack_19_indices[];
36+
3237

3338
/*
3439
* Sampling 256 coefficients mod 15 using rejection sampling from 4 bits.
@@ -80,4 +85,12 @@ void mld_poly_use_hint_88_asm(int32_t *b, const int32_t *a, const int32_t *h);
8085
#define mld_poly_chknorm_asm MLD_NAMESPACE(poly_chknorm_asm)
8186
uint32_t mld_poly_chknorm_asm(const int32_t *a, int32_t B);
8287

88+
#define mld_polyz_unpack_17_asm MLD_NAMESPACE(polyz_unpack_17_asm)
89+
void mld_polyz_unpack_17_asm(int32_t *r, const uint8_t *buf,
90+
const uint8_t *indices);
91+
92+
#define mld_polyz_unpack_19_asm MLD_NAMESPACE(polyz_unpack_19_asm)
93+
void mld_polyz_unpack_19_asm(int32_t *r, const uint8_t *buf,
94+
const uint8_t *indices);
95+
8396
#endif /* !MLD_NATIVE_AARCH64_SRC_ARITH_NATIVE_AARCH64_H */
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
/*
2+
* Copyright (c) The mldsa-native project authors
3+
* Copyright (c) The mlkem-native project authors
4+
* SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT
5+
*/
6+
7+
#include "../../../common.h"
8+
#if defined(MLD_ARITH_BACKEND_AARCH64) && !defined(MLD_CONFIG_MULTILEVEL_NO_SHARED)
9+
10+
.macro trim_map_17 a
11+
// Keep only 18 out of 24 bits in each 32-bit lane
12+
// Lane 0 1 2 3
13+
// Bits 0..23 16..39 32..55 48..71
14+
ushl \a\().4s, \a\().4s, shifts.4s
15+
// Bits 0..23 18..39 36..55 54..71
16+
and \a\().16b, \a\().16b, mask.16b
17+
// Bits 0..17 18..35 36..53 54..71
18+
19+
// Map [0, 1, ..., 2^18-1] to [2^17, 2^17-1, ..., -2^17+1]
20+
sub \a\().4s, gamma1.4s, \a\().4s
21+
.endm
22+
23+
/* Parameters */
24+
output .req x0
25+
buf .req x1
26+
indices .req x2
27+
28+
xtmp .req x3
29+
count .req x9
30+
31+
/* Constant register assignments */
32+
idx0 .req v24
33+
idx1 .req v25
34+
idx2 .req v26
35+
idx3 .req v27
36+
shifts .req v28
37+
mask .req v29 // 2^18 - 1
38+
gamma1 .req v30 // 2^17
39+
40+
.text
41+
.global MLD_ASM_NAMESPACE(polyz_unpack_17_asm)
42+
.balign 4
43+
MLD_ASM_FN_SYMBOL(polyz_unpack_17_asm)
44+
// Load indices
45+
ldr q24, [indices]
46+
ldr q25, [indices, #1*16]
47+
ldr q26, [indices, #2*16]
48+
ldr q27, [indices, #3*16]
49+
50+
// Load per-lane shifts 0, -2, -4, -6. (Negative means right shift.)
51+
// The shifts for the 4 32-bit lanes are sign-extended from the lowest
52+
// 8 bits, so it suffices to set up only byte 0, 4, 8, 12.
53+
movz xtmp, 0xfe, lsl 32
54+
mov shifts.d[0], xtmp
55+
movz xtmp, 0xfc
56+
movk xtmp, 0xfa, lsl 32
57+
mov shifts.d[1], xtmp
58+
59+
movi mask.4s, 0x3, msl 16
60+
61+
movi gamma1.4s, 0x2, lsl 16
62+
63+
mov count, #(64/4)
64+
65+
polyz_unpack_17_loop:
66+
ldr q1, [buf, #16]
67+
ldr q2, [buf, #32]
68+
ldr q0, [buf], #36
69+
70+
tbl v4.16b, {v0.16b}, idx0.16b
71+
tbl v5.16b, {v0.16b - v1.16b}, idx1.16b
72+
tbl v6.16b, {v1.16b}, idx2.16b
73+
tbl v7.16b, {v1.16b - v2.16b}, idx3.16b
74+
75+
trim_map_17 v4
76+
trim_map_17 v5
77+
trim_map_17 v6
78+
trim_map_17 v7
79+
80+
str q5, [output, #1*16]
81+
str q6, [output, #2*16]
82+
str q7, [output, #3*16]
83+
str q4, [output], #4*16
84+
85+
subs count, count, #1
86+
bne polyz_unpack_17_loop
87+
88+
ret
89+
90+
.unreq output
91+
.unreq buf
92+
.unreq indices
93+
.unreq xtmp
94+
.unreq count
95+
.unreq idx0
96+
.unreq idx1
97+
.unreq idx2
98+
.unreq idx3
99+
.unreq shifts
100+
.unreq mask
101+
.unreq gamma1
102+
103+
#endif /* MLD_ARITH_BACKEND_AARCH64 && !MLD_CONFIG_MULTILEVEL_NO_SHARED */
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
/*
2+
* Copyright (c) The mldsa-native project authors
3+
* Copyright (c) The mlkem-native project authors
4+
* SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT
5+
*/
6+
7+
#include "../../../common.h"
8+
#if defined(MLD_ARITH_BACKEND_AARCH64) && !defined(MLD_CONFIG_MULTILEVEL_NO_SHARED)
9+
10+
.macro trim_map_19 a
11+
// Keep only 20 out of 24 bits in each 32-bit lane
12+
// Lane 0 1 2 3
13+
// Bits 0..23 16..39 40..63 56..79
14+
ushl \a\().4s, \a\().4s, shifts.4s
15+
// Bits 0..23 20..39 40..63 60..79
16+
and \a\().16b, \a\().16b, mask.16b
17+
// Bits 0..19 20..39 40..59 60..79
18+
19+
// Map [0, 1, ..., 2^20-1] to [2^19, 2^19-1, ..., -2^19+1]
20+
sub \a\().4s, gamma1.4s, \a\().4s
21+
.endm
22+
23+
/* Parameters */
24+
output .req x0
25+
buf .req x1
26+
indices .req x2
27+
28+
xtmp .req x3
29+
count .req x9
30+
31+
/* Constant register assignments */
32+
idx0 .req v24
33+
idx1 .req v25
34+
idx2 .req v26
35+
idx3 .req v27
36+
shifts .req v28
37+
mask .req v29 // 2^20 - 1
38+
gamma1 .req v30 // 2^19
39+
40+
.text
41+
.global MLD_ASM_NAMESPACE(polyz_unpack_19_asm)
42+
.balign 4
43+
MLD_ASM_FN_SYMBOL(polyz_unpack_19_asm)
44+
// Load indices
45+
ldr q24, [indices]
46+
ldr q25, [indices, #1*16]
47+
ldr q26, [indices, #2*16]
48+
ldr q27, [indices, #3*16]
49+
50+
// Load per-lane shifts 0, -4, 0, -4. (Negative means right shift.)
51+
// The shifts for the 4 32-bit lanes are sign-extended from the lowest
52+
// 8 bits, so it suffices to set up only byte 0, 4, 8, 12.
53+
movz xtmp, 0xfc, lsl 32
54+
dup shifts.2d, xtmp
55+
56+
movi mask.4s, 0xf, msl 16
57+
58+
movi gamma1.4s, 0x8, lsl 16
59+
60+
mov count, #(64/4)
61+
62+
polyz_unpack_19_loop:
63+
ldr q1, [buf, #16]
64+
ldr q2, [buf, #32]
65+
ldr q0, [buf], #40
66+
67+
tbl v4.16b, {v0.16b}, idx0.16b
68+
tbl v5.16b, {v0.16b - v1.16b}, idx1.16b
69+
tbl v6.16b, {v1.16b}, idx2.16b
70+
tbl v7.16b, {v1.16b - v2.16b}, idx3.16b
71+
72+
trim_map_19 v4
73+
trim_map_19 v5
74+
trim_map_19 v6
75+
trim_map_19 v7
76+
77+
str q5, [output, #1*16]
78+
str q6, [output, #2*16]
79+
str q7, [output, #3*16]
80+
str q4, [output], #4*16
81+
82+
subs count, count, #1
83+
bne polyz_unpack_19_loop
84+
85+
ret
86+
87+
.unreq output
88+
.unreq buf
89+
.unreq indices
90+
.unreq xtmp
91+
.unreq count
92+
.unreq idx0
93+
.unreq idx1
94+
.unreq idx2
95+
.unreq idx3
96+
.unreq shifts
97+
.unreq mask
98+
.unreq gamma1
99+
100+
#endif /* MLD_ARITH_BACKEND_AARCH64 && !MLD_CONFIG_MULTILEVEL_NO_SHARED */
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
/*
2+
* Copyright (c) The mldsa-native project authors
3+
* SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT
4+
*/
5+
6+
#include "../../../common.h"
7+
8+
#if defined(MLD_ARITH_BACKEND_AARCH64) && \
9+
!defined(MLD_CONFIG_MULTILEVEL_NO_SHARED)
10+
11+
#include <stdint.h>
12+
#include "arith_native_aarch64.h"
13+
14+
/* Table of indices used for tbl instructions in polyz_unpack_{17,19}. */
15+
16+
MLD_ALIGN const uint8_t mld_polyz_unpack_17_indices[] = {
17+
0, 1, 2, -1, 2, 3, 4, -1, 4, 5, 6, -1, 6, 7, 8, -1,
18+
9, 10, 11, -1, 11, 12, 13, -1, 13, 14, 15, -1, 15, 16, 17, -1,
19+
2, 3, 4, -1, 4, 5, 6, -1, 6, 7, 8, -1, 8, 9, 10, -1,
20+
11, 12, 13, -1, 13, 14, 15, -1, 15, 16, 17, -1, 17, 18, 19, -1,
21+
};
22+
23+
MLD_ALIGN const uint8_t mld_polyz_unpack_19_indices[] = {
24+
0, 1, 2, -1, 2, 3, 4, -1, 5, 6, 7, -1, 7, 8, 9, -1,
25+
10, 11, 12, -1, 12, 13, 14, -1, 15, 16, 17, -1, 17, 18, 19, -1,
26+
4, 5, 6, -1, 6, 7, 8, -1, 9, 10, 11, -1, 11, 12, 13, -1,
27+
14, 15, 16, -1, 16, 17, 18, -1, 19, 20, 21, -1, 21, 22, 23, -1,
28+
};
29+
30+
#else /* MLD_ARITH_BACKEND_AARCH64 && !MLD_CONFIG_MULTILEVEL_NO_SHARED */
31+
32+
MLD_EMPTY_CU(aarch64_polyz_unpack_table)
33+
34+
#endif /* !(MLD_ARITH_BACKEND_AARCH64 && !MLD_CONFIG_MULTILEVEL_NO_SHARED) */

mldsa/native/api.h

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,4 +270,32 @@ static MLD_INLINE void mld_poly_use_hint_88_native(int32_t *b, const int32_t *a,
270270
static MLD_INLINE uint32_t mld_poly_chknorm_native(const int32_t *a, int32_t B);
271271
#endif /* MLD_USE_NATIVE_POLY_CHKNORM */
272272

273+
#if defined(MLD_USE_NATIVE_POLYZ_UNPACK_17)
274+
/*************************************************
275+
* Name: mld_polyz_unpack_17_native
276+
*
277+
* Description: Native implementation of polyz_unpack for GAMMA1 = 2^17.
278+
* Unpack polynomial z with coefficients
279+
* in [-(MLDSA_GAMMA1 - 1), MLDSA_GAMMA1].
280+
*
281+
* Arguments: - int32_t *r: pointer to output polynomial
282+
* - const uint8_t *a: byte array with bit-packed polynomial
283+
**************************************************/
284+
static MLD_INLINE void mld_polyz_unpack_17_native(int32_t *r, const uint8_t *a);
285+
#endif /* MLD_USE_NATIVE_POLYZ_UNPACK_17 */
286+
287+
#if defined(MLD_USE_NATIVE_POLYZ_UNPACK_19)
288+
/*************************************************
289+
* Name: mld_polyz_unpack_19_native
290+
*
291+
* Description: Native implementation of polyz_unpack for GAMMA1 = 2^19.
292+
* Unpack polynomial z with coefficients
293+
* in [-(MLDSA_GAMMA1 - 1), MLDSA_GAMMA1].
294+
*
295+
* Arguments: - int32_t *r: pointer to output polynomial
296+
* - const uint8_t *a: byte array with bit-packed polynomial
297+
**************************************************/
298+
static MLD_INLINE void mld_polyz_unpack_19_native(int32_t *r, const uint8_t *a);
299+
#endif /* MLD_USE_NATIVE_POLYZ_UNPACK_19 */
300+
273301
#endif /* !MLD_NATIVE_API_H */

mldsa/native/x86_64/meta.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
#define MLD_USE_NATIVE_POLY_USE_HINT_32
2424
#define MLD_USE_NATIVE_POLY_USE_HINT_88
2525
#define MLD_USE_NATIVE_POLY_CHKNORM
26+
#define MLD_USE_NATIVE_POLYZ_UNPACK_17
27+
#define MLD_USE_NATIVE_POLYZ_UNPACK_19
2628

2729
#if !defined(__ASSEMBLER__)
2830
#include <string.h>
@@ -139,6 +141,16 @@ static MLD_INLINE uint32_t mld_poly_chknorm_native(const int32_t *a, int32_t B)
139141
return mld_poly_chknorm_avx2((const __m256i *)a, B);
140142
}
141143

144+
static MLD_INLINE void mld_polyz_unpack_17_native(int32_t *r, const uint8_t *a)
145+
{
146+
mld_polyz_unpack_17_avx2((__m256i *)r, a);
147+
}
148+
149+
static MLD_INLINE void mld_polyz_unpack_19_native(int32_t *r, const uint8_t *a)
150+
{
151+
mld_polyz_unpack_19_avx2((__m256i *)r, a);
152+
}
153+
142154
#endif /* !__ASSEMBLER__ */
143155

144156
#endif /* !MLD_NATIVE_X86_64_META_H */

mldsa/native/x86_64/src/arith_native_x86_64.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,4 +72,10 @@ void mld_poly_use_hint_88_avx2(__m256i *b, const __m256i *a, const __m256i *h);
7272
#define mld_poly_chknorm_avx2 MLD_NAMESPACE(mld_poly_chknorm_avx2)
7373
uint32_t mld_poly_chknorm_avx2(const __m256i *a, int32_t B);
7474

75+
#define mld_polyz_unpack_17_avx2 MLD_NAMESPACE(mld_polyz_unpack_17_avx2)
76+
void mld_polyz_unpack_17_avx2(__m256i *r, const uint8_t *a);
77+
78+
#define mld_polyz_unpack_19_avx2 MLD_NAMESPACE(mld_polyz_unpack_19_avx2)
79+
void mld_polyz_unpack_19_avx2(__m256i *r, const uint8_t *a);
80+
7581
#endif /* !MLD_NATIVE_X86_64_SRC_ARITH_NATIVE_X86_64_H */

0 commit comments

Comments
 (0)