diff --git a/CHANGELOG.md b/CHANGELOG.md index 0b412830..be1c325d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,8 @@ ## Pending +* [\#93](https://github.com/arkworks-rs/r1cs-std/pull/93) Implement `EqGadget`, `CondSelectGadget`, and `R1CSVar` for `Vec` + ### Breaking changes - [\#86](https://github.com/arkworks-rs/r1cs-std/pull/86) Change the API for domains for coset. diff --git a/src/eq.rs b/src/eq.rs index f1184619..d104032d 100644 --- a/src/eq.rs +++ b/src/eq.rs @@ -128,3 +128,31 @@ impl + R1CSVar, F: Field> EqGadget for [T] { } } } + +// EqGadget for Vec just calls down to EqGadget for [T] +impl + R1CSVar, F: Field> EqGadget for Vec { + #[tracing::instrument(target = "r1cs", skip(self, other))] + fn is_eq(&self, other: &Self) -> Result, SynthesisError> { + self.as_slice().is_eq(other.as_slice()) + } + + #[tracing::instrument(target = "r1cs", skip(self, other))] + fn conditional_enforce_equal( + &self, + other: &Self, + condition: &Boolean, + ) -> Result<(), SynthesisError> { + self.as_slice() + .conditional_enforce_equal(other.as_slice(), condition) + } + + #[tracing::instrument(target = "r1cs", skip(self, other))] + fn conditional_enforce_not_equal( + &self, + other: &Self, + should_enforce: &Boolean, + ) -> Result<(), SynthesisError> { + self.as_slice() + .conditional_enforce_not_equal(other.as_slice(), should_enforce) + } +} diff --git a/src/lib.rs b/src/lib.rs index 8ff44b0d..20cea7ce 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -124,6 +124,19 @@ impl<'a, F: Field, T: 'a + R1CSVar> R1CSVar for &'a T { } } +// R1CSVar for Vec just calls down to R1CSVar for [T] +impl> R1CSVar for Vec { + type Value = Vec; + + fn cs(&self) -> ark_relations::r1cs::ConstraintSystemRef { + self.as_slice().cs() + } + + fn value(&self) -> Result { + self.as_slice().value() + } +} + /// A utility trait to convert `Self` to `Result pub trait Assignment { /// Converts `self` to `Result`. diff --git a/src/select.rs b/src/select.rs index bbc2c3c9..6fafcb9d 100644 --- a/src/select.rs +++ b/src/select.rs @@ -115,3 +115,24 @@ where constants: &[Self::TableConstant], ) -> Result; } + +impl CondSelectGadget for Vec +where + ConstraintF: Field, + T: CondSelectGadget, +{ + #[tracing::instrument(target = "r1cs", skip(true_value, false_value))] + fn conditionally_select( + cond: &Boolean, + true_value: &Vec, + false_value: &Vec, + ) -> Result, SynthesisError> { + assert_eq!(true_value.len(), false_value.len()); + + true_value + .iter() + .zip(false_value.iter()) + .map(|(t, f)| T::conditionally_select(cond, t, f)) + .collect() + } +}