Skip to content

Commit 644847c

Browse files
iakovenkosclaude
andauthored
chore: pippenger int audit (#19302)
clean up + docs+ a couple of edge case tests Closes AztecProtocol/barretenberg#486 --------- Co-authored-by: Claude Sonnet 4.5 <[email protected]>
1 parent 7ea151a commit 644847c

File tree

9 files changed

+1322
-1144
lines changed

9 files changed

+1322
-1144
lines changed

barretenberg/cpp/CLAUDE.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,14 @@ succint aztec-packages cheat sheet.
22

33
THE PROJECT ROOT IS AT TWO LEVELS ABOVE THIS FOLDER. Typically, the repository is at ~/aztec-packages. all advice is from the root.
44

5+
# Git workflow for barretenberg
6+
7+
**IMPORTANT**: When comparing branches or looking at diffs for barretenberg work, use `merge-train/barretenberg` as the base branch, NOT `master`. The master branch is often outdated for barretenberg development.
8+
9+
Examples:
10+
- `git diff merge-train/barretenberg...HEAD` (not `git diff master...HEAD`)
11+
- `git log merge-train/barretenberg..HEAD` (not `git log master..HEAD`)
12+
513
Run ./bootstrap.sh at the top-level to be sure the repo fully builds.
614
Bootstrap scripts can be called with relative paths e.g. ../barretenberg/bootstrap.sh
715

barretenberg/cpp/scripts/compare_branch_vs_baseline_remote.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ PRESET=${3:-clang20}
1616
BUILD_DIR=${4:-build}
1717
HARDWARE_CONCURRENCY=${HARDWARE_CONCURRENCY:-16}
1818

19-
BASELINE_BRANCH="master"
19+
BASELINE_BRANCH="${BASELINE_BRANCH:-merge-train/barretenberg}"
2020
BENCH_TOOLS_DIR="$BUILD_DIR/_deps/benchmark-src/tools"
2121

2222
if [ ! -z "$(git status --untracked-files=no --porcelain)" ]; then
Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
# Pippenger Multi-Scalar Multiplication (MSM)
2+
3+
## Overview
4+
5+
The Pippenger algorithm computes multi-scalar multiplications:
6+
7+
$$\text{MSM}(\vec{s}, \vec{P}) = \sum_{i=0}^{n-1} s_i \cdot P_i$$
8+
9+
**Complexity**: Let $q = \lceil \log_2(\text{field modulus}) \rceil$ be the scalar bit-length, $|A|$ the cost of a group addition, and $|D|$ the cost of a doubling.
10+
11+
- **Pippenger**: $O\left(\frac{q}{c} \cdot \left((n + 2^c) \cdot |A| + c \cdot |D|\right)\right)$
12+
- **Naive**: $O(n \cdot q \cdot |D| + n \cdot q \cdot |A| / 2)$
13+
14+
With $c \approx \frac{1}{2} \log_2 n$, Pippenger achieves roughly $O(n \cdot q / \log n)$ vs $O(n \cdot q)$ for naive scalar multiplication.
15+
16+
## Algorithm
17+
18+
### Step 1: Scalar Decomposition
19+
20+
**Implementation**: `get_scalar_slice(scalar, round_index, bits_per_slice)`
21+
22+
Each scalar $s_i$ is decomposed into $r$ slices of $c$ bits each, processed **MSB-first**:
23+
24+
$$s_i = \sum_{j=0}^{r-1} s_i^{(j)} \cdot 2^{c(r-1-j)}$$
25+
26+
- $c$ = bits per slice (from `get_optimal_log_num_buckets`, which brute-force searches for minimum cost)
27+
- $r = \lceil $ `NUM_BITS_IN_FIELD` $/ c \rceil$ = number of rounds
28+
- Round 0 extracts the most significant bits
29+
30+
### Step 2: Bucket Accumulation
31+
32+
For each round $j$, points are added into **buckets** based on their scalar slice. Bucket $k$ accumulates all points whose slice value equals $k$:
33+
34+
$$B_k^{(j)} = \sum_{\{i : s_i^{(j)} = k\}} P_i$$
35+
36+
**Two implementation paths:**
37+
38+
- **Affine**: Sorts points by bucket and uses batched affine additions
39+
- **Jacobian**: Direct bucket accumulation in Jacobian coordinates
40+
41+
### Step 3: Bucket Reduction
42+
43+
**Implementation**: `accumulate_buckets(bucket_accumulators)`
44+
45+
Computes weighted sum using a suffix sum (high to low):
46+
47+
$$R^{(j)} = \sum_{k=1}^{2^c - 1} k \cdot B_k^{(j)} = \sum_{k=1}^{2^c - 1} \left( \sum_{m=k}^{2^c - 1} B_m^{(j)} \right)$$
48+
49+
An offset generator is added and subtracted to avoid rare accumulator edge cases—a probabilistic mitigation that simplifies accumulation logic.
50+
51+
### Step 4: Round Combination
52+
53+
Combines all rounds using Horner's method (MSB-first):
54+
55+
```cpp
56+
msm_accumulator = point_at_infinity
57+
for j = 0 to r-1:
58+
repeat c doublings (or fewer for final round)
59+
msm_accumulator += bucket_result[j]
60+
```
61+
62+
## Algorithm Variants
63+
64+
### Entry Points and Safety
65+
66+
| Entry Point | Default | Safety |
67+
|-------------|---------|--------|
68+
| `msm()` | `handle_edge_cases=false` | ⚠️ **Unsafe** |
69+
| `pippenger()` | `handle_edge_cases=true` | ✓ Safe |
70+
| `pippenger_unsafe()` | `handle_edge_cases=false` | ⚠️ Unsafe |
71+
| `batch_multi_scalar_mul()` | `handle_edge_cases=true` | ✓ Safe |
72+
73+
### Edge Cases
74+
75+
Affine addition fails for **P = Q** (doubling), **P = −Q** (inverse), and **P = O** (identity). Jacobian coordinates handle these correctly at higher cost (~2-3× slower).
76+
77+
⚠️ **Use `msm()` or `pippenger_unsafe()` only when points are guaranteed linearly independent** (e.g., SRS points). For user-controlled or potentially duplicate points, use `pippenger()`.
78+
79+
### Affine Pippenger (`handle_edge_cases=false`)
80+
81+
Uses affine coordinates with Montgomery's batch inversion trick: replaces $m$ inversions with **1 inversion + O(m) multiplications**, yielding ~2-3× speedup over Jacobian.
82+
83+
### Jacobian Pippenger (`handle_edge_cases=true`)
84+
85+
Uses Jacobian coordinates for bucket accumulators. Handles all edge cases correctly.
86+
87+
## Tuning Constants
88+
89+
| Constant | Value | Purpose |
90+
|----------|-------|---------|
91+
| `PIPPENGER_THRESHOLD` | 16 | Below this, use naive scalar multiplication |
92+
| `AFFINE_TRICK_THRESHOLD` | 128 | Below this, batch inversion overhead exceeds savings |
93+
| `MAX_SLICE_BITS` | 20 | Upper bound on bucket count exponent |
94+
| `BATCH_SIZE` | 2048 | Points per batch inversion (fits L2 cache) |
95+
| `RADIX_BITS` | 8 | Bits per radix sort pass |
96+
97+
<details>
98+
<summary>Cost model constants and derivations</summary>
99+
100+
| Constant | Value | Derivation |
101+
|----------|-------|------------|
102+
| `BUCKET_ACCUMULATION_COST` | 5 | 2 Jacobian adds/bucket × 2.5× cost ratio |
103+
| `AFFINE_TRICK_SAVINGS_PER_OP` | 5 | ~10 muls saved − ~3 muls for product tree |
104+
| `JACOBIAN_Z_NOT_ONE_PENALTY` | 5 | Extra field ops when Z ≠ 1 |
105+
| `INVERSION_TABLE_COST` | 14 | 4-bit lookup table for modular exp |
106+
107+
**BATCH_SIZE=2048**: Each `AffineElement` is 64 bytes. 2048 points = 128 KB, fitting in L2 cache.
108+
109+
**RADIX_BITS=8**: 256 radix buckets × 4 bytes = 1 KB counting array, fits in L1 cache.
110+
111+
</details>
112+
113+
## Implementation Notes
114+
115+
### Zero Scalar Filtering
116+
117+
`transform_scalar_and_get_nonzero_scalar_indices` filters out zero scalars before processing (since $0 \cdot P_i = \mathcal{O}$). Scalars are converted from Montgomery form in-place to avoid doubling memory usage.
118+
119+
### Bucket Existence Tracking
120+
121+
A `BitVector` bitmap tracks which buckets are populated, avoiding expensive full-array clears between rounds. Clearing the bitmap costs $O(2^c / 64)$ words vs $O(2^c)$ for the full bucket array.
122+
123+
### Point Scheduling (Affine Variant Only)
124+
125+
Entries are packed as `(point_index << 32) | bucket_index` into 64-bit values. Since bucket indices fit in $c$ bits (typically 8-16), they occupy only the lowest bits of the packed entry. An **in-place MSD radix sort** on the low $c$ bits groups points by bucket for efficient batch processing. The sort also detects entries with `bucket_index == 0` during the final radix pass, allowing zero-bucket entries to be skipped without a separate scan.
126+
127+
### Batched Affine Addition
128+
129+
`batch_accumulate_points_into_buckets` processes sorted points iteratively:
130+
- Same-bucket pairs → queue for batch addition
131+
- Different buckets → cache in bucket or queue with existing accumulator
132+
- Uses branchless conditional moves to minimize pipeline stalls
133+
- Prefetches future points to hide memory latency
134+
- Recirculates results to maximize batch efficiency before writing to buckets
135+
136+
<details>
137+
<summary>Batch accumulation case analysis</summary>
138+
139+
| Condition | Action | Iterator Update |
140+
|-----------|--------|-----------------|
141+
| `bucket[i] == bucket[i+1]` | Queue both points for batch add | `point_it += 2` |
142+
| Different buckets, accumulator exists | Queue point + accumulator | `point_it += 1` |
143+
| Different buckets, no accumulator | Cache point into bucket | `point_it += 1` |
144+
145+
After batch addition, results targeting the same bucket are paired again before writing to bucket accumulators, reducing random memory access by ~50%.
146+
147+
</details>
148+
149+
## Parallelization
150+
151+
Uses **per-thread buffers** (bucket accumulators, scratch space) to eliminate contention.
152+
153+
For `batch_multi_scalar_mul()`, work is distributed via `MSMWorkUnit` structures that can split a single MSM across multiple threads. Each thread computes partial results on point subsets, combined in a final reduction.
154+
155+
<details>
156+
<summary>Per-call buffer sizes</summary>
157+
158+
| Buffer | Size | Purpose |
159+
|--------|------|---------|
160+
| `BucketAccumulators` (affine) | $2^c × 64$ bytes | Affine bucket array + bitmap |
161+
| `JacobianBucketAccumulators` | $2^c × 96$ bytes | Jacobian bucket array + bitmap |
162+
| `AffineAdditionData` | ~400 KB | Scratch for batch inversion |
163+
| `point_schedule` | $n × 8$ bytes | Per-MSM point schedule |
164+
165+
Buffers are allocated per-call for WASM compatibility. Memory scales with thread count during parallel execution.
166+
167+
</details>
168+
169+
## File Structure
170+
171+
```
172+
scalar_multiplication/
173+
├── scalar_multiplication.hpp # MSM class, data structures
174+
├── scalar_multiplication.cpp # Core algorithm
175+
├── process_buckets.hpp/cpp # Radix sort
176+
├── bitvector.hpp # Bit vector for bucket tracking
177+
└── README.md # This file
178+
```
179+
180+
## References
181+
182+
1. Pippenger, N. (1976). "On the evaluation of powers and related problems"
183+
2. Bernstein, D.J. et al. "Faster batch forgery identification" (batch inversion)

barretenberg/cpp/src/barretenberg/ecc/scalar_multiplication/process_buckets.cpp

Lines changed: 50 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -10,89 +10,97 @@
1010

1111
namespace bb::scalar_multiplication {
1212

13-
// NOLINTNEXTLINE(misc-no-recursion) recursion is fine here, max recursion depth is 8 (64 bit int / 8 bits per call)
13+
// NOLINTNEXTLINE(misc-no-recursion) recursion is fine here, max depth is 4 (32-bit bucket index / 8 bits per call)
1414
void radix_sort_count_zero_entries(uint64_t* keys,
1515
const size_t num_entries,
1616
const uint32_t shift,
1717
size_t& num_zero_entries,
18-
const uint32_t total_bits,
19-
const uint64_t* start_pointer) noexcept
18+
const uint32_t bucket_index_bits,
19+
const uint64_t* top_level_keys) noexcept
2020
{
21-
constexpr size_t num_bits = 8;
22-
constexpr size_t num_buckets = 1UL << num_bits;
23-
constexpr uint32_t mask = static_cast<uint32_t>(num_buckets) - 1U;
24-
std::array<uint32_t, num_buckets> bucket_counts{};
21+
constexpr size_t NUM_RADIX_BUCKETS = 1UL << RADIX_BITS;
22+
constexpr uint32_t RADIX_MASK = static_cast<uint32_t>(NUM_RADIX_BUCKETS) - 1U;
2523

24+
// Step 1: Count entries in each radix bucket
25+
std::array<uint32_t, NUM_RADIX_BUCKETS> bucket_counts{};
2626
for (size_t i = 0; i < num_entries; ++i) {
27-
bucket_counts[(keys[i] >> shift) & mask]++;
27+
bucket_counts[(keys[i] >> shift) & RADIX_MASK]++;
2828
}
2929

30-
std::array<uint32_t, num_buckets + 1> offsets;
31-
std::array<uint32_t, num_buckets + 1> offsets_copy;
30+
// Step 2: Convert counts to cumulative offsets (prefix sum)
31+
std::array<uint32_t, NUM_RADIX_BUCKETS + 1> offsets;
32+
std::array<uint32_t, NUM_RADIX_BUCKETS + 1> offsets_copy;
3233
offsets[0] = 0;
33-
34-
for (size_t i = 0; i < num_buckets - 1; ++i) {
34+
for (size_t i = 0; i < NUM_RADIX_BUCKETS - 1; ++i) {
3535
bucket_counts[i + 1] += bucket_counts[i];
3636
}
37-
if ((shift == 0) && (keys == start_pointer)) {
37+
38+
// Count zero entries only at the final recursion level (shift == 0) and only for the full array
39+
if ((shift == 0) && (keys == top_level_keys)) {
3840
num_zero_entries = bucket_counts[0];
3941
}
40-
for (size_t i = 1; i < num_buckets + 1; ++i) {
42+
43+
for (size_t i = 1; i < NUM_RADIX_BUCKETS + 1; ++i) {
4144
offsets[i] = bucket_counts[i - 1];
4245
}
43-
for (size_t i = 0; i < num_buckets + 1; ++i) {
46+
for (size_t i = 0; i < NUM_RADIX_BUCKETS + 1; ++i) {
4447
offsets_copy[i] = offsets[i];
4548
}
46-
uint64_t* start = &keys[0];
4749

48-
for (size_t i = 0; i < num_buckets; ++i) {
50+
// Step 3: In-place permutation using cycle sort
51+
// For each radix bucket, repeatedly swap elements to their correct positions until all elements
52+
// in that bucket's range belong there. The offsets array tracks the next write position for each bucket.
53+
uint64_t* start = &keys[0];
54+
for (size_t i = 0; i < NUM_RADIX_BUCKETS; ++i) {
4955
uint64_t* bucket_start = &keys[offsets[i]];
5056
const uint64_t* bucket_end = &keys[offsets_copy[i + 1]];
5157
while (bucket_start != bucket_end) {
5258
for (uint64_t* it = bucket_start; it < bucket_end; ++it) {
53-
const size_t value = (*it >> shift) & mask;
59+
const size_t value = (*it >> shift) & RADIX_MASK;
5460
const uint64_t offset = offsets[value]++;
5561
std::iter_swap(it, start + offset);
5662
}
5763
bucket_start = &keys[offsets[i]];
5864
}
5965
}
66+
67+
// Step 4: Recursively sort each bucket by the next less-significant byte
6068
if (shift > 0) {
61-
for (size_t i = 0; i < num_buckets; ++i) {
62-
if (offsets_copy[i + 1] - offsets_copy[i] > 1) {
63-
radix_sort_count_zero_entries(&keys[offsets_copy[i]],
64-
offsets_copy[i + 1] - offsets_copy[i],
65-
shift - 8,
66-
num_zero_entries,
67-
total_bits,
68-
keys);
69+
for (size_t i = 0; i < NUM_RADIX_BUCKETS; ++i) {
70+
const size_t bucket_size = offsets_copy[i + 1] - offsets_copy[i];
71+
if (bucket_size > 1) {
72+
radix_sort_count_zero_entries(
73+
&keys[offsets_copy[i]], bucket_size, shift - RADIX_BITS, num_zero_entries, bucket_index_bits, keys);
6974
}
7075
}
7176
}
7277
}
7378

74-
size_t process_buckets_count_zero_entries(uint64_t* wnaf_entries,
75-
const size_t num_entries,
76-
const uint32_t num_bits) noexcept
79+
size_t sort_point_schedule_and_count_zero_buckets(uint64_t* point_schedule,
80+
const size_t num_entries,
81+
const uint32_t bucket_index_bits) noexcept
7782
{
7883
if (num_entries == 0) {
7984
return 0;
8085
}
81-
const uint32_t bits_per_round = 8;
82-
const uint32_t base = num_bits & 7;
83-
const uint32_t total_bits = (base == 0) ? num_bits : num_bits - base + 8;
84-
const uint32_t shift = total_bits - bits_per_round;
86+
87+
// Round bucket_index_bits up to next multiple of RADIX_BITS for proper MSD radix sort alignment.
88+
// E.g., if bucket_index_bits=10, we need to start sorting from bit 16 (2 bytes) not bit 10.
89+
const uint32_t remainder = bucket_index_bits % RADIX_BITS;
90+
const uint32_t padded_bits = (remainder == 0) ? bucket_index_bits : bucket_index_bits - remainder + RADIX_BITS;
91+
const uint32_t initial_shift = padded_bits - RADIX_BITS;
92+
8593
size_t num_zero_entries = 0;
86-
radix_sort_count_zero_entries(wnaf_entries, num_entries, shift, num_zero_entries, num_bits, wnaf_entries);
87-
88-
// inside radix_sort_count_zero_entries, if the least significant *byte* of `wnaf_entries[0] == 0`,
89-
// then num_nonzero_entries = number of entries that share the same value as wnaf_entries[0].
90-
// If wnaf_entries[0] != 0, we must manually set num_zero_entries = 0
91-
if (num_entries > 0) {
92-
if ((wnaf_entries[0] & 0xffffffff) != 0) {
93-
num_zero_entries = 0;
94-
}
94+
radix_sort_count_zero_entries(
95+
point_schedule, num_entries, initial_shift, num_zero_entries, bucket_index_bits, point_schedule);
96+
97+
// The radix sort counts entries where the least significant BYTE is zero, but we need entries where
98+
// the entire bucket_index (lower 32 bits) is zero. Verify the first entry after sorting.
99+
if ((point_schedule[0] & BUCKET_INDEX_MASK) != 0) {
100+
num_zero_entries = 0;
95101
}
102+
96103
return num_zero_entries;
97104
}
105+
98106
} // namespace bb::scalar_multiplication

0 commit comments

Comments
 (0)