Skip to content

Commit 1c999cb

Browse files
fix ecdsa pai issue, add conditional_select fn and tests.
Co-authored-by: federicobarbacovi <[email protected]>
1 parent 5ab003f commit 1c999cb

File tree

5 files changed

+148
-2
lines changed

5 files changed

+148
-2
lines changed

barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup.hpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,25 @@ template <class Builder_, class Fq, class Fr, class NativeGroup> class element {
218218
return result;
219219
}
220220

221+
element conditional_select(const element& other, const bool_ct& predicate) const
222+
{
223+
// If predicate is constant, we can select out of circuit
224+
if (predicate.is_constant()) {
225+
return predicate.get_value() ? other : *this;
226+
}
227+
228+
// Get the builder context
229+
Builder* ctx = get_context(other) ? get_context(other) : predicate.get_context();
230+
BB_ASSERT_NEQ(ctx, nullptr, "biggroup::conditional_select must have a context");
231+
232+
element result(*this);
233+
result.x = result.x.conditional_select(other.x, predicate);
234+
result.y = result.y.conditional_select(other.y, predicate);
235+
result._is_infinity =
236+
bool_ct::conditional_assign(predicate, other.is_point_at_infinity(), result.is_point_at_infinity());
237+
return result;
238+
}
239+
221240
element normalize() const
222241
{
223242
element result(*this);

barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup.test.cpp

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,57 @@ template <typename TestType> class stdlib_biggroup : public testing::Test {
343343
EXPECT_CIRCUIT_CORRECTNESS(builder);
344344
}
345345

346+
static void test_conditional_negate()
347+
{
348+
Builder builder;
349+
size_t num_repetitions = 10;
350+
for (size_t i = 0; i < num_repetitions; ++i) {
351+
affine_element input_a(element::random_element());
352+
bool negate = (engine.get_random_uint32() % 2) == 1;
353+
bool_ct negate_ct = bool_ct(witness_ct(&builder, negate ? 1 : 0));
354+
element_ct a = element_ct::from_witness(&builder, input_a);
355+
a.set_origin_tag(submitted_value_origin_tag);
356+
357+
element_ct c = a.conditional_negate(negate_ct);
358+
359+
// Check the resulting tag is preserved
360+
EXPECT_EQ(c.get_origin_tag(), submitted_value_origin_tag);
361+
362+
affine_element c_expected = negate ? affine_element(-element(input_a)) : input_a;
363+
EXPECT_EQ(c.get_value(), c_expected);
364+
}
365+
366+
EXPECT_CIRCUIT_CORRECTNESS(builder);
367+
}
368+
369+
static void test_conditional_select()
370+
{
371+
Builder builder;
372+
size_t num_repetitions = 10;
373+
for (size_t i = 0; i < num_repetitions; ++i) {
374+
affine_element input_a(element::random_element());
375+
affine_element input_b(element::random_element());
376+
bool select_a = (engine.get_random_uint32() % 2) == 1;
377+
bool_ct select_a_ct = bool_ct(witness_ct(&builder, select_a ? 1 : 0));
378+
element_ct a = element_ct::from_witness(&builder, input_a);
379+
element_ct b = element_ct::from_witness(&builder, input_b);
380+
381+
// Set different tags in a and b
382+
a.set_origin_tag(submitted_value_origin_tag);
383+
b.set_origin_tag(challenge_origin_tag);
384+
385+
element_ct c = a.conditional_select(b, select_a_ct);
386+
387+
// Check that the resulting tag is the union of inputs' tags
388+
EXPECT_EQ(c.get_origin_tag(), first_two_merged_tag);
389+
390+
affine_element c_expected = select_a ? input_b : input_a;
391+
EXPECT_EQ(c.get_value(), c_expected);
392+
}
393+
394+
EXPECT_CIRCUIT_CORRECTNESS(builder);
395+
}
396+
346397
static void test_montgomery_ladder()
347398
{
348399
Builder builder;
@@ -1676,6 +1727,14 @@ TYPED_TEST(stdlib_biggroup, dbl)
16761727
{
16771728
TestFixture::test_dbl();
16781729
}
1730+
TYPED_TEST(stdlib_biggroup, conditional_negate)
1731+
{
1732+
TestFixture::test_conditional_negate();
1733+
}
1734+
TYPED_TEST(stdlib_biggroup, conditional_select)
1735+
{
1736+
TestFixture::test_conditional_select();
1737+
}
16791738
TYPED_TEST(stdlib_biggroup, montgomery_ladder)
16801739
{
16811740
if constexpr (HasGoblinBuilder<TypeParam>) {

barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_goblin.hpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,16 @@ template <class Builder_, class Fq, class Fr, class NativeGroup> class goblin_el
235235
return result;
236236
}
237237

238+
goblin_element conditional_select(const goblin_element& other, const bool_ct& predicate) const
239+
{
240+
goblin_element result(*this);
241+
result.x = Fq::conditional_assign(predicate, other.x, result.x);
242+
result.y = Fq::conditional_assign(predicate, other.y, result.y);
243+
result._is_infinity =
244+
bool_ct::conditional_assign(predicate, other.is_point_at_infinity(), result.is_point_at_infinity());
245+
return result;
246+
}
247+
238248
goblin_element normalize() const
239249
{
240250
// no need to normalize, all goblin eccvm operations are returned normalized

barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_secp256k1.hpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,8 +130,7 @@ element<C, Fq, Fr, G> element<C, Fq, Fr, G>::secp256k1_ecdsa_mul(const element&
130130

131131
// when computing the wNAF we have already validated that positive_skew and negative_skew cannot both be true
132132
bool_ct skew_combined = positive_skew_bool ^ negative_skew_bool;
133-
result.x = accumulator.x.conditional_select(result.x, skew_combined);
134-
result.y = accumulator.y.conditional_select(result.y, skew_combined);
133+
result = accumulator.conditional_select(result, skew_combined);
135134
return result;
136135
};
137136

barretenberg/cpp/src/barretenberg/stdlib/primitives/biggroup/biggroup_secp256k1.test.cpp

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,61 @@ template <typename Curve> class stdlibBiggroupSecp256k1 : public testing::Test {
111111

112112
EXPECT_CIRCUIT_CORRECTNESS(builder);
113113
}
114+
115+
static void test_secp256k1_ecdsa_mul_skew_handling_regression()
116+
{
117+
// The scalars s1, u1, u2 are chosen such that:
118+
// Public key: P = (s1 * G)
119+
//
120+
// u1 * G + u2 * (s1 * G) = ø
121+
//
122+
// where ø is the point at infinity.
123+
//
124+
// The issue with such input was that we were not setting the point at infinity correctly
125+
// while adding the skew points. For the cases when we reach the point at infinity and still have
126+
// skew to add, we did not correctly set the flag _is_point_at_infinity. For this example, we have:
127+
//
128+
// u1_low skew: 0
129+
// u1_high skew: 1
130+
// u2_low skew: 1
131+
// u2_high skew: 0
132+
//
133+
// After adding the u2_low skew (i.e., its base point), we get the point at infinity. Then we handle the
134+
// u2 high skew as follows:
135+
// result = acc ± u1_high_base_point
136+
// result.x = u2_high_skew ? result.x : acc.x;
137+
// result.y = u2_high_skew ? result.y : acc.y;
138+
//
139+
// However, we did not set the flag _is_point_at_infinity for result. We must copy the flag from the
140+
// accumulator in this case, i.e., we must do:
141+
// result.x = u2_high_skew ? result.x : acc.x;
142+
// result.y = u2_high_skew ? result.y : acc.y;
143+
// result._is_point_at_infinity = u2_high_skew ? result._is_point_at_infinity : acc._is_point_at_infinity;
144+
//
145+
// We define a new function `conditional_select` that does this operation and use it to handle the skew
146+
// addition.
147+
const uint256_t scalar_s1("0x66ad81e84534c20431c795de922fb592c3d8c68edcacfc6c5b52ab7ad10e47d3");
148+
const uint256_t scalar_u1("0x37e0ba2e9c4dd42077fd751a7426a8484a8ff2928a6c85a651e4470b461c6215");
149+
const uint256_t scalar_u2("0xdefbb9bbabde5b9f8d7175946e75babc2f11203a8bfb71beaeec1d7a2bff17dd");
150+
151+
// Check the assumptions
152+
ASSERT(scalar_s1 < fr::modulus);
153+
ASSERT(scalar_u1 < fr::modulus);
154+
ASSERT(scalar_u2 < fr::modulus);
155+
ASSERT((fr(scalar_s1) * fr(scalar_u2) + fr(scalar_u1)).is_zero());
156+
ASSERT((g1::one * fr(scalar_u1) + (g1::one * fr(scalar_s1)) * fr(scalar_u2)).is_point_at_infinity());
157+
158+
Builder builder = Builder();
159+
element_ct P_a = element_ct::from_witness(&builder, g1::one * fr(scalar_s1));
160+
scalar_ct u1 = scalar_ct::from_witness(&builder, fr(scalar_u1));
161+
scalar_ct u2 = scalar_ct::from_witness(&builder, fr(scalar_u2));
162+
auto output = element_ct::secp256k1_ecdsa_mul(P_a, u1, u2);
163+
164+
// Check that the output is the point at infinity
165+
EXPECT_EQ(output.is_point_at_infinity().get_value(), true);
166+
167+
EXPECT_CIRCUIT_CORRECTNESS(builder);
168+
}
114169
};
115170

116171
// Then define the test types
@@ -133,3 +188,7 @@ TYPED_TEST(stdlibBiggroupSecp256k1, EcdsaMulSecp256k1)
133188
{
134189
TestFixture::test_ecdsa_mul_secp256k1();
135190
}
191+
TYPED_TEST(stdlibBiggroupSecp256k1, EcdsaMulSecp256k1SkewHandlingRegression)
192+
{
193+
TestFixture::test_secp256k1_ecdsa_mul_skew_handling_regression();
194+
}

0 commit comments

Comments
 (0)