Skip to content

Commit 20d238d

Browse files
committed
fix: incorrect calc of surfeit related value
1 parent c8c51a3 commit 20d238d

File tree

4 files changed

+215
-26
lines changed

4 files changed

+215
-26
lines changed

src/fields/emulated_fp/allocated_field_var.rs

Lines changed: 85 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -242,12 +242,17 @@ impl<TargetF: PrimeField, BaseF: PrimeField> AllocatedEmulatedFpVar<TargetF, Bas
242242
}
243243
}
244244

245+
let padding_bit_len = {
246+
let mut one = BaseF::ONE.into_bigint();
247+
one <<= surfeit as u32;
248+
BaseF::from(one)
249+
};
245250
let result = AllocatedEmulatedFpVar::<TargetF, BaseF> {
246251
cs: self.cs(),
247252
limbs,
248-
num_of_additions_over_normal_form: self.num_of_additions_over_normal_form
249-
+ (other.num_of_additions_over_normal_form + BaseF::one())
250-
+ (other.num_of_additions_over_normal_form + BaseF::one()),
253+
num_of_additions_over_normal_form: self.num_of_additions_over_normal_form // this_limb
254+
+ padding_bit_len // pad_non_top_limb / pad_top_limb
255+
+ BaseF::one(), // pad_to_kp_limb
251256
is_in_the_normal_form: false,
252257
target_phantom: PhantomData,
253258
};
@@ -428,9 +433,19 @@ impl<TargetF: PrimeField, BaseF: PrimeField> AllocatedEmulatedFpVar<TargetF, Bas
428433
Ok(AllocatedMulResultVar {
429434
cs: self.cs(),
430435
limbs: prod_limbs,
436+
// New number is upper bounded by:
437+
//
438+
// (a+1)2^{bits_per_limb} * (b+1)2^{bits_per_limb} * m = (ab+a+b+1)*m*2^{2*bits_per_limb}
439+
//
440+
// where `a = self_reduced.num_of_additions_over_normal_form` and
441+
// `b = other_reduced.num_of_additions_over_normal_form`
442+
// - why m pair: at cell m, there are m possible pairs (one limb from each var) that can add to cell m
443+
//
444+
// In theory, we can let `prod_of_num_of_additions = (m(ab+a+b+1)-1)`. But below, we use an overestimation.
431445
prod_of_num_of_additions: (self_reduced.num_of_additions_over_normal_form
432446
+ BaseF::one())
433-
* (other_reduced.num_of_additions_over_normal_form + BaseF::one()),
447+
* (other_reduced.num_of_additions_over_normal_form + BaseF::one())
448+
* BaseF::from((params.num_limbs) as u32),
434449
target_phantom: PhantomData,
435450
})
436451
}
@@ -464,13 +479,6 @@ impl<TargetF: PrimeField, BaseF: PrimeField> AllocatedEmulatedFpVar<TargetF, Bas
464479
for limb in p_representations.iter() {
465480
p_gadget_limbs.push(FpVar::<BaseF>::Constant(*limb));
466481
}
467-
let p_gadget = AllocatedEmulatedFpVar::<TargetF, BaseF> {
468-
cs: self.cs(),
469-
limbs: p_gadget_limbs,
470-
num_of_additions_over_normal_form: BaseF::one(),
471-
is_in_the_normal_form: false,
472-
target_phantom: PhantomData,
473-
};
474482

475483
// Get delta = self - other
476484
let cs = self.cs().or(other.cs()).or(should_enforce.cs());
@@ -494,7 +502,7 @@ impl<TargetF: PrimeField, BaseF: PrimeField> AllocatedEmulatedFpVar<TargetF, Bas
494502

