Skip to content

Commit a0aa176

Browse files
committed
better AIR for execution
1 parent 4c07ce7 commit a0aa176

File tree

2 files changed

+54
-59
lines changed

2 files changed

+54
-59
lines changed

crates/lean_prover/vm_air/src/dot_product_air.rs

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -59,12 +59,7 @@ impl<AB: AirBuilder> Air<AB> for DotProductAir {
5959
value_b_up,
6060
res_up,
6161
computation_up,
62-
] = up
63-
.iter()
64-
.map(|v| v.clone().into())
65-
.collect::<Vec<AB::Expr>>()
66-
.try_into()
67-
.unwrap();
62+
] = up.to_vec().try_into().ok().unwrap();
6863
let [
6964
start_flag_down,
7065
len_down,
@@ -75,12 +70,9 @@ impl<AB: AirBuilder> Air<AB> for DotProductAir {
7570
_value_b_down,
7671
_res_down,
7772
computation_down,
78-
] = down
79-
.iter()
80-
.map(|v| v.clone().into())
81-
.collect::<Vec<AB::Expr>>()
82-
.try_into()
83-
.unwrap();
73+
] = down.to_vec().try_into().ok().unwrap();
74+
75+
// TODO we could some some of the following computation in the base field
8476

8577
builder.assert_bool(start_flag_down.clone());
8678

crates/lean_prover/vm_air/src/execution_air.rs

