Skip to content

Commit 23390d7

Browse files
authored
fix: program termination is enforced in AIR (#5)
1 parent a06e460 commit 23390d7

File tree

16 files changed

+271
-36
lines changed

16 files changed

+271
-36
lines changed

crates/core/executor/src/executor.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,9 @@ impl Executor {
7575
self.record.cpu_memory_access.push(event);
7676
}
7777

78+
self.record.public_values.start_pc = self.record.cpu_events[0].pc;
79+
self.record.public_values.next_pc = self.record.cpu_events.last().unwrap().next_pc;
80+
7881
Ok(())
7982
}
8083

crates/core/executor/src/record.rs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
use std::sync::Arc;
22

33
use hashbrown::HashMap;
4+
use p3_field::FieldAlgebra;
45
use serde::{Deserialize, Serialize};
56

6-
use bf_stark::MachineRecord;
7+
use bf_stark::{MachineRecord, PublicValues};
78

89
use crate::events::*;
910
use crate::program::Program;
@@ -31,6 +32,8 @@ pub struct ExecutionRecord {
3132
pub cpu_memory_access: Vec<MemoryEvent>,
3233
/// A trace of the byte lookups that are needed.
3334
pub byte_lookups: HashMap<ByteLookupEvent, usize>,
35+
/// The public values.
36+
pub public_values: PublicValues<u32>,
3437
}
3538

3639
/// A memory access record.
@@ -85,4 +88,9 @@ impl MachineRecord for ExecutionRecord {
8588

8689
self.cpu_memory_access.append(&mut other.cpu_memory_access);
8790
}
91+
92+
/// Retrieves the public values. This method is needed for the `MachineRecord` trait, since
93+
fn public_values<F: FieldAlgebra>(&self) -> Vec<F> {
94+
self.public_values.to_vec()
95+
}
8896
}

crates/core/machine/src/cpu/air.rs

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
use core::borrow::Borrow;
2-
use p3_air::{Air, AirBuilder, BaseAir};
2+
use p3_air::{Air, AirBuilder, AirBuilderWithPublicValues, BaseAir};
33
use p3_field::FieldAlgebra;
44
use p3_matrix::Matrix;
55

6-
use bf_stark::air::{BaseAirBuilder, BfAirBuilder};
6+
use bf_stark::{
7+
air::{BaseAirBuilder, BfAirBuilder},
8+
PublicValues, PROOF_NUM_PV_ELTS,
9+
};
710