495503
// Compute k * p
496504
let mut kp_gadget_limbs = Vec::new();
497-
for limb in p_gadget.limbs.iter() {
505+
for limb in p_gadget_limbs.iter() {
498506
kp_gadget_limbs.push(limb * &k_gadget);
499507
}
500508

@@ -916,3 +924,68 @@ impl<TargetF: PrimeField, BaseF: PrimeField> Clone for AllocatedEmulatedFpVar<Ta
916924
}
917925
}
918926
}
927+
928+
#[cfg(test)]
929+
mod test {
930+
use ark_ec::{bls12::Bls12Config, pairing::Pairing};
931+
use ark_relations::r1cs::ConstraintSystem;
932+
933+
use crate::{
934+
alloc::AllocVar,
935+
fields::{
936+
emulated_fp::{test::check_constraint, AllocatedEmulatedFpVar},
937+
fp::FpVar,
938+
},
939+
};
940+
941+
#[test]
942+
fn pr_157_sub() {
943+
type TargetF = <ark_bls12_381::Config as Bls12Config>::Fp;
944+
type BaseF = <ark_bls12_377::Bls12_377 as Pairing>::ScalarField;
945+
946+
let self_limb_values = [
947+
100, 2618, 1428, 2152, 2602, 1242, 2823, 511, 1752, 2058, 3599, 1113, 3207, 3601, 2736,
948+
435, 1108, 2965, 2685, 1705, 1016, 1343, 1760, 2039, 1355, 1767, 2355, 1945, 3594,
949+
4066, 1913, 2646,
950+
];
951+
let self_num_of_additions_over_normal_form = 1;
952+
let self_is_in_the_normal_form = false;
953+
let other_limb_values = [
954+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
955+
0, 0, 4,
956+
];
957+
let other_num_of_additions_over_normal_form = 1;
958+
let other_is_in_the_normal_form = false;
959+
960+
let cs = ConstraintSystem::new_ref();
961+
962+
let left_limb = self_limb_values
963+
.iter()
964+
.map(|v| FpVar::new_input(cs.clone(), || Ok(BaseF::from(*v))).unwrap())
965+
.collect();
966+
let left: AllocatedEmulatedFpVar<TargetF, BaseF> = AllocatedEmulatedFpVar {
967+
cs: cs.clone(),
968+
limbs: left_limb,
969+
num_of_additions_over_normal_form: BaseF::from(self_num_of_additions_over_normal_form),
970+
is_in_the_normal_form: self_is_in_the_normal_form,
971+
target_phantom: std::marker::PhantomData,
972+
};
973+
974+
let other_limb = other_limb_values
975+
.iter()
976+
.map(|v| FpVar::new_input(cs.clone(), || Ok(BaseF::from(*v))).unwrap())
977+
.collect();
978+
let right: AllocatedEmulatedFpVar<TargetF, BaseF> = AllocatedEmulatedFpVar {
979+
cs: cs.clone(),
980+
limbs: other_limb,
981+
num_of_additions_over_normal_form: BaseF::from(other_num_of_additions_over_normal_form),
982+
is_in_the_normal_form: other_is_in_the_normal_form,
983+
target_phantom: std::marker::PhantomData,
984+
};
985+
986+
let result = left.sub_without_reduce(&right).unwrap();
987+
assert!(check_constraint(&left));
988+
assert!(check_constraint(&right));
989+
assert!(check_constraint(&result));
990+
}
991+
}

src/fields/emulated_fp/allocated_mul_result.rs

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ impl<TargetF: PrimeField, BaseF: PrimeField> AllocatedMulResultVar<TargetF, Base
113113
};
114114

115115
// Step 2: compute surfeit
116-
let surfeit = overhead!(self.prod_of_num_of_additions + BaseF::one()) + 1 + 1;
116+
let surfeit = overhead!(self.prod_of_num_of_additions + BaseF::one());
117117

118118
// Step 3: allocate k
119119
let k_bits = {
@@ -284,3 +284,42 @@ impl<TargetF: PrimeField, BaseF: PrimeField> AllocatedMulResultVar<TargetF, Base
284284
}
285285
}
286286
}
287+
288+
#[cfg(test)]
289+
mod test {
290+
use ark_ec::{bls12::Bls12Config, pairing::Pairing};
291+
use ark_ff::PrimeField;
292+
use ark_relations::r1cs::ConstraintSystem;
293+
294+
use crate::{
295+
alloc::AllocVar,
296+
fields::emulated_fp::{
297+
test::{check_constraint, check_mulres_constraint},
298+
AllocatedEmulatedFpVar,
299+
},
300+
};
301+
302+
#[test]
303+
fn pr_157_mul() {
304+
type TargetF = <ark_bls12_381::Config as Bls12Config>::Fp;
305+
type BaseF = <ark_bls12_377::Bls12_377 as Pairing>::ScalarField;
306+
307+
let cs = ConstraintSystem::new_ref();
308+
309+
let left: AllocatedEmulatedFpVar<TargetF, BaseF> =
310+
AllocatedEmulatedFpVar::new_input(cs.clone(), || {
311+
Ok(TargetF::from(
312+
TargetF::from(1).into_bigint()
313+
<< (<TargetF as PrimeField>::MODULUS_BIT_SIZE - 1),
314+
) + TargetF::from(-1))
315+
})
316+
.unwrap();
317+
318+
let right: AllocatedEmulatedFpVar<TargetF, BaseF> = left.clone();
319+
320+
let result = left.mul_without_reduce(&right).unwrap();
321+
assert!(check_constraint(&left));
322+
assert!(check_constraint(&right));
323+
assert!(check_mulres_constraint(&result));
324+
}
325+
}

