Skip to content

Commit fc02b95

Browse files
committed
bool_t from witness constructor supports range constraints
1 parent 28211f5 commit fc02b95

File tree

8 files changed

+90
-51
lines changed

8 files changed

+90
-51
lines changed

barretenberg/cpp/src/barretenberg/stdlib/honk_verifier/ultra_recursive_verifier.test.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,7 @@ template <typename RecursiveFlavor> class RecursiveVerifierTest : public testing
291291
}
292292
// Check the size of the recursive verifier
293293
if constexpr (std::same_as<RecursiveFlavor, MegaZKRecursiveFlavor_<UltraCircuitBuilder>>) {
294-
uint32_t NUM_GATES_EXPECTED = 797234;
294+
uint32_t NUM_GATES_EXPECTED = 797170;
295295
ASSERT_EQ(static_cast<uint32_t>(outer_circuit.get_num_finalized_gates()), NUM_GATES_EXPECTED)
296296
<< "MegaZKHonk Recursive verifier changed in Ultra gate count! Update this value if you "
297297
"are sure this is expected.";

barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup.hpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,13 @@ template <class Builder_, class Fq, class Fr, class NativeGroup> class element {
370370
}
371371

372372
bool_ct is_point_at_infinity() const { return _is_infinity; }
373-
void set_point_at_infinity(const bool_ct& is_infinity) { _is_infinity = is_infinity; }
373+
void set_point_at_infinity(const bool_ct& is_infinity, const bool& add_to_used_witnesses = false)
374+
{
375+
_is_infinity = is_infinity.normalize();
376+
if (add_to_used_witnesses) {
377+
_is_infinity.get_context()->update_used_witnesses(_is_infinity.get_normalized_witness_index());
378+
};
379+
}
374380
element get_standard_form() const;
375381

376382
void set_origin_tag(OriginTag tag) const

barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_impl.hpp

Lines changed: 11 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ element<C, Fq, Fr, G> element<C, Fq, Fr, G>::operator+(const element& other) con
8989
const bool_ct double_predicate = (x_coordinates_match && y_coordinates_match);
9090
const bool_ct lhs_infinity = is_point_at_infinity();
9191
const bool_ct rhs_infinity = other.is_point_at_infinity();
92+
const bool_ct has_infinity_input = lhs_infinity || rhs_infinity;
9293

9394
// Compute the gradient `lambda`. If we add, `lambda = (y2 - y1)/(x2 - x1)`, else `lambda = 3x1*x1/2y1
9495
const Fq add_lambda_numerator = other.y - y;
@@ -103,8 +104,8 @@ element<C, Fq, Fr, G> element<C, Fq, Fr, G>::operator+(const element& other) con
103104
// divide by zero error.
104105
// Note: if either inputs are points at infinity we will not use the result of this computation.
105106
Fq safe_edgecase_denominator = Fq(1);
106-
lambda_denominator = Fq::conditional_assign(
107-
lhs_infinity || rhs_infinity || infinity_predicate, safe_edgecase_denominator, lambda_denominator);
107+
lambda_denominator =
108+
Fq::conditional_assign(has_infinity_input || infinity_predicate, safe_edgecase_denominator, lambda_denominator);
108109
const Fq lambda = Fq::div_without_denominator_check({ lambda_numerator }, lambda_denominator);
109110

110111
const Fq x3 = lambda.sqradd({ -other.x, -x });
@@ -122,15 +123,8 @@ element<C, Fq, Fr, G> element<C, Fq, Fr, G>::operator+(const element& other) con
122123
// yes = infinity_predicate && !lhs_infinity && !rhs_infinity
123124
// yes = lhs_infinity && rhs_infinity
124125
// n.b. can likely optimize this
125-
bool_ct result_is_infinity = infinity_predicate && (!lhs_infinity && !rhs_infinity);
126-
if constexpr (IsUltraBuilder<C>) {
127-
result_is_infinity.get_context()->update_used_witnesses(result_is_infinity.get_normalized_witness_index());
128-
}
129-
result_is_infinity = result_is_infinity || (lhs_infinity && rhs_infinity);
130-
if constexpr (IsUltraBuilder<C>) {
131-
result_is_infinity.get_context()->update_used_witnesses(result_is_infinity.get_normalized_witness_index());
132-
}
133-
result.set_point_at_infinity(result_is_infinity);
126+
bool_ct result_is_infinity = (infinity_predicate && !has_infinity_input) || (lhs_infinity && rhs_infinity);
127+
result.set_point_at_infinity(result_is_infinity, /* add_to_used_witnesses */ true);
134128

135129
result.set_origin_tag(OriginTag(get_origin_tag(), other.get_origin_tag()));
136130
return result;
@@ -167,6 +161,7 @@ element<C, Fq, Fr, G> element<C, Fq, Fr, G>::operator-(const element& other) con
167161
const bool_ct double_predicate = (x_coordinates_match && !y_coordinates_match);
168162
const bool_ct lhs_infinity = is_point_at_infinity();
169163
const bool_ct rhs_infinity = other.is_point_at_infinity();
164+
const bool_ct has_infinity_input = lhs_infinity || rhs_infinity;
170165

171166
// Compute the gradient `lambda`. If we add, `lambda = (y2 - y1)/(x2 - x1)`, else `lambda = 3x1*x1/2y1
172167
const Fq add_lambda_numerator = -other.y - y;
@@ -181,8 +176,8 @@ element<C, Fq, Fr, G> element<C, Fq, Fr, G>::operator-(const element& other) con
181176
// divide by zero error.
182177
// (if either inputs are points at infinity we will not use the result of this computation)
183178
Fq safe_edgecase_denominator = Fq(1);
184-
lambda_denominator = Fq::conditional_assign(
185-
lhs_infinity || rhs_infinity || infinity_predicate, safe_edgecase_denominator, lambda_denominator);
179+
lambda_denominator =
180+
Fq::conditional_assign(has_infinity_input || infinity_predicate, safe_edgecase_denominator, lambda_denominator);
186181
const Fq lambda = Fq::div_without_denominator_check({ lambda_numerator }, lambda_denominator);
187182

188183
const Fq x3 = lambda.sqradd({ -other.x, -x });
@@ -200,15 +195,9 @@ element<C, Fq, Fr, G> element<C, Fq, Fr, G>::operator-(const element& other) con
200195
// yes = infinity_predicate && !lhs_infinity && !rhs_infinity
201196
// yes = lhs_infinity && rhs_infinity
202197
// n.b. can likely optimize this
203-
bool_ct result_is_infinity = infinity_predicate && (!lhs_infinity && !rhs_infinity);
204-
if constexpr (IsUltraBuilder<C>) {
205-
result_is_infinity.get_context()->update_used_witnesses(result_is_infinity.get_normalized_witness_index());
206-
}
207-
result_is_infinity = result_is_infinity || (lhs_infinity && rhs_infinity);
208-
if constexpr (IsUltraBuilder<C>) {
209-
result_is_infinity.get_context()->update_used_witnesses(result_is_infinity.get_normalized_witness_index());
210-
}
211-
result.set_point_at_infinity(result_is_infinity);
198+
bool_ct result_is_infinity = (infinity_predicate && !has_infinity_input) || (lhs_infinity && rhs_infinity);
199+
200+
result.set_point_at_infinity(result_is_infinity, /* add_to_used_witnesses */ true);
212201
result.set_origin_tag(OriginTag(get_origin_tag(), other.get_origin_tag()));
213202
return result;
214203
}

barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_nafs.hpp

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -480,6 +480,8 @@ std::vector<bool_t<C>> element<C, Fq, Fr, G>::compute_naf(const Fr& scalar, cons
480480
{
481481
// We are not handling the case of odd bit lengths here.
482482
BB_ASSERT_EQ(max_num_bits % 2, 0U);
483+
// Apply range constraint gates instead of arithmetic gates to constrain boolean witnesses
484+
static constexpr bool use_bool_range_constraint = true;
483485

484486
C* ctx = scalar.context;
485487
uint512_t scalar_multiplier_512 = uint512_t(uint256_t(scalar.get_value()) % Fr::modulus);
@@ -495,22 +497,18 @@ std::vector<bool_t<C>> element<C, Fq, Fr, G>::compute_naf(const Fr& scalar, cons
495497
// if boolean is false => do NOT flip y
496498
// if boolean is true => DO flip y
497499
// first entry is skew. i.e. do we subtract one from the final result or not
498-
if (scalar_multiplier.get_bit(0) == false) {
499-
// add skew
500-
naf_entries[num_rounds] = bool_ct(witness_t(ctx, true));
501-
scalar_multiplier += uint256_t(1);
502-
} else {
503-
naf_entries[num_rounds] = bool_ct(witness_t(ctx, false));
504-
}
500+
naf_entries[num_rounds] = bool_ct(witness_t(ctx, !scalar_multiplier.get_bit(0)), use_bool_range_constraint);
501+
scalar_multiplier += uint256_t(!scalar_multiplier.get_bit(0));
502+
505503
// We need to manually propagate the origin tag
506504
naf_entries[num_rounds].set_origin_tag(scalar.get_origin_tag());
507505

508506
for (size_t i = 0; i < num_rounds - 1; ++i) {
509507
bool next_entry = scalar_multiplier.get_bit(i + 1);
510508
// if the next entry is false, we need to flip the sign of the current entry. i.e. make negative
511-
// This is a VERY hacky workaround to ensure that UltraBuilder will apply a basic
512-
// range constraint per bool, and not a full 1-bit range gate
513-
bool_ct bit(witness_t<C>(ctx, !next_entry));
509+
// Apply a basic range constraint per bool, and not a full 1-bit range gate. Results in ~`num_rounds`/4 gates
510+
// per scalar.
511+
bool_ct bit(witness_t<C>(ctx, !next_entry), use_bool_range_constraint);
514512

515513
naf_entries[num_rounds - i - 1] = bit;
516514

barretenberg/cpp/src/barretenberg/stdlib/primitives/bool/bool.cpp

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,16 +32,25 @@ bool_t<Builder>::bool_t(Builder* parent_context)
3232
/**
3333
* @brief Construct a `bool_t` object from a witness, note that the value stored at `witness_index` is constrained to be
3434
* 0 or 1.
35+
* @param value A witness, which is constrained to be boolean inside of this constructor.
36+
* @param use_range_constraint In case we need to create `bool_t` in a loop, it is more efficient to apply the range
37+
* constraint gates instead of creating arithmetic gates.
3538
*/
3639
template <typename Builder>
37-
bool_t<Builder>::bool_t(const witness_t<Builder>& value)
40+
bool_t<Builder>::bool_t(const witness_t<Builder>& value, const bool& use_range_constraint)
3841
: context(value.context)
3942
{
4043
ASSERT((value.witness == bb::fr::zero()) || (value.witness == bb::fr::one()),
4144
"bool_t: witness value is not 0 or 1");
4245
witness_index = value.witness_index;
43-
// Constrain x := other.witness by the relation x^2 = x
44-
context->create_bool_gate(witness_index);
46+
47+
if (use_range_constraint) {
48+
// Create a range constraint gate
49+
context->create_new_range_constraint(witness_index, 3, "bool_t: witness value is not 0 or 1");
50+
} else {
51+
// Create an arithmetic gate to enforce the relation x^2 = x
52+
context->create_bool_gate(witness_index);
53+
}
4554
witness_bool = (value.witness == bb::fr::one());
4655
witness_inverted = false;
4756
set_free_witness_tag();
@@ -135,7 +144,7 @@ template <typename Builder> bool_t<Builder>& bool_t<Builder>::operator=(bool_t&&
135144
return *this;
136145
}
137146
/**
138-
* @brief Assigns a `witness_t` to a `bool_t`. As above, he value stored at `witness_index` is constrained to be
147+
* @brief Assigns a `witness_t` to a `bool_t`. As above, the value stored at `witness_index` is constrained to be
139148
* 0 or 1.
140149
*/
141150
template <typename Builder> bool_t<Builder>& bool_t<Builder>::operator=(const witness_t<Builder>& other)

barretenberg/cpp/src/barretenberg/stdlib/primitives/bool/bool.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ template <typename Builder> class bool_t {
6161
bool_t(const bool value = false);
6262
bool_t(Builder* parent_context);
6363
bool_t(Builder* parent_context, const bool value);
64-
bool_t(const witness_t<Builder>& value);
64+
bool_t(const witness_t<Builder>& value, const bool& use_range_constraint = false);
6565
bool_t(const bool_t& other);
6666
bool_t(bool_t&& other);
6767

barretenberg/cpp/src/barretenberg/stdlib/primitives/bool/bool.test.cpp

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,39 @@ template <class Builder_> class BoolTest : public ::testing::Test {
151151
"((other.witness == bb::fr::one()) || (other.witness == bb::fr::zero()))");
152152
};
153153
}
154+
155+
void test_construct_from_witness_range_constraint()
156+
{
157+
const bool use_range_constraint = true;
158+
159+
for (size_t num_inputs = 1; num_inputs < 50; num_inputs++) {
160+
Builder builder = Builder();
161+
size_t num_gates_start = builder.get_estimated_num_finalized_gates();
162+
163+
std::vector<uint32_t> indices;
164+
for (size_t idx = 0; idx < num_inputs; idx++) {
165+
indices.push_back(
166+
bool_ct(witness_ct(&builder, idx % 2), use_range_constraint).get_normalized_witness_index());
167+
}
168+
169+
const size_t sorted_list_size = num_inputs + 2;
170+
171+
// Pin down the gate numbers. The point is that it is more efficient to use this constructor to constrain a
172+
// batch of bool_t elements.
173+
size_t expected = (num_inputs == 1) ? 4 : (sorted_list_size / 4) + 3;
174+
175+
EXPECT_EQ(builder.get_estimated_num_finalized_gates() - num_gates_start, expected);
176+
177+
builder.create_dummy_constraints(indices);
178+
179+
EXPECT_TRUE(CircuitChecker::check(builder));
180+
}
181+
182+
// Failure test
183+
Builder builder = Builder();
184+
EXPECT_THROW_OR_ABORT(auto new_bool = bool_ct(witness_ct(&builder, 2), use_range_constraint),
185+
"bool_t: witness value is not 0 or 1");
186+
}
154187
void test_AND()
155188
{
156189
test_binary_op(
@@ -451,6 +484,10 @@ TYPED_TEST(BoolTest, ConstructFromWitness)
451484
{
452485
TestFixture::test_construct_from_witness();
453486
}
487+
TYPED_TEST(BoolTest, ConstructFromWitnessRangeConstraint)
488+
{
489+
TestFixture::test_construct_from_witness_range_constraint();
490+
}
454491

455492
TYPED_TEST(BoolTest, Normalization)
456493
{

barretenberg/cpp/src/barretenberg/stdlib_circuit_builders/ultra_circuit_builder.cpp

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1131,17 +1131,17 @@ template <typename ExecutionTrace> void UltraCircuitBuilder_<ExecutionTrace>::pr
11311131
}
11321132

11331133
/*
1134-
Create range constraint:
1135-
* add variable index to a list of range constrained variables
1136-
* data structures: vector of lists, each list contains:
1137-
* - the range size
1138-
* - the list of variables in the range
1139-
* - a generalized permutation tag
1140-
*
1141-
* create range constraint parameters: variable index && range size
1142-
*
1143-
* std::map<uint64_t, RangeList> range_lists;
1144-
*/
1134+
* Create range constraint:
1135+
* add variable index to a list of range constrained variables
1136+
* data structures: vector of lists, each list contains:
1137+
* - the range size
1138+
* - the list of variables in the range
1139+
* - a generalized permutation tag
1140+
*
1141+
* create range constraint parameters: variable index && range size
1142+
*
1143+
* std::map<uint64_t, RangeList> range_lists;
1144+
*/
11451145
// Check for a sequence of variables that neighboring differences are at most 3 (used for batched range checkj)
11461146
template <typename ExecutionTrace>
11471147
void UltraCircuitBuilder_<ExecutionTrace>::create_sort_constraint(const std::vector<uint32_t>& variable_index)
@@ -1180,7 +1180,7 @@ void UltraCircuitBuilder_<ExecutionTrace>::create_sort_constraint(const std::vec
11801180
}
11811181

11821182
// useful to put variables in the witness that aren't already used - e.g. the dummy variables of the range constraint in
1183-
// multiples of three
1183+
// multiples of four
11841184
template <typename ExecutionTrace>
11851185
void UltraCircuitBuilder_<ExecutionTrace>::create_dummy_constraints(const std::vector<uint32_t>& variable_index)
11861186
{

0 commit comments

Comments
 (0)