Skip to content

Commit 07b6a29

Browse files
committed
prototype for ALU chip
1 parent 60073d7 commit 07b6a29

File tree

6 files changed

+91
-47
lines changed

6 files changed

+91
-47
lines changed

crates/circuits/primitives/cuda/include/primitives/trace_access.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,16 @@ struct RowSlice {
7171
/// Write a single value into `FIELD` of struct `STRUCT<T>` at a given row.
7272
#define COL_WRITE_VALUE(ROW, STRUCT, FIELD, VALUE) (ROW).write(COL_INDEX(STRUCT, FIELD), VALUE)
7373

74+
/// Write a single value into `FIELD` of struct `STRUCT<T>` at a given row.
75+
#define COL_WRITE_VALUE_APC(APC_ROW, STRUCT, FIELD, VALUE, SUB, OFFSET) {
76+
if SUB[COL_INDEX(STRUCT, FIELD) + OFFSET] != UINT32_MAX {
77+
(APC_ROW).write(
78+
COL_INDEX(STRUCT, FIELD) + OFFSET,
79+
VALUE
80+
);
81+
}
82+
}
83+
7484
/// Write an array of values into the fixed‐length `FIELD` array of `STRUCT<T>` for one row.
7585
#define COL_WRITE_ARRAY(ROW, STRUCT, FIELD, VALUES) \
7686
(ROW).write_array(COL_INDEX(STRUCT, FIELD), COL_ARRAY_LEN(STRUCT, FIELD), VALUES)

extensions/native/circuit/cuda/include/native/adapters/alu_native_adapter.cuh

