@@ -9,6 +9,8 @@ namespace alu;
99pol commit sel;
1010
1111pol commit sel_op_add;
12+ pol commit sel_op_sub;
13+ pol commit sel_op_mul;
1214pol commit sel_op_eq;
1315pol commit sel_op_lt;
1416pol commit sel_op_lte;
@@ -37,24 +39,29 @@ pol commit cf;
3739pol commit helper1;
3840
3941// maximum bits the number can hold (i.e. 8 for a u8):
40- // TODO(MW): Now unused since we redirect the LT/LTE range checks to GT gadget - remove?
4142pol commit max_bits;
4243// maximum value the number can hold (i.e. 255 for a u8), we 'mod' by max_value + 1
4344pol commit max_value;
4445// we need a selector to conditionally lookup ff_gt when inputs a, b are fields:
4546pol commit sel_is_ff;
47+ // we need a selector to conditionally perform u128 multiplication:
48+ pol commit sel_is_u128;
4649
4750pol IS_NOT_FF = 1 - sel_is_ff;
51+ pol IS_NOT_U128 = 1 - sel_is_u128;
4852
4953sel * (1 - sel) = 0;
5054cf * (1 - cf) = 0;
5155sel_is_ff * (1 - sel_is_ff) = 0;
56+ sel_is_u128 * (1 - sel_is_u128) = 0;
5257
5358// TODO: Consider to gate with (1 - sel_tag_err) for op_id. This might help us remove the (1 - sel_tag_err)
5459// in various operation relations below.
5560// Note that the op_ids below represent a binary decomposition (see constants_gen.pil):
5661#[OP_ID_CHECK]
5762op_id = sel_op_add * constants.AVM_EXEC_OP_ID_ALU_ADD
63+ + sel_op_sub * constants.AVM_EXEC_OP_ID_ALU_SUB
64+ + sel_op_mul * constants.AVM_EXEC_OP_ID_ALU_MUL
5865 + sel_op_eq * constants.AVM_EXEC_OP_ID_ALU_EQ
5966 + sel_op_lt * constants.AVM_EXEC_OP_ID_ALU_LT
6067 + sel_op_lte * constants.AVM_EXEC_OP_ID_ALU_LTE
@@ -78,18 +85,28 @@ execution.sel_execute_alu {
7885
7986// IS_FF CHECKING
8087
81- // TODO(MW): remove this and check for all (i.e. replace with sel), just being lazy for now. For add, we don't care, for lt we need to differentiate.
8288pol CHECK_TAG_FF = sel_op_lt + sel_op_lte + sel_op_not;
8389// We prove that sel_is_ff == 1 <==> ia_tag == MEM_TAG_FF
8490pol TAG_FF_DIFF = ia_tag - constants.MEM_TAG_FF;
8591pol commit tag_ff_diff_inv;
8692#[TAG_IS_FF]
8793CHECK_TAG_FF * (TAG_FF_DIFF * (sel_is_ff * (1 - tag_ff_diff_inv) + tag_ff_diff_inv) + sel_is_ff - 1) = 0;
8894
95+ // IS_U128 CHECKING
96+
97+ pol CHECK_TAG_U128 = sel_op_mul;
98+ // We prove that sel_is_u128 == 1 <==> ia_tag == MEM_TAG_U128
99+ pol TAG_U128_DIFF = ia_tag - constants.MEM_TAG_U128;
100+ pol commit tag_u128_diff_inv;
101+ #[TAG_IS_U128]
102+ CHECK_TAG_U128 * (TAG_U128_DIFF * (sel_is_u128 * (1 - tag_u128_diff_inv) + tag_u128_diff_inv) + sel_is_u128 - 1) = 0;
103+
104+ // Note: if we never need sel_is_ff and sel_is_u128 in the same op, can combine the above checks into one
105+
89106// TAG CHECKING
90107
91108// Will become e.g. sel_op_add * ia_tag + (comparison ops) * MEM_TAG_U1 + ....
92- pol EXPECTED_C_TAG = (sel_op_add + sel_op_truncate) * ia_tag + (sel_op_eq + sel_op_lt + sel_op_lte) * constants.MEM_TAG_U1;
109+ pol EXPECTED_C_TAG = (sel_op_add + sel_op_sub + sel_op_truncate + sel_op_mul ) * ia_tag + (sel_op_eq + sel_op_lt + sel_op_lte) * constants.MEM_TAG_U1;
93110
94111// The tag of c is generated by the opcode and is never wrong.
95112// Gating with (1 - sel_tag_err) is necessary because when an error occurs, we have to set the tag to 0,
@@ -120,8 +137,91 @@ sel { ia_tag, max_bits, max_value } in precomputed.sel_tag_parameters { precompu
120137
121138sel_op_add * (1 - sel_op_add) = 0;
122139
123- #[ALU_ADD]
124- sel_op_add * (1 - sel_tag_err) * (ia + ib - ic - cf * (max_value + 1)) = 0;
140+ // SUB
141+
142+ sel_op_sub * (1 - sel_op_sub) = 0;
143+
144+ // ADD & SUB - Shared relation:
145+
146+ // For add, sel_op_add - sel_op_sub = 1 => check a + b - cf * carry = c
147+ // For sub, sel_op_add - sel_op_sub = -1 => check a - b + cf * carry = c
148+ #[ALU_ADD_SUB]
149+ (sel_op_add + sel_op_sub) * (1 - sel_tag_err) * (ia - ic + (sel_op_add - sel_op_sub) * (ib - cf * (max_value + 1))) = 0;
150+
151+ // MUL
152+
153+ sel_op_mul * (1 - sel_op_mul) = 0;
154+
155+ pol commit c_hi;
156+
157+ // MUL - non u128
158+
159+ #[ALU_MUL_NON_U128]
160+ sel_op_mul * IS_NOT_U128 * (1 - sel_tag_err)
161+ * (
162+ ia * ib
163+ - ic
164+ - (max_value + 1) * c_hi
165+ ) = 0;
166+
167+ // MUL - u128
168+
169+ pol commit sel_mul_u128;
170+ // sel_op_mul & sel_is_u128:
171+ sel_mul_u128 - sel_is_u128 * sel_op_mul = 0;
172+
173+ // Taken from vm1:
174+ // We express a, b in 64-bit slices: a = a_l + a_h * 2^64
175+ // b = b_l + b_h * 2^64
176+ // => a * b = a_l * b_l + (a_h * b_l + a_l * b_h) * 2^64 + (a_h * b_h) * 2^128 = c_hi_full * 2^128 + c
177+ // => the 'top bits' are given by (c_hi_full - (a_h * b_h)) * 2^128
178+ // We can show for a 64 bit c_hi = c_hi_full - (a_h * b_h) % 2^64 that:
179+ // a_l * b_l + (a_h * b_l + a_l * b_h) * 2^64 = c_hi * 2^128 + c
180+ // Equivalently (cf = 0 if a_h & b_h = 0):
181+ // a * b_l + a_l * b_h * 2^64 = (cf * 2^64 + c_hi) * 2^128 + c
182+ // => no need for a_h in final relation
183+
184+ pol commit a_lo;
185+ pol commit a_hi;
186+ pol commit b_lo;
187+ pol commit b_hi;
188+ pol TWO_POW_64 = 2 ** 64;
189+
190+ #[A_MUL_DECOMPOSITION]
191+ sel_mul_u128 * (ia - (a_lo + TWO_POW_64 * a_hi)) = 0;
192+ #[B_MUL_DECOMPOSITION]
193+ sel_mul_u128 * (ib - (b_lo + TWO_POW_64 * b_hi)) = 0;
194+
195+ #[ALU_MUL_U128]
196+ sel_mul_u128 * (1 - sel_tag_err)
197+ * (
198+ ia * b_lo + a_lo * b_hi * TWO_POW_64 // a * b without the hi bits
199+ - ic // c_lo
200+ - (max_value + 1) * (cf * TWO_POW_64 + c_hi) // c_hi * 2^128 + (cf ? 2^192 : 0)
201+ ) = 0;
202+
203+ // TODO: Once lookups support expression in tuple, we can inline constant_64 into the lookup.
204+ // Note: only used for MUL, so gated by sel_op_mul
205+ pol commit constant_64;
206+ sel_op_mul * (64 - constant_64) = 0;
207+
208+ #[RANGE_CHECK_MUL_U128_A_LO]
209+ sel_mul_u128 { a_lo, constant_64 } in range_check.sel { range_check.value, range_check.rng_chk_bits };
210+
211+ #[RANGE_CHECK_MUL_U128_A_HI]
212+ sel_mul_u128 { a_hi, constant_64 } in range_check.sel { range_check.value, range_check.rng_chk_bits };
213+
214+ #[RANGE_CHECK_MUL_U128_B_LO]
215+ sel_mul_u128 { b_lo, constant_64 } in range_check.sel { range_check.value, range_check.rng_chk_bits };
216+
217+ #[RANGE_CHECK_MUL_U128_B_HI]
218+ sel_mul_u128 { b_hi, constant_64 } in range_check.sel { range_check.value, range_check.rng_chk_bits };
219+
220+ // No need to range_check c_hi for cases other than u128 because we know a and b's size from the tags and have looked
221+ // up max_value. i.e. we cannot provide a malicious c, c_hi such that a + b - c_hi * 2^n = c passes for n < 128.
222+ // No need to range_check c_lo = ic because the memory write will ensure ic <= max_value.
223+ #[RANGE_CHECK_MUL_U128_C_HI]
224+ sel_mul_u128 { c_hi, constant_64 } in range_check.sel { range_check.value, range_check.rng_chk_bits };
125225
126226// EQ
127227
@@ -280,21 +380,21 @@ sel_op_truncate = sel_trunc_non_trivial + sel_trunc_trivial;
280380#[TRUNC_TRIVIAL_CASE]
281381sel_trunc_trivial * (ia - ic) = 0;
282382
283- pol commit lo_128; // 128-bit low limb of ia.
284- pol commit hi_128; // 128-bit high limb of ia.
383+ // NOTE: reusing a_lo and a_hi columns from MUL in TRUNC:
384+ // For truncate, a_lo = 128-bit low limb of ia and a_hi = 128-bit high limb of ia.
285385pol commit mid;
286386
287387#[LARGE_TRUNC_CANONICAL_DEC]
288- sel_trunc_gte_128 { ia, lo_128, hi_128 }
388+ sel_trunc_gte_128 { ia, a_lo, a_hi }
289389in
290390ff_gt.sel_dec { ff_gt.a, ff_gt.a_lo, ff_gt.a_hi };
291391
292392#[SMALL_TRUNC_VAL_IS_LO]
293- sel_trunc_lt_128 * (lo_128 - ia) = 0;
393+ sel_trunc_lt_128 * (a_lo - ia) = 0;
294394
295- // lo_128 = ic + mid * 2^ia_tag_bits, where 2^ia_tag_bits is max_value + 1.
395+ // a_lo = ic + mid * 2^ia_tag_bits, where 2^ia_tag_bits is max_value + 1.
296396#[TRUNC_LO_128_DECOMPOSITION]
297- sel_trunc_non_trivial * (ic + mid * (max_value + 1) - lo_128 ) = 0;
397+ sel_trunc_non_trivial * (ic + mid * (max_value + 1) - a_lo ) = 0;
298398
299399// TODO: Once lookups support expression in tuple, we can inline mid_bits into the lookup.
300400pol commit mid_bits;
@@ -305,4 +405,4 @@ mid_bits = sel_trunc_non_trivial * (128 - max_bits);
305405// is supported by our range_check gadget.
306406// No need to range_check ic because the memory write will ensure ic <= max_value.
307407#[RANGE_CHECK_TRUNC_MID]
308- sel_trunc_non_trivial {mid, mid_bits} in range_check.sel { range_check.value, range_check.rng_chk_bits };
408+ sel_trunc_non_trivial { mid, mid_bits } in range_check.sel { range_check.value, range_check.rng_chk_bits };
0 commit comments