811
use crate::{
912
air::{BfCoreAirBuilder, MemoryAirBuilder, U8AirBuilder},
@@ -21,7 +24,7 @@ impl<F> BaseAir<F> for CpuChip {
2124

2225
impl<AB> Air<AB> for CpuChip
2326
where
24-
AB: BfCoreAirBuilder,
27+
AB: BfCoreAirBuilder + AirBuilderWithPublicValues,
2528
AB::Var: Sized,
2629
{
2730
#[inline(never)]
@@ -31,6 +34,10 @@ where
3134
let local: &CpuCols<AB::Var> = (*local).borrow();
3235
let next: &CpuCols<AB::Var> = (*next).borrow();
3336

37+
let public_values_slice: [AB::PublicVar; PROOF_NUM_PV_ELTS] =
38+
core::array::from_fn(|i| builder.public_values()[i]);
39+
let public_values: &PublicValues<AB::PublicVar> = public_values_slice.as_slice().borrow();
40+
3441
let clk =
3542
AB::Expr::from_canonical_u32(1u32 << 16) * local.clk_8bit_limb + local.clk_16bit_limb;
3643

@@ -47,7 +54,7 @@ where
4754
self.eval_clk(builder, local, next, clk.clone());
4855

4956
// Check that the pc is updated correctly.
50-
self.eval_pc(builder, local, next);
57+
self.eval_pc(builder, local, next, public_values);
5158

5259
// Check that the is_real flag is correct.
5360
self.eval_is_real(builder, local, next);
@@ -128,6 +135,7 @@ impl CpuChip {
128135
builder: &mut AB,
129136
local: &CpuCols<AB::Var>,
130137
next: &CpuCols<AB::Var>,
138+
public_values: &PublicValues<AB::PublicVar>,
131139
) {
132140
builder.when_transition().when(next.is_real).assert_eq(local.next_pc, next.pc);
133141

@@ -136,6 +144,19 @@ impl CpuChip {
136144
.when(local.is_real)
137145
.when_not(local.is_jump)
138146
.assert_eq(local.next_pc, local.pc + AB::Expr::from_canonical_u32(1));
147+
148+
// Verify the public value's next pc. We need to handle two cases:
149+
// 1. The last real row is a transition row.
150+
// 2. The last real row is the last row.
151+
152+
// If the last real row is a transition row, verify the public value's next pc.
153+
builder
154+
.when_transition()
155+
.when(local.is_real - next.is_real)
156+
.assert_eq(public_values.next_pc, local.next_pc);
157+
158+
// If the last real row is the last row, verify the public value's next pc.
159+
builder.when_last_row().when(local.is_real).assert_eq(public_values.next_pc, local.next_pc);
139160
}
140161

141162
/// Constraints related to the is_real column.

crates/core/machine/src/utils/mod.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ pub const fn indices_arr<const N: usize>() -> [usize; N] {
2323
}
2424

2525
pub fn pad_to_power_of_two<const N: usize, T: Clone + Default>(values: &mut Vec<T>) {
26-
debug_assert!(values.len() % N == 0);
26+
debug_assert!(values.len().is_multiple_of(N));
2727
let mut n_real_rows = values.len() / N;
2828
if n_real_rows < 16 {
2929
n_real_rows = 16;
@@ -75,7 +75,7 @@ where
7575
P: Fn(usize, &mut [F]) + Send + Sync,
7676
{
7777
// Split the vector into `num_cpus` chunks, but at least `num_cpus` rows per chunk.
78-
assert!(vec.len() % num_elements_per_event == 0);
78+
assert!(vec.len().is_multiple_of(num_elements_per_event));
7979
let len = vec.len() / num_elements_per_event;
8080
let cpus = num_cpus::get();
8181
let ceil_div = len.div_ceil(cpus);

crates/stark/src/air/builder.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use std::iter::once;
22

3-
use p3_air::{AirBuilder, FilteredAirBuilder, PermutationAirBuilder};
3+
use p3_air::{AirBuilder, AirBuilderWithPublicValues, FilteredAirBuilder, PermutationAirBuilder};
44
use p3_field::{Field, FieldAlgebra};
55
use p3_uni_stark::{
66
ProverConstraintFolder, StarkGenericConfig, SymbolicAirBuilder, VerifierConstraintFolder,
@@ -28,7 +28,7 @@ pub trait MessageBuilder<M> {
2828
/// A trait which contains basic methods for building an AIR.
2929
pub trait BaseAirBuilder: AirBuilder + MessageBuilder<AirLookup<Self::Expr>> {
3030
/// Returns a sub-builder whose constraints are enforced only when `condition` is not one.
31-
fn when_not<I: Into<Self::Expr>>(&mut self, condition: I) -> FilteredAirBuilder<Self> {
31+
fn when_not<I: Into<Self::Expr>>(&mut self, condition: I) -> FilteredAirBuilder<'_, Self> {
3232
self.when_ne(condition, Self::F::ONE)
3333
}
3434

@@ -248,7 +248,7 @@ pub trait MultiTableAirBuilder<'a>: PermutationAirBuilder {
248248
}
249249

250250
/// A trait that contains the common helper methods for building machine AIRs.
251-
pub trait MachineAirBuilder: BaseAirBuilder {}
251+
pub trait MachineAirBuilder: BaseAirBuilder + AirBuilderWithPublicValues {}
252252

253253
/// A trait which contains all helper methods for building machine AIRs.
254254
pub trait BfAirBuilder: MachineAirBuilder + ByteAirBuilder + InstructionAirBuilder {}
@@ -267,8 +267,8 @@ impl<AB: AirBuilder + MessageBuilder<AirLookup<AB::Expr>>> BaseAirBuilder for AB
267267
impl<AB: BaseAirBuilder> ByteAirBuilder for AB {}
268268
impl<AB: BaseAirBuilder> InstructionAirBuilder for AB {}
269269

270-
impl<AB: BaseAirBuilder> MachineAirBuilder for AB {}
271-
impl<AB: BaseAirBuilder> BfAirBuilder for AB {}
270+
impl<AB: BaseAirBuilder + AirBuilderWithPublicValues> MachineAirBuilder for AB {}
271+
impl<AB: BaseAirBuilder + AirBuilderWithPublicValues> BfAirBuilder for AB {}
272272

273273
impl<SC: StarkGenericConfig> EmptyMessageBuilder for ProverConstraintFolder<'_, SC> {}
274274
impl<SC: StarkGenericConfig> EmptyMessageBuilder for VerifierConstraintFolder<'_, SC> {}

crates/stark/src/air/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
mod builder;
44
mod lookup;
55
mod machine;
6+
mod public_values;
67

78
pub use builder::*;
89
pub use lookup::*;
910
pub use machine::*;
11+
pub use public_values::*;
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
use core::{fmt::Debug, mem::size_of};
2+
use std::borrow::{Borrow, BorrowMut};
3+
4+
use p3_field::FieldAlgebra;
5+
use serde::{Deserialize, Serialize};
6+
7+
use crate::PROOF_MAX_NUM_PVS;
8+
9+
/// The number of non padded elements in the Ziren proofs public values vec.
10+
pub const PROOF_NUM_PV_ELTS: usize = size_of::<PublicValues<u8>>();
11+
12+
/// Stores all of a shard proof's public values.
13+
#[derive(Serialize, Deserialize, Clone, Copy, Default, Debug)]
14+
#[repr(C)]
15+
pub struct PublicValues<T> {
16+
/// The shard's start program counter.
17+
pub start_pc: T,
18+
19+
/// The expected start program counter for the next shard.
20+
pub next_pc: T,
21+
22+
/// This field is here to ensure that the size of the public values struct is a multiple of 8.
23+
pub empty: [T; 6],
24+
}
25+
26+
impl PublicValues<u32> {
27+
/// Convert the public values into a vector of field elements. This function will pad the
28+
/// vector to the maximum number of public values.
29+
#[must_use]
30+
pub fn to_vec<F: FieldAlgebra>(&self) -> Vec<F> {
31+
let mut ret = vec![F::ZERO; PROOF_MAX_NUM_PVS];
32+
33+
let field_values = PublicValues::<F> {
34+
start_pc: F::from_canonical_u32(self.start_pc),
35+
next_pc: F::from_canonical_u32(self.next_pc),
36+
empty: [F::ZERO; 6],
37+
};
38+
let ret_ref_mut: &mut PublicValues<F> = ret.as_mut_slice().borrow_mut();
39+
*ret_ref_mut = field_values;
40+
ret
41+
}
42+
43+
/// Resets the public values to zero.
44+
#[must_use]
45+
pub fn reset(&self) -> Self {
46+
let mut copy = *self;
47+
copy.start_pc = 0;
48+
copy.next_pc = 0;
49+
copy
50+
}
51+
}
52+
53+
impl<T: Clone> Borrow<PublicValues<T>> for [T] {
54+
fn borrow(&self) -> &PublicValues<T> {
55+
let size = std::mem::size_of::<PublicValues<u8>>();
56+
debug_assert!(self.len() >= size);
57+
let slice = &self[0..size];
58+
let (prefix, shorts, _suffix) = unsafe { slice.align_to::<PublicValues<T>>() };
59+
debug_assert!(prefix.is_empty(), "Alignment should match");
60+
debug_assert_eq!(shorts.len(), 1);
61+
&shorts[0]
62+
}
63+
}
64+
65+
impl<T: Clone> BorrowMut<PublicValues<T>> for [T] {
66+
fn borrow_mut(&mut self) -> &mut PublicValues<T> {
67+
let size = std::mem::size_of::<PublicValues<u8>>();
68+
debug_assert!(self.len() >= size);
69+
let slice = &mut self[0..size];
70+
let (prefix, shorts, _suffix) = unsafe { slice.align_to_mut::<PublicValues<T>>() };
71+
debug_assert!(prefix.is_empty(), "Alignment should match");
72+
debug_assert_eq!(shorts.len(), 1);
73+
&mut shorts[0]
74+
}
75+
}

crates/stark/src/debug.rs

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@ use std::{
44
process::exit,
55
};
66

7-
use p3_air::{Air, AirBuilder, ExtensionBuilder, PairBuilder, PermutationAirBuilder};
7+
use p3_air::{
8+
Air, AirBuilder, AirBuilderWithPublicValues, ExtensionBuilder, PairBuilder,
9+
PermutationAirBuilder,
10+
};
811
use p3_field::{ExtensionField, Field, FieldAlgebra, PrimeField32};
912
use p3_matrix::{
1013
dense::{RowMajorMatrix, RowMajorMatrixView},
@@ -27,6 +30,7 @@ pub fn debug_constraints<SC, A>(
2730
main: &RowMajorMatrix<Val<SC>>,
2831
perm: &RowMajorMatrix<SC::Challenge>,
2932
perm_challenges: &[SC::Challenge],
33+
public_values: &[Val<SC>],
3034
cumulative_sum: &SC::Challenge,
3135
) where
3236
SC: StarkGenericConfig,
@@ -84,6 +88,7 @@ pub fn debug_constraints<SC, A>(
8488
is_first_row: Val::<SC>::ZERO,
8589
is_last_row: Val::<SC>::ZERO,
8690
is_transition: Val::<SC>::ONE,
91+
public_values,
8792
};
8893
if i == 0 {
8994
builder.is_first_row = Val::<SC>::ONE;
@@ -130,7 +135,7 @@ pub struct DebugConstraintBuilder<'a, F: Field, EF: ExtensionField<F>> {
130135
pub(crate) is_first_row: F,
131136
pub(crate) is_last_row: F,
132137
pub(crate) is_transition: F,
133-
// pub(crate) public_values: &'a [F],
138+
pub(crate) public_values: &'a [F],
134139
}
135140

136141
impl<F, EF> ExtensionBuilder for DebugConstraintBuilder<'_, F, EF>
@@ -260,3 +265,13 @@ where
260265
}
261266

262267
impl<F: Field, EF: ExtensionField<F>> EmptyMessageBuilder for DebugConstraintBuilder<'_, F, EF> {}
268+
269+
impl<F: Field, EF: ExtensionField<F>> AirBuilderWithPublicValues
270+
for DebugConstraintBuilder<'_, F, EF>
271+
{
272+
type PublicVar = F;
273+
274+
fn public_values(&self) -> &[Self::PublicVar] {
275+
self.public_values
276+
}
277+
}

0 commit comments

Comments
 (0)