Skip to content

Commit 389afd6

Browse files
perf: remove preprocessed trace bitwiseoplookupair [cpu & gpu tracegen] (#2366)
Resolves INT-5852.
1 parent 0291751 commit 389afd6

File tree

4 files changed

+153
-68
lines changed

4 files changed

+153
-68
lines changed

crates/circuits/primitives/cuda/src/bitwise_op_lookup.cu

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,15 @@
11
#include "fp.h"
22
#include "launcher.cuh"
3+
#include "primitives/trace_access.h"
4+
5+
constexpr uint32_t NUM_BITS = 8;
6+
7+
template <typename T> struct BitwiseOperationLookupCols {
8+
T x_bits[NUM_BITS];
9+
T y_bits[NUM_BITS];
10+
T mult_range;
11+
T mult_xor;
12+
};
313

414
__global__ void bitwise_op_lookup_tracegen(
515
const uint32_t *count,
@@ -9,9 +19,26 @@ __global__ void bitwise_op_lookup_tracegen(
919
) {
1020
uint32_t row_idx = blockIdx.x * blockDim.x + threadIdx.x;
1121
if (row_idx < num_rows) {
12-
trace[row_idx] = Fp(count[row_idx] + (cpu_count ? cpu_count[row_idx] : 0));
13-
trace[row_idx + num_rows] =
14-
Fp(count[row_idx + num_rows] + (cpu_count ? cpu_count[row_idx + num_rows] : 0));
22+
uint32_t x = row_idx >> NUM_BITS;
23+
uint32_t y = row_idx & ((1U << NUM_BITS) - 1);
24+
25+
Fp x_bits_array[NUM_BITS];
26+
Fp y_bits_array[NUM_BITS];
27+
#pragma unroll
28+
for (uint32_t i = 0; i < NUM_BITS; i++) {
29+
x_bits_array[i] = Fp((x >> i) & 1);
30+
y_bits_array[i] = Fp((y >> i) & 1);
31+
}
32+
33+
uint32_t mult_range_val = count[row_idx] + (cpu_count ? cpu_count[row_idx] : 0);
34+
uint32_t mult_xor_val = count[row_idx + num_rows] +
35+
(cpu_count ? cpu_count[row_idx + num_rows] : 0);
36+
37+
RowSlice row(trace + row_idx, num_rows);
38+
COL_WRITE_ARRAY(row, BitwiseOperationLookupCols, x_bits, x_bits_array);
39+
COL_WRITE_ARRAY(row, BitwiseOperationLookupCols, y_bits, y_bits_array);
40+
COL_WRITE_VALUE(row, BitwiseOperationLookupCols, mult_range, mult_range_val);
41+
COL_WRITE_VALUE(row, BitwiseOperationLookupCols, mult_xor, mult_xor_val);
1542
}
1643
}
1744

crates/circuits/primitives/src/bitwise_op_lookup/README.md

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,21 @@
22

33
XOR operation and range checking via lookup table
44

5-
This chip implements a lookup table approach for XOR operations and range checks for integers of size $`\texttt{NUM\_BITS}`$. The lookup table contains all possible combinations of $`x`$ and $`y`$ values (both in the range $`0..2^{\texttt{NUM\_BITS}}`$), along with their XOR result.
5+
This chip implements a lookup table approach for XOR operations and range checks for integers of size $`\texttt{NUM\_BITS}`$. The chip provides lookup table functionality for all possible combinations of $`x`$ and $`y`$ values (both in the range $`0..2^{\texttt{NUM\_BITS}}`$), enabling verification of XOR operations and range checks. In the trace, $x$ and $y$ are stored as binary decompositions (`x_bits` and `y_bits` arrays) rather than as full field elements.
66

7-
The lookup mechanism works through the Bus interface, with other circuits requesting lookups by incrementing multiplicity counters for the operations they need to perform. Each row in the lookup table corresponds to a specific $(x, y)$ pair.
7+
The lookup mechanism works through the Bus interface, with other circuits requesting lookups by incrementing multiplicity counters for the operations they need to perform. Each row in the trace corresponds to a specific $(x, y)$ pair.
88

9-
**Preprocessed Columns:**
10-
- `x`: Column containing the first input value ($0$ to $`2^{\texttt{NUM\_BITS}}-1`$)
11-
- `y`: Column containing the second input value ($0$ to $`2^{\texttt{NUM\_BITS}}-1`$)
12-
- `z_xor`: Column containing the XOR result of x and y ($x \oplus y$)
9+
The chip uses gate-based constraints to generate the trace columns instead of a preprocessed trace. The trace enumerates all valid $(x, y)$ pairs in order: row $n$ corresponds to $(x, y)$ where $x = \lfloor n / 2^{\texttt{NUM\_BITS}} \rfloor$ and $y = n \bmod 2^{\texttt{NUM\_BITS}}$. The enumeration order is: $(0, 0)$, $(0, 1)$, ..., $(0, 2^{\texttt{NUM\_BITS}}-1)$, $(1, 0)$, $(1, 1)$, ..., up to $(2^{\texttt{NUM\_BITS}}-1, 2^{\texttt{NUM\_BITS}}-1)$.
1310

14-
**IO Columns:**
11+
**Columns:**
12+
- `x_bits[0..NUM_BITS-1]`: Binary decomposition of $x$ (where `x_bits[0]` is the least significant bit)
13+
- `y_bits[0..NUM_BITS-1]`: Binary decomposition of $y$ (where `y_bits[0]` is the least significant bit)
1514
- `mult_range`: Multiplicity column tracking the number of range check operations requested for each $(x, y)$ pair
1615
- `mult_xor`: Multiplicity column tracking the number of XOR operations requested for each $(x, y)$ pair
16+
17+
The constraints enforce the enumeration pattern by:
18+
1. Ensuring each bit is binary (0 or 1) using `assert_bool` constraints
19+
2. Reconstructing $x$ and $y$ from their binary decompositions: $x = \sum_{i=0}^{\texttt{NUM\_BITS}-1} \texttt{x\_bits}[i] \cdot 2^i$
20+
3. Computing $z_{\texttt{xor}} = x \oplus y$ algebraically from bits: $z_{\texttt{xor}} = \sum_{i=0}^{\texttt{NUM\_BITS}-1} (\texttt{x\_bits}[i] + \texttt{y\_bits}[i] - 2 \cdot \texttt{x\_bits}[i] \cdot \texttt{y\_bits}[i]) \cdot 2^i$
21+
4. Constraining that the combined index $(x \cdot 2^{\texttt{NUM\_BITS}} + y)$ increments by 1 each row using transition constraints
22+
5. Enforcing boundary conditions: first row has index 0, last row has index $2^{2 \cdot \texttt{NUM\_BITS}} - 1$

crates/circuits/primitives/src/bitwise_op_lookup/cuda.rs

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@ use openvm_cuda_common::{copy::MemCopyH2D as _, d_buffer::DeviceBuffer};
55
use openvm_stark_backend::{prover::types::AirProvingContext, Chip};
66

77
use crate::{
8-
bitwise_op_lookup::{BitwiseOperationLookupChip, NUM_BITWISE_OP_LOOKUP_COLS},
8+
bitwise_op_lookup::{
9+
BitwiseOperationLookupChip, BitwiseOperationLookupCols, NUM_BITWISE_OP_LOOKUP_MULT_COLS,
10+
},
911
cuda_abi::bitwise_op_lookup::tracegen,
1012
};
1113

@@ -22,7 +24,7 @@ impl<const NUM_BITS: usize> BitwiseOperationLookupChipGPU<NUM_BITS> {
2224
pub fn new() -> Self {
2325
// The first 2^(2 * NUM_BITS) indices are for range checking, the rest are for XOR
2426
let count = Arc::new(DeviceBuffer::<F>::with_capacity(
25-
NUM_BITWISE_OP_LOOKUP_COLS * Self::num_rows(),
27+
NUM_BITWISE_OP_LOOKUP_MULT_COLS * Self::num_rows(),
2628
));
2729
count.fill_zero().unwrap();
2830
Self {
@@ -35,7 +37,7 @@ impl<const NUM_BITS: usize> BitwiseOperationLookupChipGPU<NUM_BITS> {
3537
assert_eq!(cpu_chip.count_range.len(), Self::num_rows());
3638
assert_eq!(cpu_chip.count_xor.len(), Self::num_rows());
3739
let count = Arc::new(DeviceBuffer::<F>::with_capacity(
38-
NUM_BITWISE_OP_LOOKUP_COLS * Self::num_rows(),
40+
NUM_BITWISE_OP_LOOKUP_MULT_COLS * Self::num_rows(),
3941
));
4042
count.fill_zero().unwrap();
4143
Self {
@@ -53,8 +55,9 @@ impl<const NUM_BITS: usize> Default for BitwiseOperationLookupChipGPU<NUM_BITS>
5355

5456
impl<RA, const NUM_BITS: usize> Chip<RA, GpuBackend> for BitwiseOperationLookupChipGPU<NUM_BITS> {
5557
fn generate_proving_ctx(&self, _: RA) -> AirProvingContext<GpuBackend> {
58+
let num_cols = BitwiseOperationLookupCols::<F, NUM_BITS>::width();
5659
debug_assert_eq!(
57-
Self::num_rows() * NUM_BITWISE_OP_LOOKUP_COLS,
60+
NUM_BITWISE_OP_LOOKUP_MULT_COLS * Self::num_rows(),
5861
self.count.len()
5962
);
6063
let cpu_count = self.cpu_chip.as_ref().map(|cpu_chip| {
@@ -69,7 +72,7 @@ impl<RA, const NUM_BITS: usize> Chip<RA, GpuBackend> for BitwiseOperationLookupC
6972
});
7073
// ATTENTION: we create a new buffer to copy `count` into because this chip is stateful and
7174
// `count` will be reused.
72-
let trace = DeviceMatrix::<F>::with_capacity(Self::num_rows(), NUM_BITWISE_OP_LOOKUP_COLS);
75+
let trace = DeviceMatrix::<F>::with_capacity(Self::num_rows(), num_cols);
7376
unsafe {
7477
tracegen(&self.count, &cpu_count, trace.buffer(), NUM_BITS as u32).unwrap();
7578
}

crates/circuits/primitives/src/bitwise_op_lookup/mod.rs

Lines changed: 102 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
11
use std::{
22
borrow::{Borrow, BorrowMut},
3-
mem::size_of,
43
sync::{atomic::AtomicU32, Arc},
54
};
65

76
use openvm_circuit_primitives_derive::AlignedBorrow;
87
use openvm_stark_backend::{
98
config::{StarkGenericConfig, Val},
109
interaction::InteractionBuilder,
11-
p3_air::{Air, BaseAir, PairBuilder},
10+
p3_air::{Air, AirBuilder, BaseAir},
1211
p3_field::{Field, FieldAlgebra},
1312
p3_matrix::{dense::RowMajorMatrix, Matrix},
1413
prover::{cpu::CpuBackend, types::AirProvingContext},
@@ -27,27 +26,21 @@ pub use cuda::*;
2726
#[cfg(test)]
2827
mod tests;
2928

30-
#[derive(Default, AlignedBorrow, Copy, Clone)]
29+
#[derive(AlignedBorrow, Copy, Clone)]
3130
#[repr(C)]
32-
pub struct BitwiseOperationLookupCols<T> {
31+
pub struct BitwiseOperationLookupCols<T, const NUM_BITS: usize> {
32+
/// Binary decomposition of x (x_bits[0] is LSB, x_bits[NUM_BITS-1] is MSB)
33+
pub x_bits: [T; NUM_BITS],
34+
/// Binary decomposition of y (y_bits[0] is LSB, y_bits[NUM_BITS-1] is MSB)
35+
pub y_bits: [T; NUM_BITS],
3336
/// Number of range check operations requested for each (x, y) pair
3437
pub mult_range: T,
3538
/// Number of XOR operations requested for each (x, y) pair
3639
pub mult_xor: T,
3740
}
3841

39-
#[derive(Default, AlignedBorrow, Copy, Clone)]
40-
#[repr(C)]
41-
pub struct BitwiseOperationLookupPreprocessedCols<T> {
42-
pub x: T,
43-
pub y: T,
44-
/// XOR result of x and y (x ⊕ y)
45-
pub z_xor: T,
46-
}
47-
48-
pub const NUM_BITWISE_OP_LOOKUP_COLS: usize = size_of::<BitwiseOperationLookupCols<u8>>();
49-
pub const NUM_BITWISE_OP_LOOKUP_PREPROCESSED_COLS: usize =
50-
size_of::<BitwiseOperationLookupPreprocessedCols<u8>>();
42+
/// Number of multiplicity columns (mult_range and mult_xor)
43+
pub const NUM_BITWISE_OP_LOOKUP_MULT_COLS: usize = 2;
5144

5245
#[derive(Clone, Copy, Debug, derive_new::new)]
5346
pub struct BitwiseOperationLookupAir<const NUM_BITS: usize> {
@@ -64,52 +57,92 @@ impl<F: Field, const NUM_BITS: usize> PartitionedBaseAir<F>
6457
}
6558
impl<F: Field, const NUM_BITS: usize> BaseAir<F> for BitwiseOperationLookupAir<NUM_BITS> {
6659
fn width(&self) -> usize {
67-
NUM_BITWISE_OP_LOOKUP_COLS
68-
}
69-
70-
fn preprocessed_trace(&self) -> Option<RowMajorMatrix<F>> {
71-
let rows: Vec<F> = (0..(1 << NUM_BITS))
72-
.flat_map(|x: u32| {
73-
(0..(1 << NUM_BITS)).flat_map(move |y: u32| {
74-
[
75-
F::from_canonical_u32(x),
76-
F::from_canonical_u32(y),
77-
F::from_canonical_u32(x ^ y),
78-
]
79-
})
80-
})
81-
.collect();
82-
Some(RowMajorMatrix::new(
83-
rows,
84-
NUM_BITWISE_OP_LOOKUP_PREPROCESSED_COLS,
85-
))
60+
BitwiseOperationLookupCols::<F, NUM_BITS>::width()
8661
}
8762
}
8863

89-
impl<AB: InteractionBuilder + PairBuilder, const NUM_BITS: usize> Air<AB>
64+
impl<AB: InteractionBuilder, const NUM_BITS: usize> Air<AB>
9065
for BitwiseOperationLookupAir<NUM_BITS>
9166
{
9267
fn eval(&self, builder: &mut AB) {
93-
let preprocessed = builder.preprocessed();
94-
let prep_local = preprocessed.row_slice(0);
95-
let prep_local: &BitwiseOperationLookupPreprocessedCols<AB::Var> = (*prep_local).borrow();
96-
9768
let main = builder.main();
98-
let local = main.row_slice(0);
99-
let local: &BitwiseOperationLookupCols<AB::Var> = (*local).borrow();
69+
let (local, next) = (main.row_slice(0), main.row_slice(1));
70+
let local: &BitwiseOperationLookupCols<AB::Var, NUM_BITS> = (*local).borrow();
71+
let next: &BitwiseOperationLookupCols<AB::Var, NUM_BITS> = (*next).borrow();
72+
73+
// 1. Binary constraints: ensure each bit is boolean
74+
for i in 0..NUM_BITS {
75+
builder.assert_bool(local.x_bits[i]);
76+
builder.assert_bool(local.y_bits[i]);
77+
}
78+
79+
// 2. Reconstruct x and y from their binary decompositions
80+
// x = Σ(x_bits[i] * 2^i), y = Σ(y_bits[i] * 2^i)
81+
let reconstruct = |bits: &[AB::Var; NUM_BITS]| {
82+
bits.iter()
83+
.enumerate()
84+
.fold(AB::Expr::ZERO, |acc, (i, &bit)| {
85+
acc + bit * AB::Expr::from_canonical_usize(1 << i)
86+
})
87+
};
88+
let x_reconstructed = reconstruct(&local.x_bits);
89+
let y_reconstructed = reconstruct(&local.y_bits);
90+
91+
// 3. Compute z_xor algebraically from bits
92+
// z_xor_bits[i] = x_bits[i] ^ y_bits[i] = x_bits[i] + y_bits[i] - 2 * x_bits[i] * y_bits[i]
93+
// z_xor = Σ(z_xor_bits[i] * 2^i)
94+
let z_xor_reconstructed = local
95+
.x_bits
96+
.iter()
97+
.zip(local.y_bits.iter())
98+
.enumerate()
99+
.fold(AB::Expr::ZERO, |acc, (i, (&x_bit, &y_bit))| {
100+
let xor_bit = x_bit + y_bit - AB::Expr::TWO * x_bit * y_bit;
101+
acc + xor_bit * AB::Expr::from_canonical_usize(1 << i)
102+
});
103+
104+
// 4. Combined index: idx = x * (2^NUM_BITS) + y
105+
let combined_idx = x_reconstructed.clone() * AB::Expr::from_canonical_usize(1 << NUM_BITS)
106+
+ y_reconstructed.clone();
107+
let next_combined_idx = reconstruct(&next.x_bits)
108+
* AB::Expr::from_canonical_usize(1 << NUM_BITS)
109+
+ reconstruct(&next.y_bits);
110+
111+
// 5. Constrain that combined index increments by 1 each row
112+
builder
113+
.when_transition()
114+
.assert_one(next_combined_idx.clone() - combined_idx.clone());
115+
116+
// 6. Boundary constraints: first row has idx = 0, last row has idx = 2^(2*NUM_BITS) - 1
117+
builder.when_first_row().assert_zero(combined_idx.clone());
118+
builder.when_last_row().assert_eq(
119+
combined_idx,
120+
AB::Expr::from_canonical_usize((1 << (2 * NUM_BITS)) - 1),
121+
);
100122

123+
// 7. Use reconstructed values for lookup bus interactions
101124
self.bus
102-
.receive(prep_local.x, prep_local.y, AB::F::ZERO, AB::F::ZERO)
125+
.receive(
126+
x_reconstructed.clone(),
127+
y_reconstructed.clone(),
128+
AB::F::ZERO,
129+
AB::F::ZERO,
130+
)
103131
.eval(builder, local.mult_range);
104132
self.bus
105-
.receive(prep_local.x, prep_local.y, prep_local.z_xor, AB::F::ONE)
133+
.receive(
134+
x_reconstructed,
135+
y_reconstructed,
136+
z_xor_reconstructed,
137+
AB::F::ONE,
138+
)
106139
.eval(builder, local.mult_xor);
107140
}
108141
}
109142

110-
// Lookup chip for operations on size NUM_BITS integers. Currently has pre-processed columns
111-
// for x ^ y and range check. Interactions are of form [x, y, z] where z is either x ^ y for
112-
// XOR or 0 for range check.
143+
// Lookup chip for operations on size NUM_BITS integers. Uses gate-based constraints
144+
// with binary decomposition instead of preprocessed trace. Interactions are of form [x, y, z]
145+
// where z is either x ^ y for XOR or 0 for range check.
113146

114147
pub struct BitwiseOperationLookupChip<const NUM_BITS: usize> {
115148
pub air: BitwiseOperationLookupAir<NUM_BITS>,
@@ -137,7 +170,7 @@ impl<const NUM_BITS: usize> BitwiseOperationLookupChip<NUM_BITS> {
137170
}
138171

139172
pub fn air_width(&self) -> usize {
140-
NUM_BITWISE_OP_LOOKUP_COLS
173+
BitwiseOperationLookupCols::<u8, NUM_BITS>::width()
141174
}
142175

143176
pub fn request_range(&self, x: u32, y: u32) {
@@ -164,17 +197,33 @@ impl<const NUM_BITS: usize> BitwiseOperationLookupChip<NUM_BITS> {
164197

165198
/// Generates trace and resets all internal counters to 0.
166199
pub fn generate_trace<F: Field>(&self) -> RowMajorMatrix<F> {
167-
let mut rows = F::zero_vec(self.count_range.len() * NUM_BITWISE_OP_LOOKUP_COLS);
168-
for (n, row) in rows.chunks_mut(NUM_BITWISE_OP_LOOKUP_COLS).enumerate() {
169-
let cols: &mut BitwiseOperationLookupCols<F> = row.borrow_mut();
200+
let num_cols = BitwiseOperationLookupCols::<F, NUM_BITS>::width();
201+
let num_rows = (1 << NUM_BITS) * (1 << NUM_BITS);
202+
let mut rows = F::zero_vec(num_rows * num_cols);
203+
204+
for (n, row) in rows.chunks_mut(num_cols).enumerate() {
205+
let cols: &mut BitwiseOperationLookupCols<F, NUM_BITS> = row.borrow_mut();
206+
207+
// Compute x and y from row index: row n corresponds to (x, y) where
208+
// x = n / (2^NUM_BITS), y = n % (2^NUM_BITS)
209+
let x = (n / (1 << NUM_BITS)) as u32;
210+
let y = (n % (1 << NUM_BITS)) as u32;
211+
212+
// Set x_bits and y_bits: decompose x and y into binary
213+
for i in 0..NUM_BITS {
214+
cols.x_bits[i] = F::from_canonical_u32((x >> i) & 1);
215+
cols.y_bits[i] = F::from_canonical_u32((y >> i) & 1);
216+
}
217+
218+
// Set multiplicities
170219
cols.mult_range = F::from_canonical_u32(
171220
self.count_range[n].swap(0, std::sync::atomic::Ordering::SeqCst),
172221
);
173222
cols.mult_xor = F::from_canonical_u32(
174223
self.count_xor[n].swap(0, std::sync::atomic::Ordering::SeqCst),
175224
);
176225
}
177-
RowMajorMatrix::new(rows, NUM_BITWISE_OP_LOOKUP_COLS)
226+
RowMajorMatrix::new(rows, num_cols)
178227
}
179228

180229
fn idx(x: u32, y: u32) -> usize {
@@ -203,6 +252,6 @@ impl<const NUM_BITS: usize> ChipUsageGetter for BitwiseOperationLookupChip<NUM_B
203252
1 << (2 * NUM_BITS)
204253
}
205254
fn trace_width(&self) -> usize {
206-
NUM_BITWISE_OP_LOOKUP_COLS
255+
BitwiseOperationLookupCols::<u8, NUM_BITS>::width()
207256
}
208257
}

0 commit comments

Comments
 (0)