src/fields/emulated_fp/mod.rs

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ macro_rules! overhead {
152152
use ark_ff::BigInteger;
153153
let num = $x;
154154
let num_bits = num.into_bigint().to_bits_be();
155+
155156
let mut skipped_bits = 0;
156157
for b in num_bits.iter() {
157158
if *b == false {
@@ -168,10 +169,13 @@ macro_rules! overhead {
168169
}
169170
}
170171

171-
if is_power_of_2 {
172-
num_bits.len() - skipped_bits
172+
// let log(0) = 0 in our case
173+
if num == BaseF::zero() {
174+
0
175+
} else if is_power_of_2 {
176+
num_bits.len() - skipped_bits - 1
173177
} else {
174-
num_bits.len() - skipped_bits + 1
178+
num_bits.len() - skipped_bits
175179
}
176180
}};
177181
}
@@ -200,3 +204,45 @@ pub use field_var::*;
200204

201205
mod mul_result;
202206
pub use mul_result::*;
207+
208+
#[cfg(test)]
209+
mod test {
210+
use ark_ff::PrimeField;
211+
212+
use crate::{
213+
fields::emulated_fp::{params::get_params, AllocatedEmulatedFpVar},
214+
R1CSVar,
215+
};
216+
217+
use super::AllocatedMulResultVar;
218+
219+
pub(crate) fn check_constraint<TargetF: PrimeField, BaseF: PrimeField>(
220+
emulated_fpvar: &AllocatedEmulatedFpVar<TargetF, BaseF>,
221+
) -> bool {
222+
let limb_values = emulated_fpvar.limbs.value().unwrap();
223+
let params = get_params(
224+
TargetF::MODULUS_BIT_SIZE as usize,
225+
BaseF::MODULUS_BIT_SIZE as usize,
226+
emulated_fpvar.get_optimization_type(),
227+
);
228+
let bits_per_limb = params.bits_per_limb;
229+
let upper_bound = (emulated_fpvar.num_of_additions_over_normal_form + BaseF::one())
230+
* (BaseF::from(BaseF::from(1).into_bigint() << bits_per_limb as u32) + BaseF::from(-1));
231+
return !limb_values.iter().any(|value| value > &upper_bound);
232+
}
233+
234+
pub(crate) fn check_mulres_constraint<TargetF: PrimeField, BaseF: PrimeField>(
235+
emulated_fpvar: &AllocatedMulResultVar<TargetF, BaseF>,
236+
) -> bool {
237+
let limb_values: Vec<_> = emulated_fpvar.limbs.value().unwrap();
238+
let params = get_params(
239+
TargetF::MODULUS_BIT_SIZE as usize,
240+
BaseF::MODULUS_BIT_SIZE as usize,
241+
emulated_fpvar.get_optimization_type(),
242+
);
243+
let bits_per_limb = params.bits_per_limb * 2;
244+
let upper_bound = (emulated_fpvar.prod_of_num_of_additions + BaseF::one())
245+
* (BaseF::from(BaseF::from(1).into_bigint() << bits_per_limb as u32) + BaseF::from(-1));
246+
return !limb_values.iter().any(|value| value > &upper_bound);
247+
}
248+
}

src/fields/emulated_fp/reduce.rs

Lines changed: 41 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -155,24 +155,34 @@ impl<TargetF: PrimeField, BaseF: PrimeField> Reducer<TargetF, BaseF> {
155155
elem.get_optimization_type(),
156156
);
157157

158-
if 2 * params.bits_per_limb + ark_std::log2(params.num_limbs) as usize
159-
> BaseF::MODULUS_BIT_SIZE as usize - 1
158+
// `2 * params.bits_per_limb + ark_std::log2(params.num_limbs + 1)` needs to be `<= BaseF::MODULUS_BIT_SIZE as usize - 4`
159+
// - see `group_and_check_equality` for more details
160+
if 2 * params.bits_per_limb + ark_std::log2(params.num_limbs + 1) as usize
161+
>= BaseF::MODULUS_BIT_SIZE as usize - 3
160162
{
161163
panic!("The current limb parameters do not support multiplication.");
162164
}
163165

