Skip to content

Commit 4020fbc

Browse files
authored
Add convenience impls for common types (#137)
* Work * Tweak * Fmt * Work * Format + typo fixes * `no-std` fix
1 parent a124995 commit 4020fbc

File tree

13 files changed

+341
-121
lines changed

13 files changed

+341
-121
lines changed

src/alloc.rs

Lines changed: 52 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,7 @@ impl AllocationMode {
3737

3838
/// Specifies how variables of type `Self` should be allocated in a
3939
/// `ConstraintSystem`.
40-
pub trait AllocVar<V, F: Field>
41-
where
42-
Self: Sized,
43-
V: ?Sized,
44-
{
40+
pub trait AllocVar<V: ?Sized, F: Field>: Sized {
4541
/// Allocates a new variable of type `Self` in the `ConstraintSystem` `cs`.
4642
/// The mode of allocation is decided by `mode`.
4743
fn new_variable<T: Borrow<V>>(
@@ -92,10 +88,56 @@ impl<I, F: Field, A: AllocVar<I, F>> AllocVar<[I], F> for Vec<A> {
9288
) -> Result<Self, SynthesisError> {
9389
let ns = cs.into();
9490
let cs = ns.cs();
95-
let mut vec = Vec::new();
96-
for value in f()?.borrow().iter() {
97-
vec.push(A::new_variable(cs.clone(), || Ok(value), mode)?);
98-
}
99-
Ok(vec)
91+
f().and_then(|v| {
92+
v.borrow()
93+
.iter()
94+
.map(|e| A::new_variable(cs.clone(), || Ok(e), mode))
95+
.collect()
96+
})
97+
}
98+
}
99+
100+
/// Dummy impl for `()`.
101+
impl<F: Field> AllocVar<(), F> for () {
102+
fn new_variable<T: Borrow<()>>(
103+
_cs: impl Into<Namespace<F>>,
104+
_f: impl FnOnce() -> Result<T, SynthesisError>,
105+
_mode: AllocationMode,
106+
) -> Result<Self, SynthesisError> {
107+
Ok(())
108+
}
109+
}
110+
111+
/// This blanket implementation just allocates variables in `Self`
112+
/// element by element.
113+
impl<I, F: Field, A: AllocVar<I, F>, const N: usize> AllocVar<[I; N], F> for [A; N] {
114+
fn new_variable<T: Borrow<[I; N]>>(
115+
cs: impl Into<Namespace<F>>,
116+
f: impl FnOnce() -> Result<T, SynthesisError>,
117+
mode: AllocationMode,
118+
) -> Result<Self, SynthesisError> {
119+
let ns = cs.into();
120+
let cs = ns.cs();
121+
f().map(|v| {
122+
let v = v.borrow();
123+
core::array::from_fn(|i| A::new_variable(cs.clone(), || Ok(&v[i]), mode).unwrap())
124+
})
125+
}
126+
}
127+
128+
/// This blanket implementation just allocates variables in `Self`
129+
/// element by element.
130+
impl<I, F: Field, A: AllocVar<I, F>, const N: usize> AllocVar<[I], F> for [A; N] {
131+
fn new_variable<T: Borrow<[I]>>(
132+
cs: impl Into<Namespace<F>>,
133+
f: impl FnOnce() -> Result<T, SynthesisError>,
134+
mode: AllocationMode,
135+
) -> Result<Self, SynthesisError> {
136+
let ns = cs.into();
137+
let cs = ns.cs();
138+
f().map(|v| {
139+
let v = v.borrow();
140+
core::array::from_fn(|i| A::new_variable(cs.clone(), || Ok(&v[i]), mode).unwrap())
141+
})
100142
}
101143
}

src/boolean/cmp.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ impl<F: PrimeField> Boolean<F> {
4848
let mut bits_iter = bits.iter().rev(); // Iterate in big-endian
4949

5050
// Runs of ones in r
51-
let mut last_run = Boolean::constant(true);
51+
let mut last_run = Boolean::TRUE;
5252
let mut current_run = vec![];
5353

5454
let mut element_num_bits = 0;
@@ -57,12 +57,12 @@ impl<F: PrimeField> Boolean<F> {
5757
}
5858

5959
if bits.len() > element_num_bits {
60-
let mut or_result = Boolean::constant(false);
60+
let mut or_result = Boolean::FALSE;
6161
for should_be_zero in &bits[element_num_bits..] {
6262
or_result |= should_be_zero;
6363
let _ = bits_iter.next().unwrap();
6464
}
65-
or_result.enforce_equal(&Boolean::constant(false))?;
65+
or_result.enforce_equal(&Boolean::FALSE)?;
6666
}
6767

6868
for (b, a) in BitIteratorBE::without_leading_zeros(b).zip(bits_iter.by_ref()) {

src/boolean/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ impl<F: Field> Boolean<F> {
100100
/// let true_var = Boolean::<Fr>::TRUE;
101101
/// let false_var = Boolean::<Fr>::FALSE;
102102
///
103-
/// true_var.enforce_equal(&Boolean::constant(true))?;
103+
/// true_var.enforce_equal(&Boolean::TRUE)?;
104104
/// false_var.enforce_equal(&Boolean::constant(false))?;
105105
/// # Ok(())
106106
/// # }

src/cmp.rs

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
use ark_ff::Field;
1+
use ark_ff::{Field, PrimeField};
22
use ark_relations::r1cs::SynthesisError;
33

4-
use crate::{boolean::Boolean, R1CSVar};
4+
use crate::{boolean::Boolean, eq::EqGadget, R1CSVar};
55

66
/// Specifies how to generate constraints for comparing two variables.
7-
pub trait CmpGadget<F: Field>: R1CSVar<F> {
7+
pub trait CmpGadget<F: Field>: R1CSVar<F> + EqGadget<F> {
88
fn is_gt(&self, other: &Self) -> Result<Boolean<F>, SynthesisError> {
99
other.is_lt(self)
1010
}
@@ -19,3 +19,35 @@ pub trait CmpGadget<F: Field>: R1CSVar<F> {
1919
other.is_ge(self)
2020
}
2121
}
22+
23+
/// Mimics the behavior of `std::cmp::PartialOrd` for `()`.
24+
impl<F: Field> CmpGadget<F> for () {
25+
fn is_gt(&self, _other: &Self) -> Result<Boolean<F>, SynthesisError> {
26+
Ok(Boolean::FALSE)
27+
}
28+
29+
fn is_ge(&self, _other: &Self) -> Result<Boolean<F>, SynthesisError> {
30+
Ok(Boolean::TRUE)
31+
}
32+
33+
fn is_lt(&self, _other: &Self) -> Result<Boolean<F>, SynthesisError> {
34+
Ok(Boolean::FALSE)
35+
}
36+
37+
fn is_le(&self, _other: &Self) -> Result<Boolean<F>, SynthesisError> {
38+
Ok(Boolean::TRUE)
39+
}
40+
}
41+
42+
/// Mimics the lexicographic comparison behavior of `std::cmp::PartialOrd` for `[T]`.
43+
impl<T: CmpGadget<F>, F: PrimeField> CmpGadget<F> for [T] {
44+
fn is_ge(&self, other: &Self) -> Result<Boolean<F>, SynthesisError> {
45+
let mut result = Boolean::TRUE;
46+
let mut all_equal_so_far = Boolean::TRUE;
47+
for (a, b) in self.iter().zip(other) {
48+
all_equal_so_far &= a.is_eq(b)?;
49+
result &= a.is_gt(b)? | &all_equal_so_far;
50+
}
51+
Ok(result)
52+
}
53+
}

src/convert.rs

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,60 @@ impl<'a, F: Field, T: 'a + ToBytesGadget<F>> ToBytesGadget<F> for &'a T {
9090
}
9191
}
9292

93+
impl<T: ToBytesGadget<F>, F: Field> ToBytesGadget<F> for [T] {
94+
fn to_bytes_le(&self) -> Result<Vec<UInt8<F>>, SynthesisError> {
95+
let mut bytes = Vec::new();
96+
for elem in self {
97+
let elem = elem.to_bytes_le()?;
98+
bytes.extend_from_slice(&elem);
99+
// Make sure that there's enough capacity to avoid reallocations.
100+
bytes.reserve(elem.len() * (self.len() - 1));
101+
}
102+
Ok(bytes)
103+
}
104+
105+
fn to_non_unique_bytes_le(&self) -> Result<Vec<UInt8<F>>, SynthesisError> {
106+
let mut bytes = Vec::new();
107+
for elem in self {
108+
let elem = elem.to_non_unique_bytes_le()?;
109+
bytes.extend_from_slice(&elem);
110+
// Make sure that there's enough capacity to avoid reallocations.
111+
bytes.reserve(elem.len() * (self.len() - 1));
112+
}
113+
Ok(bytes)
114+
}
115+
}
116+
117+
impl<T: ToBytesGadget<F>, F: Field> ToBytesGadget<F> for Vec<T> {
118+
fn to_bytes_le(&self) -> Result<Vec<UInt8<F>>, SynthesisError> {
119+
self.as_slice().to_bytes_le()
120+
}
121+
122+
fn to_non_unique_bytes_le(&self) -> Result<Vec<UInt8<F>>, SynthesisError> {
123+
self.as_slice().to_non_unique_bytes_le()
124+
}
125+
}
126+
127+
impl<T: ToBytesGadget<F>, F: Field, const N: usize> ToBytesGadget<F> for [T; N] {
128+
fn to_bytes_le(&self) -> Result<Vec<UInt8<F>>, SynthesisError> {
129+
self.as_slice().to_bytes_le()
130+
}
131+
132+
fn to_non_unique_bytes_le(&self) -> Result<Vec<UInt8<F>>, SynthesisError> {
133+
self.as_slice().to_non_unique_bytes_le()
134+
}
135+
}
136+
137+
impl<F: Field> ToBytesGadget<F> for () {
138+
fn to_bytes_le(&self) -> Result<Vec<UInt8<F>>, SynthesisError> {
139+
Ok(Vec::new())
140+
}
141+
142+
fn to_non_unique_bytes_le(&self) -> Result<Vec<UInt8<F>>, SynthesisError> {
143+
Ok(Vec::new())
144+
}
145+
}
146+
93147
/// Specifies how to convert a variable of type `Self` to variables of
94148
/// type `FpVar<ConstraintF>`
95149
pub trait ToConstraintFieldGadget<ConstraintF: ark_ff::PrimeField> {

src/eq.rs

Lines changed: 100 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ pub trait EqGadget<F: Field> {
3333
should_enforce: &Boolean<F>,
3434
) -> Result<(), SynthesisError> {
3535
self.is_eq(&other)?
36-
.conditional_enforce_equal(&Boolean::constant(true), should_enforce)
36+
.conditional_enforce_equal(&Boolean::TRUE, should_enforce)
3737
}
3838

3939
/// Enforce that `self` and `other` are equal.
@@ -46,7 +46,7 @@ pub trait EqGadget<F: Field> {
4646
/// are encouraged to carefully analyze the efficiency and safety of these.
4747
#[tracing::instrument(target = "r1cs", skip(self, other))]
4848
fn enforce_equal(&self, other: &Self) -> Result<(), SynthesisError> {
49-
self.conditional_enforce_equal(other, &Boolean::constant(true))
49+
self.conditional_enforce_equal(other, &Boolean::TRUE)
5050
}
5151

5252
/// If `should_enforce == true`, enforce that `self` and `other` are *not*
@@ -65,7 +65,7 @@ pub trait EqGadget<F: Field> {
6565
should_enforce: &Boolean<F>,
6666
) -> Result<(), SynthesisError> {
6767
self.is_neq(&other)?
68-
.conditional_enforce_equal(&Boolean::constant(true), should_enforce)
68+
.conditional_enforce_equal(&Boolean::TRUE, should_enforce)
6969
}
7070

7171
/// Enforce that `self` and `other` are *not* equal.
@@ -78,20 +78,23 @@ pub trait EqGadget<F: Field> {
7878
/// are encouraged to carefully analyze the efficiency and safety of these.
7979
#[tracing::instrument(target = "r1cs", skip(self, other))]
8080
fn enforce_not_equal(&self, other: &Self) -> Result<(), SynthesisError> {
81-
self.conditional_enforce_not_equal(other, &Boolean::constant(true))
81+
self.conditional_enforce_not_equal(other, &Boolean::TRUE)
8282
}
8383
}
8484

8585
impl<T: EqGadget<F> + R1CSVar<F>, F: PrimeField> EqGadget<F> for [T] {
8686
#[tracing::instrument(target = "r1cs", skip(self, other))]
8787
fn is_eq(&self, other: &Self) -> Result<Boolean<F>, SynthesisError> {
8888
assert_eq!(self.len(), other.len());
89-
assert!(!self.is_empty());
90-
let mut results = Vec::with_capacity(self.len());
91-
for (a, b) in self.iter().zip(other) {
92-
results.push(a.is_eq(b)?);
89+
if self.is_empty() & other.is_empty() {
90+
Ok(Boolean::TRUE)
91+
} else {
92+
let mut results = Vec::with_capacity(self.len());
93+
for (a, b) in self.iter().zip(other) {
94+
results.push(a.is_eq(b)?);
95+
}
96+
Boolean::kary_and(&results)
9397
}
94-
Boolean::kary_and(&results)
9598
}
9699

97100
#[tracing::instrument(target = "r1cs", skip(self, other))]
@@ -128,3 +131,91 @@ impl<T: EqGadget<F> + R1CSVar<F>, F: PrimeField> EqGadget<F> for [T] {
128131
}
129132
}
130133
}
134+
135+
/// This blanket implementation just forwards to the impl on [`[T]`].
136+
impl<T: EqGadget<F> + R1CSVar<F>, F: PrimeField> EqGadget<F> for Vec<T> {
137+
#[tracing::instrument(target = "r1cs", skip(self, other))]
138+
fn is_eq(&self, other: &Self) -> Result<Boolean<F>, SynthesisError> {
139+
self.as_slice().is_eq(other.as_slice())
140+
}
141+
142+
#[tracing::instrument(target = "r1cs", skip(self, other))]
143+
fn conditional_enforce_equal(
144+
&self,
145+
other: &Self,
146+
condition: &Boolean<F>,
147+
) -> Result<(), SynthesisError> {
148+
self.as_slice()
149+
.conditional_enforce_equal(other.as_slice(), condition)
150+
}
151+
152+
#[tracing::instrument(target = "r1cs", skip(self, other))]
153+
fn conditional_enforce_not_equal(
154+
&self,
155+
other: &Self,
156+
should_enforce: &Boolean<F>,
157+
) -> Result<(), SynthesisError> {
158+
self.as_slice()
159+
.conditional_enforce_not_equal(other.as_slice(), should_enforce)
160+
}
161+
}
162+
163+
/// Dummy impl for `()`.
164+
impl<F: Field> EqGadget<F> for () {
165+
/// Output a `Boolean` value representing whether `self.value() ==
166+
/// other.value()`.
167+
#[inline]
168+
fn is_eq(&self, _other: &Self) -> Result<Boolean<F>, SynthesisError> {
169+
Ok(Boolean::TRUE)
170+
}
171+
172+
/// If `should_enforce == true`, enforce that `self` and `other` are equal;
173+
/// else, enforce a vacuously true statement.
174+
///
175+
/// This is a no-op as `self.is_eq(other)?` is always `true`.
176+
#[tracing::instrument(target = "r1cs", skip(self, _other))]
177+
fn conditional_enforce_equal(
178+
&self,
179+
_other: &Self,
180+
_should_enforce: &Boolean<F>,
181+
) -> Result<(), SynthesisError> {
182+
Ok(())
183+
}
184+
185+
/// Enforce that `self` and `other` are equal.
186+
///
187+
/// This does not generate any constraints as `self.is_eq(other)?` is always
188+
/// `true`.
189+
#[tracing::instrument(target = "r1cs", skip(self, _other))]
190+
fn enforce_equal(&self, _other: &Self) -> Result<(), SynthesisError> {
191+
Ok(())
192+
}
193+
}
194+
195+
/// This blanket implementation just forwards to the impl on [`[T]`].
196+
impl<T: EqGadget<F> + R1CSVar<F>, F: PrimeField, const N: usize> EqGadget<F> for [T; N] {
197+
#[tracing::instrument(target = "r1cs", skip(self, other))]
198+
fn is_eq(&self, other: &Self) -> Result<Boolean<F>, SynthesisError> {
199+
self.as_slice().is_eq(other.as_slice())
200+
}
201+
202+
#[tracing::instrument(target = "r1cs", skip(self, other))]
203+
fn conditional_enforce_equal(
204+
&self,
205+
other: &Self,
206+
condition: &Boolean<F>,
207+
) -> Result<(), SynthesisError> {
208+
self.as_slice()
209+
.conditional_enforce_equal(other.as_slice(), condition)
210+
}
211+
212+
#[tracing::instrument(target = "r1cs", skip(self, other))]
213+
fn conditional_enforce_not_equal(
214+
&self,
215+
other: &Self,
216+
should_enforce: &Boolean<F>,
217+
) -> Result<(), SynthesisError> {
218+
self.as_slice()
219+
.conditional_enforce_not_equal(other.as_slice(), should_enforce)
220+
}
221+
}

src/fields/emulated_fp/allocated_field_var.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -686,7 +686,7 @@ impl<TargetF: PrimeField, BaseF: PrimeField> ToBytesGadget<BaseF>
686686

687687
let num_bits = TargetF::BigInt::NUM_LIMBS * 64;
688688
assert!(bits.len() <= num_bits);
689-
bits.resize_with(num_bits, || Boolean::constant(false));
689+
bits.resize_with(num_bits, || Boolean::FALSE);
690690

691691
let bytes = bits.chunks(8).map(UInt8::from_bits_le).collect();
692692
Ok(bytes)

src/fields/fp/mod.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -555,7 +555,7 @@ impl<F: PrimeField> ToBytesGadget<F> for AllocatedFp<F> {
555555
fn to_bytes_le(&self) -> Result<Vec<UInt8<F>>, SynthesisError> {
556556
let num_bits = F::BigInt::NUM_LIMBS * 64;
557557
let mut bits = self.to_bits_le()?;
558-
let remainder = core::iter::repeat(Boolean::constant(false)).take(num_bits - bits.len());
558+
let remainder = core::iter::repeat(Boolean::FALSE).take(num_bits - bits.len());
559559
bits.extend(remainder);
560560
let bytes = bits
561561
.chunks(8)
@@ -568,7 +568,7 @@ impl<F: PrimeField> ToBytesGadget<F> for AllocatedFp<F> {
568568
fn to_non_unique_bytes_le(&self) -> Result<Vec<UInt8<F>>, SynthesisError> {
569569
let num_bits = F::BigInt::NUM_LIMBS * 64;
570570
let mut bits = self.to_non_unique_bits_le()?;
571-
let remainder = core::iter::repeat(Boolean::constant(false)).take(num_bits - bits.len());
571+
let remainder = core::iter::repeat(Boolean::FALSE).take(num_bits - bits.len());
572572
bits.extend(remainder);
573573
let bytes = bits
574574
.chunks(8)

0 commit comments

Comments
 (0)