Skip to content

Commit 5c3b570

Browse files
committed
improve experimental montgomery 2^k-ary pow
1 parent 7161f04 commit 5c3b570

File tree

57 files changed

+92315
-28409
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

57 files changed

+92315
-28409
lines changed

montgomery_arithmetic/include/hurchalla/montgomery_arithmetic/detail/experimental/montgomery_pow_2kary/experimental_montgomery_pow_2kary.h

Lines changed: 536 additions & 334 deletions
Large diffs are not rendered by default.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,307 @@
1+
2+
// -------
3+
// This file is intended to be included multiple times, using different macro definitions
4+
// -------
5+
6+
// The macros to set:
7+
8+
// I was not using this
9+
// HURCHALLA_REQUEST_UNROLL_LOOP_2KARY_TABLESIZE_INIT
10+
// Set this as blank, or to HURCHALLA_REQUEST_UNROLL_LOOP
11+
12+
// I was using this
13+
// HURCHALLA_REQUEST_UNROLL_LOOP_2KARY_ARRAY_SIZE
14+
// Set this as blank, or to HURCHALLA_REQUEST_UNROLL_LOOP
15+
16+
// for 128bit I was NOT unrolling on NUM_TABLES in table init, but I was for 64 bit.
17+
// HURCHALLA_REQUEST_UNROLL_LOOP_2KARY_NUM_TABLES_INIT
18+
// Set this as blank, or to HURCHALLA_REQUEST_UNROLL_LOOP
19+
20+
// I was using this
21+
// HURCHALLA_REQUEST_UNROLL_LOOP_2KARY_TABLE_BITS
22+
// Set this as blank, or to HURCHALLA_REQUEST_UNROLL_LOOP
23+
24+
// I was not using this
25+
// HURCHALLA_REQUEST_UNROLL_LOOP_2KARY_NUM_TABLES_MAINLOOP
26+
// Set this as blank, or to HURCHALLA_REQUEST_UNROLL_LOOP
27+
28+
29+
30+
// For now, I'm not going to utlize these macros, but I could:
31+
// I was using this
32+
// HURCHALLA_REQUEST_UNROLL_LOOP_2KARY_TABLE_BITS_ENDING
33+
// Set this as blank, or to HURCHALLA_REQUEST_UNROLL_LOOP
34+
35+
36+
37+
38+
39+
std::array<std::array<std::array<V, ARRAY_SIZE>, TABLESIZE>, NUM_TABLES> table;
40+
41+
V mont_one = mf.getUnityValue();
42+
43+
HURCHALLA_REQUEST_UNROLL_LOOP_2KARY_ARRAY_SIZE for (size_t q=0; q<ARRAY_SIZE; ++q) {
44+
table[0][0][q] = mont_one;
45+
table[0][1][q] = x[q];
46+
}
47+
if HURCHALLA_CPP17_CONSTEXPR (TABLESIZE >= 4) {
48+
HURCHALLA_REQUEST_UNROLL_LOOP_2KARY_ARRAY_SIZE for (size_t q=0; q<ARRAY_SIZE; ++q)
49+
table[0][2][q] = mf.square(x[q]);
50+
HURCHALLA_REQUEST_UNROLL_LOOP_2KARY_ARRAY_SIZE for (size_t q=0; q<ARRAY_SIZE; ++q)
51+
table[0][3][q] = mf.template multiply<PTAG>(table[0][2][q], x[q]);
52+
}
53+
if HURCHALLA_CPP17_CONSTEXPR (TABLESIZE > 4) {
54+
HURCHALLA_REQUEST_UNROLL_LOOP_2KARY_TABLESIZE_INIT for (size_t i=4; i<TABLESIZE; i+=2) {
55+
size_t j = i/2;
56+
HURCHALLA_REQUEST_UNROLL_LOOP_2KARY_ARRAY_SIZE for (size_t q=0; q<ARRAY_SIZE; ++q)
57+
table[0][i][q] = mf.template square<LowuopsTag>(table[0][j][q]);
58+
HURCHALLA_REQUEST_UNROLL_LOOP_2KARY_ARRAY_SIZE for (size_t q=0; q<ARRAY_SIZE; ++q)
59+
table[0][i+1][q] = mf.template multiply<LowuopsTag>(table[0][j+1][q], table[0][j][q]);
60+
}
61+
}
62+
63+
64+
U n_orig = n;
65+
(void)n_orig; // silence potential unsed var warnings
66+
int shift;
67+
size_t tmp;
68+
if (n > MASKBIG) {
69+
HPBC_CLOCKWORK_ASSERT2(n > 0);
70+
int leading_zeros = count_leading_zeros(n);
71+
int numbits = ut_numeric_limits<decltype(n)>::digits - leading_zeros;
72+
HPBC_CLOCKWORK_ASSERT2(numbits > NUMBITS_MASKBIG);
73+
shift = numbits - NUMBITS_MASKBIG;
74+
HPBC_CLOCKWORK_ASSERT2(shift > 0);
75+
tmp = static_cast<size_t>(branchless_shift_right(n, shift));
76+
// this preps n ahead of time for the main loop
77+
n = branchless_shift_left(n, leading_zeros + NUMBITS_MASKBIG);
78+
}
79+
else {
80+
shift = 0;
81+
tmp = static_cast<size_t>(n);
82+
}
83+
HPBC_CLOCKWORK_ASSERT2(shift >= 0);
84+
85+
HPBC_CLOCKWORK_ASSERT2(tmp <= MASKBIG);
86+
87+
88+
std::array<V, ARRAY_SIZE> result;
89+
HURCHALLA_REQUEST_UNROLL_LOOP_2KARY_ARRAY_SIZE for (size_t q=0; q<ARRAY_SIZE; ++q) {
90+
result[q] = table[0][tmp & MASK][q];
91+
}
92+
93+
94+
// constexpr int digitsRU = hurchalla::ut_numeric_limits<typename MFE_LU::RU>::digits;
95+
96+
HURCHALLA_REQUEST_UNROLL_LOOP_2KARY_NUM_TABLES_INIT for (size_t k=1; k < NUM_TABLES; ++k) {
97+
if HURCHALLA_CPP17_CONSTEXPR (UseEarlyExitInInit) {
98+
// this part could be removed - it provides fast return when n_orig is small.
99+
size_t limit_in_progress = static_cast<size_t>(1) << (k * TABLE_BITS);
100+
if (n_orig < limit_in_progress)
101+
return result;
102+
}
103+
HURCHALLA_REQUEST_UNROLL_LOOP_2KARY_ARRAY_SIZE for (size_t q=0; q<ARRAY_SIZE; ++q) {
104+
table[k][0][q] = mont_one;
105+
table[k][1][q] = mf.square(table[k-1][TABLESIZE/2][q]);
106+
}
107+
if HURCHALLA_CPP17_CONSTEXPR (TABLESIZE >= 4) {
108+
HURCHALLA_REQUEST_UNROLL_LOOP_2KARY_ARRAY_SIZE for (size_t q=0; q<ARRAY_SIZE; ++q)
109+
table[k][2][q] = mf.square(table[k][1][q]);
110+
HURCHALLA_REQUEST_UNROLL_LOOP_2KARY_ARRAY_SIZE for (size_t q=0; q<ARRAY_SIZE; ++q)
111+
table[k][3][q] = mf.template multiply<PTAG>(table[k][2][q], table[k][1][q]);
112+
}
113+
if HURCHALLA_CPP17_CONSTEXPR (TABLESIZE > 4) {
114+
HURCHALLA_REQUEST_UNROLL_LOOP_2KARY_TABLESIZE_INIT for (size_t i=4; i<TABLESIZE; i+=2) {
115+
size_t j = i/2;
116+
HURCHALLA_REQUEST_UNROLL_LOOP_2KARY_ARRAY_SIZE for (size_t q=0; q<ARRAY_SIZE; ++q)
117+
table[k][i][q] = mf.template square<LowuopsTag>(table[k][j][q]);
118+
HURCHALLA_REQUEST_UNROLL_LOOP_2KARY_ARRAY_SIZE for (size_t q=0; q<ARRAY_SIZE; ++q)
119+
table[k][i+1][q] = mf.template multiply<LowuopsTag>(table[k][j+1][q], table[k][j][q]);
120+
}
121+
}
122+
123+
size_t index = (tmp >> (k * TABLE_BITS)) & MASK;
124+
HURCHALLA_REQUEST_UNROLL_LOOP_2KARY_ARRAY_SIZE for (size_t q=0; q<ARRAY_SIZE; ++q)
125+
result[q] = mf.template multiply<LowuopsTag>(table[k][index][q], result[q]);
126+
}
127+
int bits_remaining = shift;
128+
129+
130+
while (bits_remaining >= NUMBITS_MASKBIG) {
131+
if HURCHALLA_CPP17_CONSTEXPR (USE_SQUARING_VALUE_OPTIMIZATION) {
132+
SV sv[ARRAY_SIZE];
133+
HURCHALLA_REQUEST_UNROLL_LOOP_2KARY_ARRAY_SIZE for (size_t q=0; q<ARRAY_SIZE; ++q)
134+
sv[q] = MFE_LU::getSquaringValue(mf, result[q]);
135+
if HURCHALLA_CPP17_CONSTEXPR (USE_SLIDING_WINDOW_OPTIMIZATION) {
136+
while (bits_remaining > NUMBITS_MASKBIG &&
137+
(static_cast<size_t>(n >> high_word_shift) &
138+
(static_cast<size_t>(1) << (digits_smaller - 1))) == 0) {
139+
HURCHALLA_REQUEST_UNROLL_LOOP_2KARY_ARRAY_SIZE for (size_t q=0; q<ARRAY_SIZE; ++q)
140+
sv[q] = MFE_LU::squareSV(mf, sv[q]);
141+
n = static_cast<U>(n << 1);
142+
--bits_remaining;
143+
}
144+
}
145+
HPBC_CLOCKWORK_ASSERT2(bits_remaining >= NUMBITS_MASKBIG);
146+
147+
tmp = static_cast<size_t>(n >> high_word_shift) >> small_shift;
148+
n = static_cast<U>(n << NUMBITS_MASKBIG);
149+
bits_remaining -= NUMBITS_MASKBIG;
150+
151+
V val1[ARRAY_SIZE];
152+
HURCHALLA_REQUEST_UNROLL_LOOP_2KARY_ARRAY_SIZE for (size_t q=0; q<ARRAY_SIZE; ++q)
153+
val1[q] = table[0][tmp & MASK][q];
154+
155+
static_assert(TABLE_BITS >= 1, "");
156+
HURCHALLA_REQUEST_UNROLL_LOOP_2KARY_TABLE_BITS for (size_t i=0; i<TABLE_BITS - 1; ++i) {
157+
HURCHALLA_REQUEST_UNROLL_LOOP_2KARY_ARRAY_SIZE for (size_t q=0; q<ARRAY_SIZE; ++q)
158+
sv[q] = MFE_LU::squareSV(mf, sv[q]);
159+
}
160+
161+
HURCHALLA_REQUEST_UNROLL_LOOP_2KARY_NUM_TABLES_MAINLOOP for (size_t k=1; k<NUM_TABLES; ++k) {
162+
tmp = tmp >> TABLE_BITS;
163+
size_t index = tmp & MASK;
164+
HURCHALLA_REQUEST_UNROLL_LOOP_2KARY_ARRAY_SIZE for (size_t q=0; q<ARRAY_SIZE; ++q)
165+
val1[q] = mf.template multiply<LowuopsTag>(val1[q], table[k][index][q]);
166+
167+
HURCHALLA_REQUEST_UNROLL_LOOP_2KARY_TABLE_BITS for (size_t i=0; i<TABLE_BITS; ++i) {
168+
HURCHALLA_REQUEST_UNROLL_LOOP_2KARY_ARRAY_SIZE for (size_t q=0; q<ARRAY_SIZE; ++q)
169+
sv[q] = MFE_LU::squareSV(mf, sv[q]);
170+
}
171+
}
172+
HURCHALLA_REQUEST_UNROLL_LOOP_2KARY_ARRAY_SIZE for (size_t q=0; q<ARRAY_SIZE; ++q)
173+
result[q] = MFE_LU::squareToMontgomeryValue(mf, sv[q]);
174+
175+
HURCHALLA_REQUEST_UNROLL_LOOP_2KARY_ARRAY_SIZE for (size_t q=0; q<ARRAY_SIZE; ++q)
176+
result[q] = mf.template multiply<PTAG>(result[q], val1[q]);
177+
}
178+
else {
179+
if HURCHALLA_CPP17_CONSTEXPR (USE_SLIDING_WINDOW_OPTIMIZATION) {
180+
while (bits_remaining > NUMBITS_MASKBIG &&
181+
(static_cast<size_t>(n >> high_word_shift) &
182+
(static_cast<size_t>(1) << (digits_smaller - 1))) == 0) {
183+
HURCHALLA_REQUEST_UNROLL_LOOP_2KARY_ARRAY_SIZE for (size_t q=0; q<ARRAY_SIZE; ++q)
184+
result[q] = mf.template square<PTAG>(result[q]);
185+
n = static_cast<U>(n << 1);
186+
--bits_remaining;
187+
}
188+
}
189+
HPBC_CLOCKWORK_ASSERT2(bits_remaining >= NUMBITS_MASKBIG);
190+
191+
tmp = static_cast<size_t>(n >> high_word_shift) >> small_shift;
192+
n = static_cast<U>(n << NUMBITS_MASKBIG);
193+
bits_remaining -= NUMBITS_MASKBIG;
194+
195+
V val1[ARRAY_SIZE];
196+
HURCHALLA_REQUEST_UNROLL_LOOP_2KARY_ARRAY_SIZE for (size_t q=0; q<ARRAY_SIZE; ++q)
197+
val1[q] = table[0][tmp & MASK][q];
198+
199+
HURCHALLA_REQUEST_UNROLL_LOOP_2KARY_TABLE_BITS for (size_t i=0; i<TABLE_BITS; ++i) {
200+
HURCHALLA_REQUEST_UNROLL_LOOP_2KARY_ARRAY_SIZE for (size_t q=0; q<ARRAY_SIZE; ++q)
201+
result[q] = mf.template square<PTAG>(result[q]);
202+
}
203+
204+
HURCHALLA_REQUEST_UNROLL_LOOP_2KARY_NUM_TABLES_MAINLOOP for (size_t k=1; k<NUM_TABLES; ++k) {
205+
tmp = tmp >> TABLE_BITS;
206+
size_t index = tmp & MASK;
207+
HURCHALLA_REQUEST_UNROLL_LOOP_2KARY_ARRAY_SIZE for (size_t q=0; q<ARRAY_SIZE; ++q)
208+
val1[q] = mf.template multiply<LowuopsTag>(val1[q], table[k][index][q]);
209+
210+
HURCHALLA_REQUEST_UNROLL_LOOP_2KARY_TABLE_BITS for (size_t i=0; i<TABLE_BITS; ++i)
211+
HURCHALLA_REQUEST_UNROLL_LOOP_2KARY_ARRAY_SIZE for (size_t q=0; q<ARRAY_SIZE; ++q)
212+
result[q] = mf.template square<PTAG>(result[q]);
213+
}
214+
215+
HURCHALLA_REQUEST_UNROLL_LOOP_2KARY_ARRAY_SIZE for (size_t q=0; q<ARRAY_SIZE; ++q)
216+
result[q] = mf.template multiply<PTAG>(result[q], val1[q]);
217+
}
218+
}
219+
if (bits_remaining == 0)
220+
return result;
221+
222+
HPBC_CLOCKWORK_ASSERT2(0 < bits_remaining && bits_remaining < NUMBITS_MASKBIG);
223+
224+
tmp = static_cast<size_t>(n >> high_word_shift) >> (digits_smaller - bits_remaining);
225+
HPBC_CLOCKWORK_ASSERT2(tmp <= MASKBIG);
226+
227+
V val1[ARRAY_SIZE];
228+
HURCHALLA_REQUEST_UNROLL_LOOP_2KARY_ARRAY_SIZE for (size_t q=0; q<ARRAY_SIZE; ++q)
229+
val1[q] = table[0][tmp & MASK][q];
230+
231+
232+
if HURCHALLA_CPP17_CONSTEXPR (NUM_TABLES <= 2) {
233+
// here we only handle NUM_TABLES <= 2 because when we have larger
234+
// numbers of tables we optimize for that below
235+
HURCHALLA_REQUEST_UNROLL_LOOP for (size_t k=1; k<NUM_TABLES; ++k) {
236+
size_t index = (tmp >> (k * TABLE_BITS)) & MASK;
237+
HURCHALLA_REQUEST_UNROLL_LOOP_2KARY_ARRAY_SIZE for (size_t q=0; q<ARRAY_SIZE; ++q)
238+
val1[q] = mf.template multiply<PTAG>(val1[q], table[k][index][q]);
239+
}
240+
if HURCHALLA_CPP17_CONSTEXPR (USE_SQUARING_VALUE_OPTIMIZATION) {
241+
SV sv[ARRAY_SIZE];
242+
HURCHALLA_REQUEST_UNROLL_LOOP_2KARY_ARRAY_SIZE for (size_t q=0; q<ARRAY_SIZE; ++q)
243+
sv[q] = MFE_LU::getSquaringValue(mf, result[q]);
244+
HPBC_CLOCKWORK_ASSERT2(bits_remaining >= 1);
245+
for (int i=0; i<bits_remaining-1; ++i) {
246+
HURCHALLA_REQUEST_UNROLL_LOOP_2KARY_ARRAY_SIZE for (size_t q=0; q<ARRAY_SIZE; ++q)
247+
sv[q] = MFE_LU::squareSV(mf, sv[q]);
248+
}
249+
HURCHALLA_REQUEST_UNROLL_LOOP_2KARY_ARRAY_SIZE for (size_t q=0; q<ARRAY_SIZE; ++q)
250+
result[q] = MFE_LU::squareToMontgomeryValue(mf, sv[q]);
251+
}
252+
else {
253+
for (int i=0; i<bits_remaining; ++i) {
254+
HURCHALLA_REQUEST_UNROLL_LOOP_2KARY_ARRAY_SIZE for (size_t q=0; q<ARRAY_SIZE; ++q)
255+
result[q] = mf.template square<PTAG>(result[q]);
256+
}
257+
}
258+
}
259+
else {
260+
if HURCHALLA_CPP17_CONSTEXPR (USE_SQUARING_VALUE_OPTIMIZATION) {
261+
SV sv[ARRAY_SIZE];
262+
HURCHALLA_REQUEST_UNROLL_LOOP_2KARY_ARRAY_SIZE for (size_t q=0; q<ARRAY_SIZE; ++q)
263+
sv[q] = MFE_LU::getSquaringValue(mf, result[q]);
264+
int i=0;
265+
for (size_t k=1; i + static_cast<int>(TABLE_BITS) < bits_remaining;
266+
i += static_cast<int>(TABLE_BITS), ++k) {
267+
HURCHALLA_REQUEST_UNROLL_LOOP_2KARY_TABLE_BITS for (size_t h=0; h<TABLE_BITS; ++h) {
268+
HURCHALLA_REQUEST_UNROLL_LOOP_2KARY_ARRAY_SIZE for (size_t q=0; q<ARRAY_SIZE; ++q)
269+
sv[q] = MFE_LU::squareSV(mf, sv[q]);
270+
}
271+
size_t index = (tmp >> (k * TABLE_BITS)) & MASK;
272+
HPBC_CLOCKWORK_ASSERT2(k < NUM_TABLES);
273+
HURCHALLA_REQUEST_UNROLL_LOOP_2KARY_ARRAY_SIZE for (size_t q=0; q<ARRAY_SIZE; ++q)
274+
val1[q] = mf.template multiply<PTAG>(val1[q], table[k][index][q]);
275+
}
276+
HPBC_CLOCKWORK_ASSERT2(bits_remaining >= 1);
277+
HPBC_CLOCKWORK_ASSERT2(i < bits_remaining);
278+
for (; i<bits_remaining-1; ++i) {
279+
HURCHALLA_REQUEST_UNROLL_LOOP_2KARY_ARRAY_SIZE for (size_t q=0; q<ARRAY_SIZE; ++q)
280+
sv[q] = MFE_LU::squareSV(mf, sv[q]);
281+
}
282+
HURCHALLA_REQUEST_UNROLL_LOOP_2KARY_ARRAY_SIZE for (size_t q=0; q<ARRAY_SIZE; ++q)
283+
result[q] = MFE_LU::squareToMontgomeryValue(mf, sv[q]);
284+
}
285+
else {
286+
int i=0;
287+
for (size_t k=1; i + static_cast<int>(TABLE_BITS) < bits_remaining;
288+
i += static_cast<int>(TABLE_BITS), ++k) {
289+
HURCHALLA_REQUEST_UNROLL_LOOP_2KARY_TABLE_BITS for (size_t h=0; h<TABLE_BITS; ++h) {
290+
HURCHALLA_REQUEST_UNROLL_LOOP_2KARY_ARRAY_SIZE for (size_t q=0; q<ARRAY_SIZE; ++q)
291+
result[q] = mf.template square<PTAG>(result[q]);
292+
}
293+
size_t index = (tmp >> (k * TABLE_BITS)) & MASK;
294+
HPBC_CLOCKWORK_ASSERT2(k < NUM_TABLES);
295+
HURCHALLA_REQUEST_UNROLL_LOOP_2KARY_ARRAY_SIZE for (size_t q=0; q<ARRAY_SIZE; ++q)
296+
val1[q] = mf.template multiply<PTAG>(val1[q], table[k][index][q]);
297+
}
298+
for (; i<bits_remaining; ++i) {
299+
HURCHALLA_REQUEST_UNROLL_LOOP_2KARY_ARRAY_SIZE for (size_t q=0; q<ARRAY_SIZE; ++q)
300+
result[q] = mf.template square<PTAG>(result[q]);
301+
}
302+
}
303+
}
304+
305+
HURCHALLA_REQUEST_UNROLL_LOOP_2KARY_ARRAY_SIZE for (size_t q=0; q<ARRAY_SIZE; ++q)
306+
result[q] = mf.template multiply<PTAG>(result[q], val1[q]);
307+
return result;
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# filter_lines.py
2+
3+
import sys
4+
5+
def filter_lines(input_filename, search_string, output_filename):
6+
"""
7+
Reads lines from input_filename and writes to output_filename
8+
all lines that contain search_string.
9+
"""
10+
with open(input_filename, 'r', encoding='utf-8') as infile, \
11+
open(output_filename, 'w', encoding='utf-8') as outfile:
12+
13+
for line in infile:
14+
if search_string in line:
15+
outfile.write(line)
16+
17+
def main():
18+
# Expect exactly three command-line arguments
19+
if len(sys.argv) != 4:
20+
print("Usage: python filter_lines.py <input_file> <search_string> <output_file>")
21+
sys.exit(1)
22+
23+
input_filename = sys.argv[1]
24+
search_string = sys.argv[2]
25+
output_filename = sys.argv[3]
26+
27+
filter_lines(input_filename, search_string, output_filename)
28+
print(f"Lines containing '{search_string}' have been written to '{output_filename}'.")
29+
30+
if __name__ == "__main__":
31+
main()
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import sys
2+
3+
def main():
4+
if len(sys.argv) != 2:
5+
print("Usage: python3 script.py <filename>")
6+
sys.exit(1)
7+
8+
filename = sys.argv[1]
9+
10+
try:
11+
with open(filename, 'r') as file:
12+
lines = file.readlines()
13+
except FileNotFoundError:
14+
print(f"Error: File '{filename}' not found.")
15+
sys.exit(1)
16+
17+
# Find start and end markers
18+
try:
19+
start_index = next(i for i, line in enumerate(lines) if "OVERALL BEST:" in line)
20+
end_index = next(i for i, line in enumerate(lines) if "Timings By Test Type:" in line)
21+
except StopIteration:
22+
print("Error: Could not find required markers in the file.")
23+
sys.exit(1)
24+
25+
# Process lines between the markers
26+
for line in lines[start_index + 1:end_index]:
27+
parts = line.strip().split()
28+
if len(parts) != 7:
29+
continue # skip malformed lines
30+
try:
31+
third_field = int(parts[2])
32+
except ValueError:
33+
continue # skip lines where the third field isn’t an integer
34+
if third_field < 6:
35+
print(line.strip())
36+
return
37+
38+
print("No line found where the third field is less than 6.")
39+
40+
if __name__ == "__main__":
41+
main()

0 commit comments

Comments
 (0)