164166
loop {
167+
// this needs to be adjusted if we modify `prod_of_num_of_additions` for `AllocatedMulResultVar`
168+
// - see `mul_without_reduce` in `src/fields/emulated_fp/allocated_field_var.rs`
165169
let prod_of_num_of_additions = (elem.num_of_additions_over_normal_form + BaseF::one())
166170
* (elem_other.num_of_additions_over_normal_form + BaseF::one());
167-
let overhead_limb = overhead!(prod_of_num_of_additions.mul(
168-
&BaseF::from_bigint(<BaseF as PrimeField>::BigInt::from(
169-
(params.num_limbs) as u64
170-
))
171-
.unwrap()
172-
));
173-
let bits_per_mulresult_limb = 2 * (params.bits_per_limb + 1) + overhead_limb;
171+
let overhead_limb = overhead!(
172+
BaseF::one()
173+
+ prod_of_num_of_additions.mul(
174+
&BaseF::from_bigint(<BaseF as PrimeField>::BigInt::from(
175+
(params.num_limbs) as u64
176+
))
177+
.unwrap()
178+
)
179+
);
174180

175-
if bits_per_mulresult_limb < BaseF::MODULUS_BIT_SIZE as usize {
181+
let bits_per_mulresult_limb = 2 * params.bits_per_limb + overhead_limb;
182+
183+
// we need `bits_per_mulresult_limb <= MODULUS_BIT_SIZE - 4`
184+
// - see `group_and_check_equality` for more details
185+
if bits_per_mulresult_limb < (BaseF::MODULUS_BIT_SIZE - 3) as usize {
176186
break;
177187
}
178188

@@ -211,6 +221,22 @@ impl<TargetF: PrimeField, BaseF: PrimeField> Reducer<TargetF, BaseF> {
211221
let zero = FpVar::<BaseF>::zero();
212222

213223
let mut limb_pairs = Vec::<(FpVar<BaseF>, FpVar<BaseF>)>::new();
224+
225+
// this size is closely related to the grouped limb size, padding size, pre_mul_reduce and post_add_reduce
226+
//
227+
// it should be carefully chosen so that 1) no overflow/underflow can happen in this function and 2) num_limb_in_a_group
228+
// is always >=1.
229+
//
230+
// 1. for this function
231+
// - pad_limb has bit size BaseF::MODULUS_BIT_SIZE - 1
232+
// - left/right_total_limb has bit size BaseF::MODULUS_BIT_SIZE - 3
233+
// - carry has even smaller size
234+
// - so, their sum has bit size <= BaseF::MODULUS_BIT_SIZE - 1
235+
//
236+
// 2. for pre_mul_reduce
237+
// - it enforces `2 * bits_per_limb + surfeit <= BaseF::MODULUS_BIT_SIZE - 4`
238+
// - 2 * bits_per_limb in that function == 2 * (bits_per_limb - shift_per_limb) == shift_per_limb
239+
// - so, num_limb_in_a_group is >= 1 for mul
214240
let num_limb_in_a_group = (BaseF::MODULUS_BIT_SIZE as usize
215241
- 1
216242
- surfeit
@@ -240,6 +266,8 @@ impl<TargetF: PrimeField, BaseF: PrimeField> Reducer<TargetF, BaseF> {
240266
let mut groupped_limb_pairs = Vec::<(FpVar<BaseF>, FpVar<BaseF>, usize)>::new();
241267

242268
for limb_pairs_in_a_group in limb_pairs.chunks(num_limb_in_a_group) {
269+
// bit size = num_limb_in_a_group * shift_per_limb + bits_per_limb + true surfeit + 1
270+
// <= BaseF::MODULUS_BIT_SIZE - 3
243271
let mut left_total_limb = zero.clone();
244272
let mut right_total_limb = zero.clone();
245273

@@ -267,6 +295,9 @@ impl<TargetF: PrimeField, BaseF: PrimeField> Reducer<TargetF, BaseF> {
267295
{
268296
let mut pad_limb_repr = BaseF::ONE.into_bigint();
269297

298+
// use padding to avoid underflow
299+
//
300+
// bit size = BaseF::MODULUS_BIT_SIZE - 1
270301
pad_limb_repr <<= (surfeit
271302
+ (bits_per_limb - shift_per_limb)
272303
+ shift_per_limb * num_limb_in_this_group

0 commit comments

Comments
 (0)