diff --git a/Cargo.toml b/Cargo.toml index 582ecff..2d5906c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,6 +21,7 @@ ark-relations = { version = "0.5.0", default-features = false } educe = "0.6.0" tracing = { version = "^0.1.0", default-features = false, features = ["attributes"] } +itertools = { version = "0.14.0", default-features = false, features = [ "use_alloc" ] } num-bigint = { version = "0.4", default-features = false } num-traits = { version = "0.2", default-features = false } num-integer = { version = "0.1.44", default-features = false } @@ -45,7 +46,7 @@ tracing-subscriber = { version = "0.3", default-features = true } [features] default = ["std"] -std = ["ark-ff/std", "ark-relations/std", "ark-std/std", "num-bigint/std"] +std = ["ark-ff/std", "ark-relations/std", "ark-std/std", "num-bigint/std", "itertools/use_std" ] parallel = ["std", "ark-ff/parallel", "ark-std/parallel"] [[bench]] diff --git a/src/fields/cubic_extension.rs b/src/fields/cubic_extension.rs index 5bf92b1..db0012f 100644 --- a/src/fields/cubic_extension.rs +++ b/src/fields/cubic_extension.rs @@ -9,7 +9,7 @@ use ark_ff::{ CubicExtConfig, Zero, }; use ark_relations::gr1cs::{ConstraintSystemRef, Namespace, SynthesisError}; -use core::{borrow::Borrow, marker::PhantomData}; +use core::{borrow::Borrow, iter::Sum, marker::PhantomData}; use educe::Educe; /// This struct is the `R1CS` equivalent of the cubic extension field type @@ -576,3 +576,41 @@ where Ok(Self::new(c0, c1, c2)) } } + +impl Sum for CubicExtVar +where + BF: FieldVar, + for<'a> &'a BF: FieldOpsBounds<'a, P::BaseField, BF>, + P: CubicExtVarConfig, +{ + #[inline] + #[tracing::instrument(target = "gr1cs", skip(iter))] + fn sum>(iter: I) -> Self { + let (c0s, c1s, c2s): (Vec<_>, Vec<_>, Vec<_>) = + itertools::multiunzip(iter.map(|x| (x.c0, x.c1, x.c2))); + let c0 = c0s.into_iter().sum(); + let c1 = c1s.into_iter().sum(); + let c2 = c2s.into_iter().sum(); + + Self::new(c0, c1, c2) + } +} + +impl<'a, BF, P> Sum<&'a Self> for CubicExtVar +where + BF: FieldVar, + for<'b> &'b BF: FieldOpsBounds<'b, P::BaseField, BF>, + P: CubicExtVarConfig, +{ + #[inline] + #[tracing::instrument(target = "gr1cs", skip(iter))] + fn sum>(iter: I) -> Self { + let (c0s, c1s, c2s): (Vec<_>, Vec<_>, Vec<_>) = + itertools::multiunzip(iter.map(|x| (&x.c0, &x.c1, &x.c2))); + let c0 = c0s.into_iter().sum(); + let c1 = c1s.into_iter().sum(); + let c2 = c2s.into_iter().sum(); + + Self::new(c0, c1, c2) + } +} diff --git a/src/fields/emulated_fp/field_var.rs b/src/fields/emulated_fp/field_var.rs index ee2230c..b525fec 100644 --- a/src/fields/emulated_fp/field_var.rs +++ b/src/fields/emulated_fp/field_var.rs @@ -1,3 +1,5 @@ +use core::iter::Sum; + use super::{params::OptimizationType, AllocatedEmulatedFpVar, MulResultVar}; use crate::{ boolean::Boolean, @@ -471,3 +473,17 @@ impl EmulatedFpVar { } } } + +impl Sum for EmulatedFpVar { + #[tracing::instrument(target = "gr1cs", skip(iter))] + fn sum>(iter: I) -> Self { + iter.fold(Self::zero(), |acc, x| acc + x) + } +} + +impl<'a, TargetF: PrimeField, BaseF: PrimeField> Sum<&'a Self> for EmulatedFpVar { + #[tracing::instrument(target = "gr1cs", skip(iter))] + fn sum>(iter: I) -> Self { + iter.fold(Self::zero(), |acc, x| acc + x) + } +} diff --git a/src/fields/fp/mod.rs b/src/fields/fp/mod.rs index f108bfa..e9e8c14 100644 --- a/src/fields/fp/mod.rs +++ b/src/fields/fp/mod.rs @@ -2,11 +2,10 @@ use ark_ff::{BigInteger, PrimeField}; use ark_relations::gr1cs::{ ConstraintSystemRef, LinearCombination, Namespace, SynthesisError, Variable, }; - -use core::borrow::Borrow; +use ark_std::{borrow::Borrow, iter::Sum, vec::Vec}; +use itertools::zip_eq; use crate::{boolean::AllocatedBool, convert::ToConstraintFieldGadget, prelude::*, Assignment}; -use ark_std::{iter::Sum, vec::Vec}; mod cmp; @@ -156,10 +155,8 @@ impl AllocatedFp { /// This does not create any constraints and only creates one linear /// combination. /// - /// # Panics - /// - /// Panics if you pass an empty iterator. - pub fn add_many, I: Iterator>(iter: I) -> Self { + /// Returns `None` if you pass an empty iterator. + pub fn add_many, I: IntoIterator>(iter: I) -> Option { let mut cs = ConstraintSystemRef::None; let mut has_value = true; let mut value = F::zero(); @@ -179,14 +176,128 @@ impl AllocatedFp { new_lc = new_lc + variable.variable; num_iters += 1; } - assert_ne!(num_iters, 0); + if num_iters == 0 { + return None; // No elements to add + } + + let variable = cs.new_lc(new_lc).unwrap(); + + if has_value { + Some(AllocatedFp::new(Some(value), variable, cs)) + } else { + Some(AllocatedFp::new(None, variable, cs)) + } + } + + /// Computes the inner product of two iterators of `AllocatedFp` elements. + /// + /// + /// This does not create any constraints and only creates one linear + /// combination. + /// + /// # Panics + /// + /// Panics if the iterators are of different lengths. + pub fn linear_combination(this: I1, other: I2) -> Option + where + B1: Borrow, + B2: Borrow, + I1: IntoIterator, + I2: IntoIterator, + { + let mut cs = ConstraintSystemRef::None; + let mut has_value = true; + let mut value = F::zero(); + let mut new_lc = lc!(); + + let mut num_iters = 0; + for (coeff, variable) in zip_eq(this, other) { + let coeff = *coeff.borrow(); + let variable = variable.borrow(); + if !variable.cs.is_none() { + cs = cs.or(variable.cs.clone()); + } + if variable.value.is_none() { + has_value = false; + } else { + value += coeff * variable.value.unwrap(); + } + new_lc += (coeff, variable.variable); + num_iters += 1; + } + if num_iters == 0 { + return None; // No elements to add + } + + let variable = cs.new_lc(new_lc).unwrap(); + + if has_value { + Some(AllocatedFp::new(Some(value), variable, cs)) + } else { + Some(AllocatedFp::new(None, variable, cs)) + } + } + + /// Computes the inner product of two iterators of `AllocatedFp` elements. + /// + /// + /// This does not create any constraints and only creates one linear + /// combination. + /// + /// # Panics + /// + /// Panics if the iterators are of different lengths. + pub fn inner_product(this: I1, other: I2) -> Option + where + B1: Borrow, + B2: Borrow, + I1: IntoIterator, + I2: IntoIterator, + { + let mut cs = ConstraintSystemRef::None; + let mut has_value = true; + let mut value = F::zero(); + let mut new_lc = lc!(); + let mut num_iters = 0; + for (v1, v2) in zip_eq(this, other) { + let v1 = v1.borrow(); + let v2 = v2.borrow(); + cs = cs.or(v1.cs.clone()).or(v2.cs.clone()); + match (v1.value, v2.value) { + (Some(val1), Some(val2)) => value += val1 * val2, + (..) => has_value = false, + } + if v1.cs.is_none() && v2.cs.is_none() { + // both v1 and v2 should be constants + let v1 = v1.value?; + let v2 = v2.value?; + let product = v1 * v2; + new_lc += (product, Variable::One); + } + if v1.cs.is_none() { + // v1 should be a constant + let v1 = v1.value?; + new_lc += (v1, v2.variable); + } else if v2.cs.is_none() { + // v2 should be a constant + let v2 = v2.value?; + new_lc += (v2, v1.variable); + } else { + let product = v1.mul(v2); + new_lc += (F::ONE, product.variable); + } + num_iters += 1; + } + if num_iters == 0 { + return None; // No elements to compute the inner product + } let variable = cs.new_lc(new_lc).unwrap(); if has_value { - AllocatedFp::new(Some(value), variable, cs) + Some(AllocatedFp::new(Some(value), variable, cs)) } else { - AllocatedFp::new(None, variable, cs) + Some(AllocatedFp::new(None, variable, cs)) } } @@ -800,6 +911,49 @@ impl FieldVar for FpVar { } } + /// Computes the inner product of two slices of `FpVar`. + /// This is faster for the `ConstraintSystem` to process as it directly creates + /// the minimal number of linear combinations. + #[tracing::instrument(target = "gr1cs")] + fn inner_product(this: &[Self], other: &[Self]) -> Result { + if this.len() != other.len() { + return Err(SynthesisError::Unsatisfiable); + } + + let mut lc_vars = vec![]; + let mut lc_coeffs = vec![]; + let mut sum_constants = F::zero(); + // constants, linear_combinations, and variables separately + let (vars_left, vars_right): (Vec<_>, Vec<_>) = this + .iter() + .zip(other) + .filter_map(|(x, y)| match (x, y) { + (FpVar::Constant(x), FpVar::Constant(y)) => { + // If both are constants, we can sum them directly + sum_constants += *x * y; + None + }, + (FpVar::Constant(x), FpVar::Var(y)) | (FpVar::Var(y), FpVar::Constant(x)) => { + // If one is a constant, we can treat it as a linear combination + lc_vars.push(y); + lc_coeffs.push(*x); + None + }, + // If both are variables, we keep them for the inner product + (FpVar::Var(x), FpVar::Var(y)) => Some((x, y)), + }) + .unzip(); + let sum_constants = FpVar::Constant(sum_constants); + let sum_lc = AllocatedFp::linear_combination(lc_coeffs, lc_vars).map(FpVar::Var); + let sum_variables = AllocatedFp::inner_product(vars_left, vars_right).map(FpVar::Var); + + match (sum_lc, sum_variables) { + (Some(a), Some(b)) => Ok(a + b + sum_constants), + (Some(a), None) | (None, Some(a)) => Ok(a + sum_constants), + (None, None) => Ok(sum_constants), + } + } + #[tracing::instrument(target = "gr1cs")] fn frobenius_map(&self, power: usize) -> Result { match self { @@ -1107,8 +1261,9 @@ impl<'a, F: PrimeField> Sum<&'a FpVar> for FpVar { if variables.is_empty() { return FpVar::Constant(sum_constants); } - let sum_variables = FpVar::Var(AllocatedFp::::add_many(variables.into_iter())); - sum_variables + sum_constants + AllocatedFp::add_many(variables).map_or(FpVar::Constant(sum_constants), |sum_vars| { + FpVar::Var(sum_vars) + sum_constants + }) } } @@ -1128,17 +1283,19 @@ impl<'a, F: PrimeField> Sum> for FpVar { if variables.is_empty() { return FpVar::Constant(sum_constants); } - let sum_variables = FpVar::Var(AllocatedFp::::add_many(variables.into_iter())); - sum_variables + sum_constants + AllocatedFp::add_many(variables).map_or(FpVar::Constant(sum_constants), |sum_vars| { + FpVar::Var(sum_vars) + sum_constants + }) } } #[cfg(test)] mod test { use crate::{ - alloc::{AllocVar, AllocationMode}, + alloc::AllocVar, eq::EqGadget, - fields::fp::FpVar, + fields::{fp::FpVar, FieldVar}, + test_utils::{combination, modes}, GR1CSVar, }; use ark_relations::gr1cs::ConstraintSystem; @@ -1146,33 +1303,56 @@ mod test { use ark_test_curves::bls12_381::Fr; #[test] - fn test_sum_fpvar() { + fn test_inner_product() { let mut rng = ark_std::test_rng(); let cs = ConstraintSystem::new_ref(); - let mut sum_expected = Fr::zero(); + for (a_mode, b_mode) in combination(modes()) { + let a = (0..10) + .map(|_| FpVar::new_variable(cs.clone(), || Ok(Fr::rand(&mut rng)), a_mode).ok()) + .collect::>>() + .unwrap(); + let b = (0..10) + .map(|_| FpVar::new_variable(cs.clone(), || Ok(Fr::rand(&mut rng)), b_mode).ok()) + .collect::>>() + .unwrap(); + let a = [a, b].concat(); + let b = a.iter().rev().cloned().collect::>(); + let inner_product: FpVar = FpVar::inner_product(&a, &b).unwrap(); + let mut expected = Fr::zero(); + for (x, y) in a.iter().zip(b) { + expected += x.value().unwrap() * y.value().unwrap(); + } + inner_product + .enforce_equal(&FpVar::Constant(expected)) + .unwrap(); - let mut v = Vec::new(); - for _ in 0..10 { - let a = Fr::rand(&mut rng); - sum_expected += &a; - v.push( - FpVar::::new_variable(cs.clone(), || Ok(a), AllocationMode::Constant).unwrap(), - ); - } - for _ in 0..10 { - let a = Fr::rand(&mut rng); - sum_expected += &a; - v.push( - FpVar::::new_variable(cs.clone(), || Ok(a), AllocationMode::Witness).unwrap(), - ); + assert!(cs.is_satisfied().unwrap()); } + } - let sum: FpVar = v.iter().sum(); + #[test] + fn test_sum_fpvar() { + let mut rng = ark_std::test_rng(); + let cs = ConstraintSystem::new_ref(); - sum.enforce_equal(&FpVar::Constant(sum_expected)).unwrap(); + for (a_mode, b_mode) in combination(modes()) { + let a = (0..10) + .map(|_| FpVar::new_variable(cs.clone(), || Ok(Fr::rand(&mut rng)), a_mode).ok()) + .collect::>>() + .unwrap(); + let b = (0..10) + .map(|_| FpVar::new_variable(cs.clone(), || Ok(Fr::rand(&mut rng)), b_mode).ok()) + .collect::>>() + .unwrap(); + let v = [a, b].concat(); + let sum: FpVar = v.iter().sum(); - assert!(cs.is_satisfied().unwrap()); - assert_eq!(sum.value().unwrap(), sum_expected); + let sum_expected = v.iter().map(|x| x.value().unwrap()).sum(); + sum.enforce_equal(&FpVar::Constant(sum_expected)).unwrap(); + + assert!(cs.is_satisfied().unwrap()); + assert_eq!(sum.value().unwrap(), sum_expected); + } } } diff --git a/src/fields/mod.rs b/src/fields/mod.rs index 525830e..1cb0274 100644 --- a/src/fields/mod.rs +++ b/src/fields/mod.rs @@ -2,6 +2,7 @@ use ark_ff::{prelude::*, BitIteratorBE}; use ark_relations::gr1cs::{ConstraintSystemRef, SynthesisError}; use core::{ fmt::Debug, + iter::Sum, ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign}, }; @@ -90,6 +91,8 @@ pub trait FieldVar: + AddAssign + SubAssign + MulAssign + + Sum + + for<'a> Sum<&'a Self> + Debug { /// Returns the constant `F::zero()`. @@ -199,6 +202,14 @@ pub trait FieldVar: } } + /// Computes the inner product of `this` and `other`. + fn inner_product(this: &[Self], other: &[Self]) -> Result { + if this.len() != other.len() { + return Err(SynthesisError::Unsatisfiable); + } + Ok(this.iter().zip(other).map(|(a, b)| a.clone() * b).sum()) + } + /// Computes the frobenius map over `self`. fn frobenius_map(&self, power: usize) -> Result; diff --git a/src/fields/quadratic_extension.rs b/src/fields/quadratic_extension.rs index f29b7fb..4c5dd81 100644 --- a/src/fields/quadratic_extension.rs +++ b/src/fields/quadratic_extension.rs @@ -9,7 +9,7 @@ use ark_ff::{ Zero, }; use ark_relations::gr1cs::{ConstraintSystemRef, Namespace, SynthesisError}; -use core::{borrow::Borrow, marker::PhantomData}; +use core::{borrow::Borrow, iter::Sum, marker::PhantomData}; use educe::Educe; /// This struct is the `R1CS` equivalent of the quadratic extension field type @@ -558,3 +558,37 @@ where Ok(Self::new(c0, c1)) } } + +impl Sum for QuadExtVar +where + BF: FieldVar, + for<'a> &'a BF: FieldOpsBounds<'a, P::BaseField, BF>, + P: QuadExtVarConfig, +{ + #[inline] + #[tracing::instrument(target = "gr1cs", skip(iter))] + fn sum>(iter: I) -> Self { + let (c0s, c1s): (Vec<_>, Vec<_>) = itertools::multiunzip(iter.map(|x| (x.c0, x.c1))); + let c0 = c0s.into_iter().sum(); + let c1 = c1s.into_iter().sum(); + + Self::new(c0, c1) + } +} + +impl<'a, BF, P> Sum<&'a Self> for QuadExtVar +where + BF: FieldVar, + for<'b> &'b BF: FieldOpsBounds<'b, P::BaseField, BF>, + P: QuadExtVarConfig, +{ + #[inline] + #[tracing::instrument(target = "gr1cs", skip(iter))] + fn sum>(iter: I) -> Self { + let (c0s, c1s): (Vec<_>, Vec<_>) = itertools::multiunzip(iter.map(|x| (&x.c0, &x.c1))); + let c0 = c0s.into_iter().sum(); + let c1 = c1s.into_iter().sum(); + + Self::new(c0, c1) + } +}