Lines changed: 50 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -36,50 +36,59 @@ struct AluNativeAdapter {
3636
__device__ AluNativeAdapter(VariableRangeChecker rc, uint32_t timestamp_max_bits)
3737
: mem_helper(rc, timestamp_max_bits) {}
3838

39-
__device__ void fill_trace_row(RowSlice row, AluNativeAdapterRecord const &rec) {
40-
COL_WRITE_VALUE(row, AluNativeAdapterCols, from_state.pc, Fp(rec.from_pc));
41-
COL_WRITE_VALUE(row, AluNativeAdapterCols, from_state.timestamp, Fp(rec.from_timestamp));
42-
COL_WRITE_VALUE(row, AluNativeAdapterCols, a_pointer, Fp::fromRaw(rec.a_ptr));
43-
COL_WRITE_VALUE(row, AluNativeAdapterCols, b_pointer, Fp::fromRaw(rec.b));
44-
COL_WRITE_VALUE(row, AluNativeAdapterCols, c_pointer, Fp::fromRaw(rec.c));
45-
46-
// Fill read auxiliary columns for two operands (b and c)
47-
const Fp native_as = Fp(AS_NATIVE);
48-
for (int i = 0; i < 2; i++) {
49-
const uint32_t prev_timestamp = rec.reads_aux[i].prev_timestamp;
50-
const uint32_t current_timestamp = rec.from_timestamp + i;
51-
RowSlice aux_slice = row.slice_from(COL_INDEX(AluNativeAdapterCols, reads_aux[i]));
52-
53-
if (prev_timestamp == UINT32_MAX) {
54-
// Immediate
55-
mem_helper.fill(aux_slice, 0, current_timestamp);
56-
COL_WRITE_VALUE(row, AluNativeAdapterCols, reads_aux[i].is_zero_aux, Fp::zero());
57-
COL_WRITE_VALUE(row, AluNativeAdapterCols, reads_aux[i].is_immediate, Fp::one());
58-
if (i == 0) {
59-
COL_WRITE_VALUE(row, AluNativeAdapterCols, e_as, Fp(AS_IMMEDIATE));
39+
__device__ void fill_trace_row(RowSlice row, AluNativeAdapterRecord const &rec, RowSlice apc_row, uint32_t *sub, uint32_t offset) {
40+
if !apc_row.is_null() {
41+
COL_WRITE_VALUE_APC(apc_row, AluNativeAdapterCols, from_state.timestamp, Fp(rec.from_timestamp), sub, offset);
42+
COL_WRITE_VALUE_APC(row, AluNativeAdapterCols, a_pointer, Fp::fromRaw(rec.a_ptr), sub, offset);
43+
COL_WRITE_VALUE_APC(row, AluNativeAdapterCols, b_pointer, Fp::fromRaw(rec.b), sub, offset);
44+
COL_WRITE_VALUE_APC(row, AluNativeAdapterCols, c_pointer, Fp::fromRaw(rec.c), sub, offset);
45+
46+
// TODO: adapt the rest similar to above
47+
} else {
48+
COL_WRITE_VALUE(row, AluNativeAdapterCols, from_state.timestamp, Fp(rec.from_timestamp));
49+
COL_WRITE_VALUE(row, AluNativeAdapterCols, a_pointer, Fp::fromRaw(rec.a_ptr));
50+
COL_WRITE_VALUE(row, AluNativeAdapterCols, b_pointer, Fp::fromRaw(rec.b));
51+
COL_WRITE_VALUE(row, AluNativeAdapterCols, c_pointer, Fp::fromRaw(rec.c));
52+
53+
// Fill read auxiliary columns for two operands (b and c)
54+
const Fp native_as = Fp(AS_NATIVE);
55+
for (int i = 0; i < 2; i++) {
56+
const uint32_t prev_timestamp = rec.reads_aux[i].prev_timestamp;
57+
const uint32_t current_timestamp = rec.from_timestamp + i;
58+
RowSlice aux_slice = row.slice_from(COL_INDEX(AluNativeAdapterCols, reads_aux[i]));
59+
60+
if (prev_timestamp == UINT32_MAX) {
61+
// Immediate
62+
mem_helper.fill(aux_slice, 0, current_timestamp);
63+
COL_WRITE_VALUE(row, AluNativeAdapterCols, reads_aux[i].is_zero_aux, Fp::zero());
64+
COL_WRITE_VALUE(row, AluNativeAdapterCols, reads_aux[i].is_immediate, Fp::one());
65+
if (i == 0) {
66+
COL_WRITE_VALUE(row, AluNativeAdapterCols, e_as, Fp(AS_IMMEDIATE));
67+
} else {
68+
COL_WRITE_VALUE(row, AluNativeAdapterCols, f_as, Fp(AS_IMMEDIATE));
69+
}
6070
} else {
61-
COL_WRITE_VALUE(row, AluNativeAdapterCols, f_as, Fp(AS_IMMEDIATE));
62-
}
63-
} else {
64-
// Memory
65-
mem_helper.fill(aux_slice, prev_timestamp, current_timestamp);
66-
COL_WRITE_VALUE(
67-
row, AluNativeAdapterCols, reads_aux[i].is_zero_aux, inv(native_as)
68-
);
69-
COL_WRITE_VALUE(row, AluNativeAdapterCols, reads_aux[i].is_immediate, Fp::zero());
70-
if (i == 0) {
71-
COL_WRITE_VALUE(row, AluNativeAdapterCols, e_as, native_as);
72-
} else {
73-
COL_WRITE_VALUE(row, AluNativeAdapterCols, f_as, native_as);
71+
// Memory
72+
mem_helper.fill(aux_slice, prev_timestamp, current_timestamp);
73+
COL_WRITE_VALUE(
74+
row, AluNativeAdapterCols, reads_aux[i].is_zero_aux, inv(native_as)
75+
);
76+
COL_WRITE_VALUE(row, AluNativeAdapterCols, reads_aux[i].is_immediate, Fp::zero());
77+
if (i == 0) {
78+
COL_WRITE_VALUE(row, AluNativeAdapterCols, e_as, native_as);
79+
} else {
80+
COL_WRITE_VALUE(row, AluNativeAdapterCols, f_as, native_as);
81+
}
7482
}
7583
}
84+
85+
COL_WRITE_ARRAY(row, AluNativeAdapterCols, write_aux.prev_data, rec.write_aux.prev_data);
86+
mem_helper.fill(
87+
row.slice_from(COL_INDEX(AluNativeAdapterCols, write_aux)),
88+
rec.write_aux.prev_timestamp,
89+
rec.from_timestamp + 2
90+
);
7691
}
77-
78-
COL_WRITE_ARRAY(row, AluNativeAdapterCols, write_aux.prev_data, rec.write_aux.prev_data);
79-
mem_helper.fill(
80-
row.slice_from(COL_INDEX(AluNativeAdapterCols, write_aux)),
81-
rec.write_aux.prev_timestamp,
82-
rec.from_timestamp + 2
83-
);
92+
8493
}
8594
};

extensions/rv32im/circuit/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ test-case.workspace = true
4040
openvm-cuda-builder = { workspace = true, optional = true }
4141

4242
[features]
43-
default = ["parallel", "jemalloc"]
43+
default = ["parallel", "jemalloc", "cuda"]
4444
parallel = ["openvm-circuit/parallel"]
4545
test-utils = ["openvm-circuit/test-utils", "dep:openvm-stark-sdk"]
4646
tco = ["openvm-circuit/tco"]

extensions/rv32im/circuit/cuda/src/alu.cu

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,21 +31,27 @@ __global__ void alu_tracegen(
3131
uint32_t *d_bitwise_lookup_ptr,
3232
size_t bitwise_num_bits,
3333
uint32_t timestamp_max_bits
34+
Fp *d_apc_trace,
35+
uint32_t *subs, // same length as dummy width
36+
size_t width, // dummy width
37+
uint32_t *apc_row_index, // dummy row mapping to apc row same length as d_records
3438
) {
3539
uint32_t idx = blockIdx.x * blockDim.x + threadIdx.x;
3640
RowSlice row(d_trace + idx, height);
3741
if (idx < d_records.len()) {
3842
auto const &rec = d_records[idx];
43+
RowSlice apc_row(d_apc_trace + apc_row_index[idx], height);
44+
auto const sub = subs[idx * width]; // offset the subs to the corresponding dummy row
3945

4046
Rv32BaseAluAdapter adapter(
4147
VariableRangeChecker(d_range_checker_ptr, range_checker_bins),
4248
BitwiseOperationLookup(d_bitwise_lookup_ptr, bitwise_num_bits),
4349
timestamp_max_bits
4450
);
45-
adapter.fill_trace_row(row, rec.adapter);
51+
adapter.fill_trace_row(row, rec.adapter, apc_row, sub, 0); // sub offset is 0
4652

4753
Rv32BaseAluCore core(BitwiseOperationLookup(d_bitwise_lookup_ptr, bitwise_num_bits));
48-
core.fill_trace_row(row.slice_from(COL_INDEX(Rv32BaseAluCols, core)), rec.core);
54+
core.fill_trace_row(row.slice_from(COL_INDEX(Rv32BaseAluCols, core)), rec.core, apc_row, sub, COL_INDEX(Rv32BaseAluCols, core)); // has sub offset
4955
} else {
5056
row.fill_zero(0, sizeof(Rv32BaseAluCols<uint8_t>));
5157
}
@@ -60,7 +66,10 @@ extern "C" int _alu_tracegen(
6066
size_t range_checker_bins,
6167
uint32_t *d_bitwise_lookup_ptr,
6268
size_t bitwise_num_bits,
63-
uint32_t timestamp_max_bits
69+
uint32_t timestamp_max_bits,
70+
Fp *d_apc_trace,
71+
uint32_t *subs, // same length as dummy width
72+
uint32_t *apc_row_index, // dummy row mapping to apc row same length as d_records
6473
) {
6574
assert((height & (height - 1)) == 0);
6675
assert(height >= d_records.len());
@@ -74,7 +83,11 @@ extern "C" int _alu_tracegen(
7483
range_checker_bins,
7584
d_bitwise_lookup_ptr,
7685
bitwise_num_bits,
77-
timestamp_max_bits
86+
timestamp_max_bits,
87+
Fp *d_apc_trace,
88+
uint32_t *subs, // same length as dummy width
89+
size_t width, // dummy width
90+
uint32_t *apc_row_index, // dummy row mapping to apc row same length as d_records
7891
);
7992
return CHECK_KERNEL();
8093
}

extensions/rv32im/circuit/src/base_alu/cuda.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ pub struct Rv32BaseAluChipGpu {
2727
}
2828

2929
impl Chip<DenseRecordArena, GpuBackend> for Rv32BaseAluChipGpu {
30-
fn generate_proving_ctx(&self, arena: DenseRecordArena) -> AirProvingContext<GpuBackend> {
30+
fn generate_proving_ctx(&self, arena: DenseRecordArena, d_apc_trace: DeviceMatrix<F>, subs: Option<Vec<Vec<u32>>>, apc_row_index: Option<Vec<u32>>) -> AirProvingContext<GpuBackend> {
3131
const RECORD_SIZE: usize = size_of::<(
3232
Rv32BaseAluAdapterRecord,
3333
BaseAluCoreRecord<RV32_REGISTER_NUM_LIMBS>,
@@ -55,6 +55,9 @@ impl Chip<DenseRecordArena, GpuBackend> for Rv32BaseAluChipGpu {
5555
&self.bitwise_lookup.count,
5656
RV32_CELL_BITS,
5757
self.timestamp_max_bits as u32,
58+
d_apc_trace.buffer(),
59+
subs, // same length as dummy width
60+
apc_row_index, // dummy row mapping to apc row same length as d_records
5861
)
5962
.unwrap();
6063
}

extensions/rv32im/circuit/src/cuda_abi.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,9 @@ pub mod alu_cuda {
322322
d_bitwise_lookup: *mut u32,
323323
bitwise_num_bits: usize,
324324
timestamp_max_bits: u32,
325+
d_apc_trace: *mut F,
326+
subs: *mut u32,
327+
apc_row_index: *mut u32,
325328
) -> i32;
326329
}
327330

@@ -334,6 +337,9 @@ pub mod alu_cuda {
334337
d_bitwise_lookup: &DeviceBuffer<F>,
335338
bitwise_num_bits: usize,
336339
timestamp_max_bits: u32,
340+
d_apc_trace: &DeviceBuffer<F>,
341+
subs: Option<Vec<Vec<u32>>>,
342+
apc_row_index: Option<Vec<u32>>,
337343
) -> Result<(), CudaError> {
338344
let width = d_trace.len() / height;
339345
CudaError::from_result(_alu_tracegen(
@@ -346,6 +352,9 @@ pub mod alu_cuda {
346352
d_bitwise_lookup.as_mut_ptr() as *mut u32,
347353
bitwise_num_bits,
348354
timestamp_max_bits,
355+
d_apc_trace.as_mut_ptr(),
356+
subs,
357+
apc_row_index,
349358
))
350359
}
351360
}

0 commit comments

Comments
 (0)