Skip to content

Commit 2414f13

Browse files
slumbersjudson
authored andcommitted
omnibus prover2 components fixes (#713)
* omnibus prover components fixes * ram init final timestamp size
1 parent e46740f commit 2414f13

File tree

10 files changed

+96
-77
lines changed

10 files changed

+96
-77
lines changed

prover2/machine/src/components/execution/branch_eq/mod.rs

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,10 @@ impl<T: BranchOp> BuiltInComponent for BranchEq<T> {
262262
// fill padding
263263
for row_idx in num_steps..1 << log_size {
264264
common_trace.fill_columns(row_idx, true, Column::IsLocalPad);
265+
common_trace.fill_columns(row_idx, 1u8, Column::HNeq12FlagAux);
266+
common_trace.fill_columns(row_idx, 1u8, Column::HNeq12FlagAuxInv);
267+
common_trace.fill_columns(row_idx, 1u8, Column::HNeq34FlagAux);
268+
common_trace.fill_columns(row_idx, 1u8, Column::HNeq34FlagAuxInv);
265269
if T::OPCODE == BuiltinOpcode::BNE {
266270
// (1 - h-neq-flag) * 4 term in pc-next constraint is non-zero on padding
267271
common_trace.fill_columns(row_idx, [4u16, 0], Column::PcNext);
@@ -310,7 +314,6 @@ impl<T: BranchOp> BuiltInComponent for BranchEq<T> {
310314
}
311315
.constrain(eval, &trace_eval);
312316

313-
let [is_local_pad] = trace_eval!(trace_eval, Column::IsLocalPad);
314317
let pc = trace_eval!(trace_eval, Column::Pc);
315318
let pc_next = trace_eval!(trace_eval, Column::PcNext);
316319

@@ -362,15 +365,9 @@ impl<T: BranchOp> BuiltInComponent for BranchEq<T> {
362365
// enforcing h-neq12-flag-aux != 0, h-neq34-flag-aux != 0
363366
//
364367
// (1 − is-local-pad) · (h-neq12-flag-aux · h-neq12-flag-aux-inv − 1) = 0
365-
eval.add_constraint(
366-
(E::F::one() - is_local_pad.clone())
367-
* (h_neq12_flag_aux.clone() * h_neq12_flag_aux_inv.clone() - E::F::one()),
368-
);
368+
eval.add_constraint(h_neq12_flag_aux.clone() * h_neq12_flag_aux_inv.clone() - E::F::one());
369369
// (1 − is-local-pad) · (h-neq34-flag-aux · h-neq34-flag-aux-inv − 1) = 0
370-
eval.add_constraint(
371-
(E::F::one() - is_local_pad.clone())
372-
* (h_neq34_flag_aux.clone() * h_neq34_flag_aux_inv.clone() - E::F::one()),
373-
);
370+
eval.add_constraint(h_neq34_flag_aux.clone() * h_neq34_flag_aux_inv.clone() - E::F::one());
374371

375372
// (1 − is-local-pad) · (
376373
// (1 − h-neq12-flag) · (1 − h-neq34-flag)

prover2/machine/src/components/execution/decoding/type_b.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ impl CVal {
7070
let [op_c12] = trace_eval.column_eval(DecodingColumn::OpC12);
7171

7272
[
73-
op_c1_4 * BaseField::from(1 << 1) + op_c5_7 * BaseField::from(1 << 4),
73+
op_c1_4 * BaseField::from(1 << 1) + op_c5_7 * BaseField::from(1 << 5),
7474
op_c8_10
7575
+ op_c11 * BaseField::from(1 << 3)
7676
+ op_c12.clone() * BaseField::from(((1 << 4) - 1) * (1 << 4)),

prover2/machine/src/components/execution/jalr/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -396,7 +396,7 @@ impl BuiltInComponent for Jalr {
396396
+ c_val[3].clone() * BaseField::from(1 << 8)
397397
+ b_val[2].clone()
398398
+ b_val[3].clone() * BaseField::from(1 << 8)
399-
+ pc_carry_1.clone()
399+
+ pc_carry_2.clone()
400400
- pc_carry_4.clone() * BaseField::from(1 << 8).pow(2)
401401
- pc_next_high.clone()),
402402
);

prover2/machine/src/components/execution/load/mod.rs

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -287,8 +287,12 @@ impl<T: LoadOp> BuiltInComponent for Load<T> {
287287
}
288288
.constrain(eval, &trace_eval);
289289

290-
// (1 − is-local-pad) *
291-
// (h_ram_base_addr(1) + h_ram_base_addr(2) * 2^8 − b-val(1) − b-val(2) * 2^8 − c-val(1) − c-val(2) * 2^8 + h_carry(1) * 2^16) = 0
290+
// (1 − is-local-pad) · (
291+
// h-ram-base-addr(1) + h-ram-base-addr(2) · 2^8
292+
// − b-val(1) − b-val(2) · 2^8
293+
// − c-val(1) − c-val(2) · 2^8
294+
// + h-carry(1) · 2^16
295+
// ) = 0
292296
eval.add_constraint(
293297
(E::F::one() - is_local_pad.clone())
294298
* (h_ram_base_addr[0].clone()
@@ -299,12 +303,18 @@ impl<T: LoadOp> BuiltInComponent for Load<T> {
299303
- c_val[1].clone() * BaseField::from(1 << 8)
300304
+ h_carry[0].clone() * BaseField::from(1 << 16)),
301305
);
302-
// (1 − is-local-pad) *
303-
// (h_ram_base_addr(3) + h_ram_base_addr(4) * 2^8 − b-val(3) − b-val(4) * 2^8 − c-val(3) − c-val(4) * 2^8 + h_carry(2) * 2^16) = 0
306+
// (1 − is-local-pad) · (
307+
// h-ram-base-addr(3) + h-ram-base-addr(4) · 2^8
308+
// − h-carry(1)
309+
// − b-val(3) − b-val(4) · 2^8
310+
// − c-val(3) − c-val(4) · 2^8
311+
// + h-carry(2) · 2^16
312+
// ) = 0
304313
eval.add_constraint(
305314
(E::F::one() - is_local_pad.clone())
306315
* (h_ram_base_addr[2].clone()
307316
+ h_ram_base_addr[3].clone() * BaseField::from(1 << 8)
317+
- h_carry[0].clone()
308318
- b_val[2].clone()
309319
- b_val[3].clone() * BaseField::from(1 << 8)
310320
- c_val[2].clone()

prover2/machine/src/components/execution/store/columns.rs

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -93,28 +93,6 @@ pub const OP_B: RegSplitAt4<Column> = RegSplitAt4 {
9393
bits_0_3: Column::OpB0_3,
9494
bit_4: Column::OpB4,
9595
};
96-
pub const OP_C: OpC = OpC;
97-
98-
pub struct OpC;
99-
100-
impl OpC {
101-
pub fn eval<E: EvalAtRow>(
102-
&self,
103-
trace_eval: &TraceEval<PreprocessedColumn, Column, E>,
104-
) -> E::F {
105-
let [op_c0] = trace_eval.column_eval(Column::OpC0);
106-
let [op_c1_4] = trace_eval.column_eval(Column::OpC1_4);
107-
let [op_c5_7] = trace_eval.column_eval(Column::OpC5_7);
108-
let [op_c8_10] = trace_eval.column_eval(Column::OpC8_10);
109-
let [op_c11] = trace_eval.column_eval(Column::OpC11);
110-
111-
op_c0.clone()
112-
+ op_c1_4.clone() * BaseField::from(2)
113-
+ op_c5_7.clone() * BaseField::from(1 << 5)
114-
+ op_c8_10.clone() * BaseField::from(1 << 8)
115-
+ op_c11.clone() * BaseField::from(1 << 11)
116-
}
117-
}
11896

11997
pub struct InstrVal {
12098
opcode: u8,

prover2/machine/src/components/execution/store/mod.rs

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,6 @@ impl<T: StoreOp> BuiltInComponent for Store<T> {
214214
n => b_val[..n].into(),
215215
};
216216
ram_values.resize(WORD_SIZE, BaseField::zero().into());
217-
ram_values.resize(WORD_SIZE, BaseField::zero().into());
218217
// provide(
219218
// rel-inst-to-ram,
220219
// 1 − is-local-pad,
@@ -288,8 +287,12 @@ impl<T: StoreOp> BuiltInComponent for Store<T> {
288287
}
289288
.constrain(eval, &trace_eval);
290289

291-
// (1 − is-local-pad) *
292-
// (h_ram_base_addr(1) + h_ram_base_addr(2) * 2^8 − a-val(1) − a-val(2) * 2^8 − c-val(1) − c-val(2) * 2^8 + h_carry(1) * 2^16) = 0
290+
// (1 − is-local-pad) · (
291+
// h-ram-base-addr(1) + h-ram-base-addr(2) · 2^8
292+
// − a-val(1) − a-val(2) · 2^8
293+
// − c-val(1) − c-val(2) · 2^8
294+
// + h-carry(1) · 2^16
295+
// ) = 0
293296
eval.add_constraint(
294297
(E::F::one() - is_local_pad.clone())
295298
* (h_ram_base_addr[0].clone()
@@ -300,12 +303,18 @@ impl<T: StoreOp> BuiltInComponent for Store<T> {
300303
- c_val[1].clone() * BaseField::from(1 << 8)
301304
+ h_carry[0].clone() * BaseField::from(1 << 16)),
302305
);
303-
// (1 − is-local-pad) *
304-
// (h_ram_base_addr(3) + h_ram_base_addr(4) * 2^8 − a-val(3) − a-val(4) * 2^8 − c-val(3) − c-val(4) * 2^8 + h_carry(2) * 2^16) = 0
306+
// (1 − is-local-pad) · (
307+
// h-ram-base-addr(3) + h-ram-base-addr(4) · 2^8
308+
// − h-carry(1)
309+
// − a-val(3) − a-val(4) · 2^8
310+
// − c-val(3) − c-val(4) · 2^8
311+
// + h-carry(2) · 2^16
312+
// ) = 0
305313
eval.add_constraint(
306314
(E::F::one() - is_local_pad.clone())
307315
* (h_ram_base_addr[2].clone()
308316
+ h_ram_base_addr[3].clone() * BaseField::from(1 << 8)
317+
- h_carry[0].clone()
309318
- a_val[2].clone()
310319
- a_val[3].clone() * BaseField::from(1 << 8)
311320
- c_val[2].clone()
@@ -337,7 +346,7 @@ impl<T: StoreOp> BuiltInComponent for Store<T> {
337346
columns::InstrVal::new(T::OPCODE.raw(), T::OPCODE.fn3().value()).eval(&trace_eval);
338347
let op_a = columns::OP_A.eval(&trace_eval);
339348
let op_b = columns::OP_B.eval(&trace_eval);
340-
let op_c = columns::OP_C.eval(&trace_eval);
349+
let op_c = E::F::zero();
341350

342351
let ram2_accessed = E::F::from(BaseField::from(T::RAM2_ACCESSED as u32));
343352
let ram3_4accessed = E::F::from(BaseField::from(T::RAM3_4ACCESSED as u32));

prover2/machine/src/components/read_write_memory_boundary/columns.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ pub enum Column {
3333
#[size = 1]
3434
RamValInit,
3535
/// The timestamp associated with the last access to address ram-init-final-addr
36-
#[size = 4]
36+
#[size = 2]
3737
RamTsFinal,
3838
/// A flag indicating whether ram-final, ram-init columns on the current row are being used
3939
#[size = 1]

prover2/machine/src/components/read_write_memory_boundary/mod.rs

Lines changed: 53 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,13 @@
22
33
use num_traits::Zero;
44
use stwo_prover::{
5-
constraint_framework::{EvalAtRow, RelationEntry},
5+
constraint_framework::{logup::LogupTraceGenerator, EvalAtRow, Relation, RelationEntry},
66
core::{
7-
backend::simd::{column::BaseColumn, m31::LOG_N_LANES, SimdBackend},
7+
backend::simd::{
8+
column::BaseColumn,
9+
m31::{PackedBaseField, LOG_N_LANES},
10+
SimdBackend,
11+
},
812
fields::{
913
m31::{self, BaseField},
1014
qm31::SecureField,
@@ -14,6 +18,7 @@ use stwo_prover::{
1418
},
1519
};
1620

21+
use nexus_common::constants::WORD_SIZE_HALVED;
1722
use nexus_vm::{
1823
emulator::{MemoryInitializationEntry, PublicOutputEntry},
1924
WORD_SIZE,
@@ -26,8 +31,9 @@ use nexus_vm_prover_trace::{
2631
};
2732

2833
use crate::{
34+
components::utils::u32_to_16bit_parts_le,
2935
framework::BuiltInComponent,
30-
lookups::{AllLookupElements, LogupTraceBuilder, RamReadWriteLookupElements},
36+
lookups::{AllLookupElements, RamReadWriteLookupElements},
3137
side_note::{program::ProgramTraceRef, SideNote},
3238
};
3339

@@ -115,7 +121,8 @@ impl BuiltInComponent for ReadWriteMemoryBoundary {
115121
trace.fill_columns(row_idx, true, Column::RamInitFinalFlag);
116122
assert!(*last_access < m31::P, "Access counter overflow");
117123

118-
trace.fill_columns(row_idx, *last_access, Column::RamTsFinal);
124+
let ts_final = u32_to_16bit_parts_le(*last_access);
125+
trace.fill_columns(row_idx, ts_final, Column::RamTsFinal);
119126
trace.fill_columns(row_idx, *last_value, Column::RamValFinal);
120127
}
121128

@@ -131,41 +138,55 @@ impl BuiltInComponent for ReadWriteMemoryBoundary {
131138
ColumnVec<CircleEvaluation<SimdBackend, BaseField, BitReversedOrder>>,
132139
SecureField,
133140
) {
141+
// TODO: support non-optimized logup trace generation
142+
134143
let rel_ram_read_write: &Self::LookupElements = lookup_elements.as_ref();
135-
let mut logup_trace_builder = LogupTraceBuilder::new(component_trace.log_size());
144+
let log_size = component_trace.log_size();
136145

137146
let [ram_init_final_flag] =
138147
original_base_column!(component_trace, Column::RamInitFinalFlag);
139148
let ram_init_final_addr = original_base_column!(component_trace, Column::RamInitFinalAddr);
140149
let [ram_val_final] = original_base_column!(component_trace, Column::RamValFinal);
141150
let ram_ts_final = original_base_column!(component_trace, Column::RamTsFinal);
151+
let ram_val_init = ReadWriteMemoryBoundary::combine_ram_val_init(&component_trace);
142152

143-
// consume(rel-ram-read-write, ram-init-final-flag, (ram-init-final-addr, ram-val-final, ram-ts-final))
144-
logup_trace_builder.add_to_relation_with(
145-
rel_ram_read_write,
146-
[ram_init_final_flag.clone()],
147-
|[ram_init_final_flag]| (-ram_init_final_flag).into(),
148-
&[
149-
ram_init_final_addr.as_slice(),
150-
std::slice::from_ref(&ram_val_final),
151-
&ram_ts_final,
152-
]
153-
.concat(),
154-
);
153+
let mut logup_trace_gen = LogupTraceGenerator::new(log_size);
155154

156-
let ram_val_init = ReadWriteMemoryBoundary::combine_ram_val_init(&component_trace);
155+
let final_values = [
156+
ram_init_final_addr.as_slice(),
157+
std::slice::from_ref(&ram_val_final),
158+
&ram_ts_final,
159+
]
160+
.concat();
161+
// consume(rel-ram-read-write, ram-init-final-flag, (ram-init-final-addr, ram-val-final, ram-ts-final))
162+
let mut logup_col_gen = logup_trace_gen.new_col();
163+
for vec_row in 0..(1 << (log_size - LOG_N_LANES)) {
164+
let tuple: Vec<PackedBaseField> =
165+
final_values.iter().map(|col| col.at(vec_row)).collect();
166+
let denom = rel_ram_read_write.combine(&tuple);
167+
let numerator = ram_init_final_flag.at(vec_row);
168+
logup_col_gen.write_frac(vec_row, (-numerator).into(), denom);
169+
}
170+
logup_col_gen.finalize_col();
171+
172+
let init_values = [
173+
ram_init_final_addr.as_slice(),
174+
std::slice::from_ref(&ram_val_init),
175+
vec![BaseField::zero().into(); WORD_SIZE].as_slice(),
176+
]
177+
.concat();
157178
// provide(rel-ram-read-write, ram-init-final-flag, (ram-init-final-addr, ram-val-init, 0))
158-
logup_trace_builder.add_to_relation(
159-
rel_ram_read_write,
160-
ram_init_final_flag,
161-
&[
162-
ram_init_final_addr.as_slice(),
163-
std::slice::from_ref(&ram_val_init),
164-
vec![BaseField::zero().into(); WORD_SIZE].as_slice(),
165-
]
166-
.concat(),
167-
);
168-
logup_trace_builder.finalize()
179+
let mut logup_col_gen = logup_trace_gen.new_col();
180+
for vec_row in 0..(1 << (log_size - LOG_N_LANES)) {
181+
let tuple: Vec<PackedBaseField> =
182+
init_values.iter().map(|col| col.at(vec_row)).collect();
183+
let denom = rel_ram_read_write.combine(&tuple);
184+
let numerator = ram_init_final_flag.at(vec_row);
185+
logup_col_gen.write_frac(vec_row, numerator.into(), denom);
186+
}
187+
logup_col_gen.finalize_col();
188+
189+
logup_trace_gen.finalize_last()
169190
}
170191

171192
fn add_constraints<E: EvalAtRow>(
@@ -225,12 +246,13 @@ impl BuiltInComponent for ReadWriteMemoryBoundary {
225246
&[
226247
ram_init_final_addr.as_slice(),
227248
std::slice::from_ref(&ram_val_init),
228-
vec![E::F::zero(); WORD_SIZE].as_slice(),
249+
vec![E::F::zero(); WORD_SIZE_HALVED].as_slice(),
229250
]
230251
.concat(),
231252
));
232253

233-
eval.finalize_logup_in_pairs();
254+
// avoid in-pairs optimizations to keep low degree
255+
eval.finalize_logup();
234256
}
235257
}
236258

prover2/machine/src/prove.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,6 @@ mod tests {
154154
k_trace_direct(&basic_block, 1).expect("error generating trace");
155155

156156
let proof = prove(&program_trace, &view).unwrap();
157-
verify(proof).unwrap();
157+
verify(proof, &[]).unwrap();
158158
}
159159
}

prover2/machine/src/verify.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ use nexus_vm_prover_trace::eval::{
1818
use super::{Proof, BASE_COMPONENTS};
1919
use crate::lookups::AllLookupElements;
2020

21-
pub fn verify(proof: Proof) -> Result<(), VerificationError> {
21+
pub fn verify(proof: Proof, ad: &[u8]) -> Result<(), VerificationError> {
2222
let components = BASE_COMPONENTS;
2323
let Proof {
2424
stark_proof: proof,
@@ -44,6 +44,9 @@ pub fn verify(proof: Proof) -> Result<(), VerificationError> {
4444

4545
let config = PcsConfig::default();
4646
let verifier_channel = &mut Blake2sChannel::default();
47+
for &byte in ad {
48+
verifier_channel.mix_u64(byte.into());
49+
}
4750

4851
claimed_log_sizes.iter().for_each(|log_size| {
4952
verifier_channel.mix_u64(*log_size as u64);

0 commit comments

Comments
 (0)