Lines changed: 50 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -60,70 +60,73 @@ impl<AB: AirBuilder> Air<AB> for VMAir {
6060
let down: &[AB::Var] = (*down).borrow();
6161
assert_eq!(down.len(), N_EXEC_AIR_COLUMNS);
6262

63-
let (operand_a, operand_b, operand_c): (AB::Expr, AB::Expr, AB::Expr) = (
64-
up[COL_INDEX_OPERAND_A].clone().into(),
65-
up[COL_INDEX_OPERAND_B].clone().into(),
66-
up[COL_INDEX_OPERAND_C].clone().into(),
63+
let (operand_a, operand_b, operand_c) = (
64+
up[COL_INDEX_OPERAND_A].clone(),
65+
up[COL_INDEX_OPERAND_B].clone(),
66+
up[COL_INDEX_OPERAND_C].clone(),
6767
);
68-
let (flag_a, flag_b, flag_c): (AB::Expr, AB::Expr, AB::Expr) = (
69-
up[COL_INDEX_FLAG_A].clone().into(),
70-
up[COL_INDEX_FLAG_B].clone().into(),
71-
up[COL_INDEX_FLAG_C].clone().into(),
68+
let (flag_a, flag_b, flag_c) = (
69+
up[COL_INDEX_FLAG_A].clone(),
70+
up[COL_INDEX_FLAG_B].clone(),
71+
up[COL_INDEX_FLAG_C].clone(),
7272
);
73-
let add: AB::Expr = up[COL_INDEX_ADD].clone().into();
74-
let mul: AB::Expr = up[COL_INDEX_MUL].clone().into();
75-
let deref: AB::Expr = up[COL_INDEX_DEREF].clone().into();
76-
let jump: AB::Expr = up[COL_INDEX_JUMP].clone().into();
77-
let aux: AB::Expr = up[COL_INDEX_AUX].clone().into();
78-
79-
let (value_a, value_b, value_c): (AB::Expr, AB::Expr, AB::Expr) = (
80-
up[COL_INDEX_MEM_VALUE_A.index_in_air()].clone().into(),
81-
up[COL_INDEX_MEM_VALUE_B.index_in_air()].clone().into(),
82-
up[COL_INDEX_MEM_VALUE_C.index_in_air()].clone().into(),
73+
let add = up[COL_INDEX_ADD].clone();
74+
let mul = up[COL_INDEX_MUL].clone();
75+
let deref = up[COL_INDEX_DEREF].clone();
76+
let jump = up[COL_INDEX_JUMP].clone();
77+
let aux = up[COL_INDEX_AUX].clone();
78+
79+
let (value_a, value_b, value_c) = (
80+
up[COL_INDEX_MEM_VALUE_A.index_in_air()].clone(),
81+
up[COL_INDEX_MEM_VALUE_B.index_in_air()].clone(),
82+
up[COL_INDEX_MEM_VALUE_C.index_in_air()].clone(),
8383
);
84-
let (pc, next_pc): (AB::Expr, AB::Expr) = (
85-
up[COL_INDEX_PC.index_in_air()].clone().into(),
86-
down[COL_INDEX_PC.index_in_air()].clone().into(),
84+
let (pc, next_pc) = (
85+
up[COL_INDEX_PC.index_in_air()].clone(),
86+
down[COL_INDEX_PC.index_in_air()].clone(),
8787
);
88-
let (fp, next_fp): (AB::Expr, AB::Expr) = (
89-
up[COL_INDEX_FP.index_in_air()].clone().into(),
90-
down[COL_INDEX_FP.index_in_air()].clone().into(),
88+
let (fp, next_fp) = (
89+
up[COL_INDEX_FP.index_in_air()].clone(),
90+
down[COL_INDEX_FP.index_in_air()].clone(),
9191
);
92-
let (addr_a, addr_b, addr_c): (AB::Expr, AB::Expr, AB::Expr) = (
93-
up[COL_INDEX_MEM_ADDRESS_A.index_in_air()].clone().into(),
94-
up[COL_INDEX_MEM_ADDRESS_B.index_in_air()].clone().into(),
95-
up[COL_INDEX_MEM_ADDRESS_C.index_in_air()].clone().into(),
92+
let (addr_a, addr_b, addr_c) = (
93+
up[COL_INDEX_MEM_ADDRESS_A.index_in_air()].clone(),
94+
up[COL_INDEX_MEM_ADDRESS_B.index_in_air()].clone(),
95+
up[COL_INDEX_MEM_ADDRESS_C.index_in_air()].clone(),
9696
);
9797

98-
let nu_a =
99-
flag_a.clone() * operand_a.clone() + value_a.clone() * (AB::Expr::ONE - flag_a.clone());
100-
let nu_b = flag_b.clone() * operand_b.clone() + value_b * (AB::Expr::ONE - flag_b.clone());
101-
let nu_c = flag_c.clone() * fp.clone() + value_c.clone() * (AB::Expr::ONE - flag_c.clone());
98+
let flag_a_minus_one = flag_a.clone() - AB::F::ONE;
99+
let flag_b_minus_one = flag_b.clone() - AB::F::ONE;
100+
let flag_c_minus_one = flag_c.clone() - AB::F::ONE;
102101

103-
builder.assert_zero((AB::Expr::ONE - flag_a) * (addr_a - (fp.clone() + operand_a)));
104-
builder.assert_zero((AB::Expr::ONE - flag_b) * (addr_b - (fp.clone() + operand_b)));
105-
builder.assert_zero(
106-
(AB::Expr::ONE - flag_c) * (addr_c.clone() - (fp.clone() + operand_c.clone())),
107-
);
102+
let nu_a = flag_a.clone() * operand_a.clone() + value_a.clone() * -flag_a_minus_one.clone();
103+
let nu_b = flag_b.clone() * operand_b.clone() + value_b * -flag_b_minus_one.clone();
104+
let nu_c = flag_c.clone() * fp.clone() + value_c.clone() * -flag_c_minus_one.clone();
105+
106+
let fp_plus_operand_a = fp.clone() + operand_a;
107+
let fp_plus_operand_b = fp.clone() + operand_b;
108+
let fp_plus_operand_c = fp.clone() + operand_c.clone();
109+
let pc_plus_one = pc.clone() + AB::F::ONE;
110+
let nu_a_minus_one = nu_a.clone() - AB::F::ONE;
111+
112+
builder.assert_zero(flag_a_minus_one * (addr_a - fp_plus_operand_a));
113+
builder.assert_zero(flag_b_minus_one * (addr_b - fp_plus_operand_b));
114+
builder.assert_zero(flag_c_minus_one * (addr_c.clone() - fp_plus_operand_c));
108115

109116
builder.assert_zero(add * (nu_b.clone() - (nu_a.clone() + nu_c.clone())));
110117
builder.assert_zero(mul * (nu_b.clone() - nu_a.clone() * nu_c.clone()));
111118

112119
builder.assert_zero(deref.clone() * (addr_c - (value_a + operand_c)));
113120
builder.assert_zero(deref.clone() * aux.clone() * (value_c.clone() - nu_b.clone()));
114-
builder.assert_zero(deref * (AB::Expr::ONE - aux) * (value_c - fp.clone()));
121+
builder.assert_zero(deref * (aux - AB::F::ONE) * (value_c - fp.clone()));
115122

116-
builder.assert_zero(
117-
(AB::Expr::ONE - jump.clone()) * (next_pc.clone() - (pc.clone() + AB::Expr::ONE)),
118-
);
119-
builder.assert_zero((AB::Expr::ONE - jump.clone()) * (next_fp.clone() - fp.clone()));
123+
builder.assert_zero((jump.clone() - AB::F::ONE) * (next_pc.clone() - pc_plus_one.clone()));
124+
builder.assert_zero((jump.clone() - AB::F::ONE) * (next_fp.clone() - fp.clone()));
120125

121-
builder.assert_zero(jump.clone() * nu_a.clone() * (AB::Expr::ONE - nu_a.clone()));
126+
builder.assert_zero(jump.clone() * nu_a.clone() * nu_a_minus_one.clone());
122127
builder.assert_zero(jump.clone() * nu_a.clone() * (next_pc.clone() - nu_b));
123128
builder.assert_zero(jump.clone() * nu_a.clone() * (next_fp.clone() - nu_c));
124-
builder.assert_zero(
125-
jump.clone() * (AB::Expr::ONE - nu_a.clone()) * (next_pc - (pc + AB::Expr::ONE)),
126-
);
127-
builder.assert_zero(jump * (AB::Expr::ONE - nu_a) * (next_fp - fp));
129+
builder.assert_zero(jump.clone() * nu_a_minus_one.clone() * (next_pc - pc_plus_one));
130+
builder.assert_zero(jump * nu_a_minus_one * (next_fp - fp));
128131
}
129132
}

0 commit comments

Comments
 (0)