Skip to content

Commit dd3eaa7

Browse files
committed
ecmult: Refactor ecmult algo selection
1 parent 4721e07 commit dd3eaa7

File tree

6 files changed

+546
-824
lines changed

6 files changed

+546
-824
lines changed

src/bench_ecmult.c

Lines changed: 59 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,11 @@
1919

2020
#define POINTS 32768
2121

22+
/* Default memory limit (64 MB) */
23+
#define DEFAULT_MEM_LIMIT (64 * 1024 * 1024)
24+
/* Select bench algorithm automatically */
25+
#define BENCH_ALGO_AUTO (-1)
26+
2227
static void help(char **argv, int default_iters) {
2328
printf("Benchmark EC multiplication algorithms\n");
2429
printf("\n");
@@ -30,23 +35,25 @@ static void help(char **argv, int default_iters) {
3035
printf("function name. The letter 'g' indicates that one of the points is the generator.\n");
3136
printf("The benchmarks are divided by the number of points.\n");
3237
printf("\n");
33-
printf("default (ecmult_multi): picks pippenger_wnaf or strauss_wnaf depending on the\n");
34-
printf(" batch size\n");
35-
printf("pippenger_wnaf: for all batch sizes\n");
36-
printf("strauss_wnaf: for all batch sizes\n");
37-
printf("simple: multiply and sum each point individually\n");
38+
printf("default (auto): automatically select best algorithm\n");
39+
printf("pippenger_wnaf: for all batch sizes\n");
40+
printf("strauss_wnaf: for all batch sizes\n");
41+
printf("simple: multiply and sum each point individually\n");
42+
printf("\n");
3843
}
3944

4045
typedef struct {
4146
/* Setup once in advance */
4247
secp256k1_context* ctx;
43-
secp256k1_scratch_space* scratch;
4448
secp256k1_scalar* scalars;
4549
secp256k1_ge* pubkeys;
4650
secp256k1_gej* pubkeys_gej;
4751
secp256k1_scalar* seckeys;
4852
secp256k1_gej* expected_output;
49-
secp256k1_ecmult_multi_func ecmult_multi;
53+
54+
/* Algorithm selection */
55+
int forced_algo;
56+
size_t mem_limit;
5057

5158
/* Changes per benchmark */
5259
size_t count;
@@ -214,32 +221,54 @@ static void run_ecmult_bench(bench_data* data, int iters) {
214221
run_benchmark(str, bench_ecmult_1p_g, bench_ecmult_setup, bench_ecmult_1p_g_teardown, data, 10, 2*iters);
215222
}
216223

217-
static int bench_ecmult_multi_callback(secp256k1_scalar* sc, secp256k1_ge* ge, size_t idx, void* arg) {
218-
bench_data* data = (bench_data*)arg;
219-
if (data->includes_g) ++idx;
220-
if (idx == 0) {
221-
*sc = data->scalars[data->offset1];
222-
*ge = secp256k1_ge_const_g;
223-
} else {
224-
*sc = data->scalars[(data->offset1 + idx) % POINTS];
225-
*ge = data->pubkeys[(data->offset2 + idx - 1) % POINTS];
226-
}
227-
return 1;
228-
}
229-
230224
static void bench_ecmult_multi(void* arg, int iters) {
231225
bench_data* data = (bench_data*)arg;
232226

233227
int includes_g = data->includes_g;
234228
int iter;
235229
int count = data->count;
230+
size_t n_points = count - includes_g;
231+
secp256k1_ecmult_multi_algo algo;
232+
secp256k1_ge *points = NULL;
233+
secp256k1_scalar *scalars = NULL;
234+
size_t i;
236235
iters = iters / data->count;
237236

237+
if (n_points > 0) {
238+
points = (secp256k1_ge *)malloc(n_points * sizeof(secp256k1_ge));
239+
scalars = (secp256k1_scalar *)malloc(n_points * sizeof(secp256k1_scalar));
240+
CHECK(points != NULL);
241+
CHECK(scalars != NULL);
242+
}
243+
238244
for (iter = 0; iter < iters; ++iter) {
239-
data->ecmult_multi(&data->ctx->error_callback, data->scratch, &data->output[iter], data->includes_g ? &data->scalars[data->offset1] : NULL, bench_ecmult_multi_callback, arg, count - includes_g);
245+
const secp256k1_scalar *g_scalar_ptr = NULL;
246+
247+
if (includes_g) {
248+
g_scalar_ptr = &data->scalars[data->offset1];
249+
}
250+
251+
for (i = 0; i < n_points; ++i) {
252+
size_t idx = includes_g ? i + 1 : i;
253+
scalars[i] = data->scalars[(data->offset1 + idx) % POINTS];
254+
points[i] = data->pubkeys[(data->offset2 + i) % POINTS];
255+
}
256+
257+
if (data->forced_algo >= 0) {
258+
algo = data->forced_algo;
259+
} else {
260+
algo = secp256k1_ecmult_multi_select(data->mem_limit, n_points);
261+
}
262+
263+
CHECK(secp256k1_ecmult_multi_internal(&data->ctx->error_callback, algo, &data->output[iter],
264+
n_points, points, scalars, g_scalar_ptr));
265+
240266
data->offset1 = (data->offset1 + count) % POINTS;
241267
data->offset2 = (data->offset2 + count - 1) % POINTS;
242268
}
269+
270+
free(points);
271+
free(scalars);
243272
}
244273

245274
static void bench_ecmult_multi_setup(void* arg) {
@@ -309,12 +338,12 @@ static void run_ecmult_multi_bench(bench_data* data, size_t count, int includes_
309338
int main(int argc, char **argv) {
310339
bench_data data;
311340
int i, p;
312-
size_t scratch_size;
313341

314342
int default_iters = 10000;
315343
int iters = get_iters(default_iters);
316344

317-
data.ecmult_multi = secp256k1_ecmult_multi_var;
345+
data.forced_algo = BENCH_ALGO_AUTO;
346+
data.mem_limit = DEFAULT_MEM_LIMIT;
318347

319348
if (argc > 1) {
320349
if(have_flag(argc, argv, "-h")
@@ -324,12 +353,17 @@ int main(int argc, char **argv) {
324353
return EXIT_SUCCESS;
325354
} else if(have_flag(argc, argv, "pippenger_wnaf")) {
326355
printf("Using pippenger_wnaf:\n");
327-
data.ecmult_multi = secp256k1_ecmult_pippenger_batch_single;
356+
/* TODO: Make this a dynamic selection again */
357+
data.forced_algo = SECP256K1_ECMULT_MULTI_ALGO_PIPPENGER_4;
328358
} else if(have_flag(argc, argv, "strauss_wnaf")) {
329359
printf("Using strauss_wnaf:\n");
330-
data.ecmult_multi = secp256k1_ecmult_strauss_batch_single;
360+
data.forced_algo = SECP256K1_ECMULT_MULTI_ALGO_STRAUSS;
331361
} else if(have_flag(argc, argv, "simple")) {
332362
printf("Using simple algorithm:\n");
363+
data.forced_algo = SECP256K1_ECMULT_MULTI_ALGO_TRIVIAL;
364+
} else if(have_flag(argc, argv, "auto")) {
365+
printf("Using automatic algorithm selection:\n");
366+
data.forced_algo = BENCH_ALGO_AUTO;
333367
} else {
334368
fprintf(stderr, "%s: unrecognized argument '%s'.\n\n", argv[0], argv[1]);
335369
help(argv, default_iters);
@@ -338,12 +372,6 @@ int main(int argc, char **argv) {
338372
}
339373

340374
data.ctx = secp256k1_context_create(SECP256K1_CONTEXT_NONE);
341-
scratch_size = secp256k1_strauss_scratch_size(POINTS) + STRAUSS_SCRATCH_OBJECTS*ALIGNMENT;
342-
if (!have_flag(argc, argv, "simple")) {
343-
data.scratch = secp256k1_scratch_space_create(data.ctx, scratch_size);
344-
} else {
345-
data.scratch = NULL;
346-
}
347375

348376
/* Allocate stuff */
349377
data.scalars = malloc(sizeof(secp256k1_scalar) * POINTS);
@@ -389,9 +417,6 @@ int main(int argc, char **argv) {
389417
printf("Skipping some benchmarks due to SECP256K1_BENCH_ITERS <= 2\n");
390418
}
391419

392-
if (data.scratch != NULL) {
393-
secp256k1_scratch_space_destroy(data.ctx, data.scratch);
394-
}
395420
secp256k1_context_destroy(data.ctx);
396421
free(data.scalars);
397422
free(data.pubkeys);

src/ecmult.h

Lines changed: 76 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010
#include "group.h"
1111
#include "scalar.h"
12-
#include "scratch.h"
1312

1413
#ifndef ECMULT_WINDOW_SIZE
1514
# define ECMULT_WINDOW_SIZE 15
@@ -43,19 +42,84 @@
4342
/** Double multiply: R = na*A + ng*G */
4443
static void secp256k1_ecmult(secp256k1_gej *r, const secp256k1_gej *a, const secp256k1_scalar *na, const secp256k1_scalar *ng);
4544

46-
typedef int (secp256k1_ecmult_multi_callback)(secp256k1_scalar *sc, secp256k1_ge *pt, size_t idx, void *data);
45+
/**
46+
* Algorithm identifiers for multi-scalar multiplication.
47+
*
48+
* TRIVIAL: Simple algorithm, no extra memory needed
49+
* STRAUSS: Strauss algorithm (efficient for small batches)
50+
* PIPPENGER_n: Pippenger algorithm with bucket window size n
51+
*/
52+
typedef enum {
53+
SECP256K1_ECMULT_MULTI_ALGO_TRIVIAL = 0,
54+
SECP256K1_ECMULT_MULTI_ALGO_STRAUSS = 1,
55+
SECP256K1_ECMULT_MULTI_ALGO_PIPPENGER_1 = 2,
56+
SECP256K1_ECMULT_MULTI_ALGO_PIPPENGER_2 = 3,
57+
SECP256K1_ECMULT_MULTI_ALGO_PIPPENGER_3 = 4,
58+
SECP256K1_ECMULT_MULTI_ALGO_PIPPENGER_4 = 5,
59+
SECP256K1_ECMULT_MULTI_ALGO_PIPPENGER_5 = 6,
60+
SECP256K1_ECMULT_MULTI_ALGO_PIPPENGER_6 = 7,
61+
SECP256K1_ECMULT_MULTI_ALGO_PIPPENGER_7 = 8,
62+
SECP256K1_ECMULT_MULTI_ALGO_PIPPENGER_8 = 9,
63+
SECP256K1_ECMULT_MULTI_ALGO_PIPPENGER_9 = 10,
64+
SECP256K1_ECMULT_MULTI_ALGO_PIPPENGER_10 = 11,
65+
SECP256K1_ECMULT_MULTI_ALGO_PIPPENGER_11 = 12,
66+
SECP256K1_ECMULT_MULTI_ALGO_PIPPENGER_12 = 13
67+
} secp256k1_ecmult_multi_algo;
68+
69+
#define SECP256K1_ECMULT_MULTI_NUM_ALGOS 14
4770

4871
/**
49-
* Multi-multiply: R = inp_g_sc * G + sum_i ni * Ai.
50-
* Chooses the right algorithm for a given number of points and scratch space
51-
* size. Resets and overwrites the given scratch space. If the points do not
52-
* fit in the scratch space the algorithm is repeatedly run with batches of
53-
* points. If no scratch space is given then a simple algorithm is used that
54-
* simply multiplies the points with the corresponding scalars and adds them up.
55-
* Returns: 1 on success (including when inp_g_sc is NULL and n is 0)
56-
* 0 if there is not enough scratch space for a single point or
57-
* callback returns 0
72+
* Calculate max batch size for a given memory limit.
73+
*
74+
* For each algorithm, memory usage is modeled as m(x) = A*x + B and
75+
* running time as c(x) = C*x + D, where x is the batch size. This
76+
* function finds the algorithm that minimizes time per operation
77+
* C + D/x at the maximum batch size x = (mem_limit - B) / A.
78+
*
79+
* Returns: The optimal batch size, or 0 if memory is insufficient.
5880
*/
59-
static int secp256k1_ecmult_multi_var(const secp256k1_callback* error_callback, secp256k1_scratch *scratch, secp256k1_gej *r, const secp256k1_scalar *inp_g_sc, secp256k1_ecmult_multi_callback cb, void *cbdata, size_t n);
81+
static size_t secp256k1_ecmult_multi_batch_size(size_t mem_limit);
82+
83+
/**
84+
* Select the best algorithm for a given batch size within the memory
85+
* limit.
86+
*
87+
* Among algorithms that fit within mem_limit for the given batch_size,
88+
* selects the one that minimizes time per operation C + D/batch_size.
89+
*
90+
* Returns: The optimal algorithm identifier.
91+
*/
92+
static secp256k1_ecmult_multi_algo secp256k1_ecmult_multi_select(
93+
size_t mem_limit,
94+
size_t batch_size
95+
);
96+
97+
/**
98+
* Multi-multiply: R = scalar_g * G + sum_i scalars[i] * points[i].
99+
*
100+
* Chooses the right algorithm for the given number of points.
101+
*
102+
* Returns: 1 on success, 0 on memory allocation failure.
103+
*/
104+
static int secp256k1_ecmult_multi(
105+
const secp256k1_callback *error_callback,
106+
secp256k1_gej *r,
107+
size_t n_points,
108+
const secp256k1_ge *points,
109+
const secp256k1_scalar *scalars,
110+
const secp256k1_scalar *scalar_g,
111+
size_t mem_limit
112+
);
113+
114+
/* Only for benchmarks and testing */
115+
static int secp256k1_ecmult_multi_internal(
116+
const secp256k1_callback *error_callback,
117+
secp256k1_ecmult_multi_algo algo,
118+
secp256k1_gej *r,
119+
size_t n_points,
120+
const secp256k1_ge *points,
121+
const secp256k1_scalar *scalars,
122+
const secp256k1_scalar *scalar_g
123+
);
60124

61125
#endif /* SECP256K1_ECMULT_H */

0 commit comments

Comments
 (0)