Skip to content

Commit 0b5ddae

Browse files
committed
feat: add predicate to msm and ecadd
1 parent 9f79db5 commit 0b5ddae

File tree

5 files changed

+48
-5
lines changed

5 files changed

+48
-5
lines changed

barretenberg/cpp/src/barretenberg/dsl/acir_format/ec_operations.hpp

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,25 @@ struct EcAdd {
1818
WitnessOrConstant<bb::fr> input2_x;
1919
WitnessOrConstant<bb::fr> input2_y;
2020
WitnessOrConstant<bb::fr> input2_infinite;
21+
// Predicate indicating whether the constraint should be disabled:
22+
// - true: the constraint is valid
23+
// - false: the constraint is disabled, i.e it must not fail and can return whatever.
24+
WitnessOrConstant<bb::fr> predicate;
2125
uint32_t result_x;
2226
uint32_t result_y;
2327
uint32_t result_infinite;
2428

2529
// for serialization, update with any new fields
26-
MSGPACK_FIELDS(
27-
input1_x, input1_y, input1_infinite, input2_x, input2_y, input2_infinite, result_x, result_y, result_infinite);
30+
MSGPACK_FIELDS(input1_x,
31+
input1_y,
32+
input1_infinite,
33+
input2_x,
34+
input2_y,
35+
input2_infinite,
36+
predicate,
37+
result_x,
38+
result_y,
39+
result_infinite);
2840
friend bool operator==(EcAdd const& lhs, EcAdd const& rhs) = default;
2941
};
3042

barretenberg/cpp/src/barretenberg/dsl/acir_format/ecdsa_constraints.test.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include "acir_format_mocks.hpp"
44
#include "barretenberg/crypto/ecdsa/ecdsa.hpp"
55
#include "barretenberg/dsl/acir_format/utils.hpp"
6+
#include "barretenberg/dsl/acir_format/witness_constant.hpp"
67
#include "barretenberg/stdlib/primitives/curves/secp256k1.hpp"
78
#include "barretenberg/stdlib/primitives/curves/secp256r1.hpp"
89

@@ -86,6 +87,7 @@ template <class Curve> class EcdsaConstraintsTest : public ::testing::Test {
8687
.signature = signature_indices,
8788
.pub_x_indices = pub_x_indices,
8889
.pub_y_indices = pub_y_indices,
90+
.predicate = WitnessOrConstant<bb::fr>::from_constant(bb::fr::one()),
8991
.result = result_index };
9092

9193
return num_variables;

barretenberg/cpp/src/barretenberg/dsl/acir_format/multi_scalar_mul.hpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,17 @@ namespace acir_format {
1717
struct MultiScalarMul {
1818
std::vector<WitnessOrConstant<bb::fr>> points;
1919
std::vector<WitnessOrConstant<bb::fr>> scalars;
20+
// Predicate indicating whether the constraint should be disabled:
21+
// - true: the constraint is valid
22+
// - false: the constraint is disabled, i.e it must not fail and can return whatever.
23+
WitnessOrConstant<bb::fr> predicate;
2024

2125
uint32_t out_point_x;
2226
uint32_t out_point_y;
2327
uint32_t out_point_is_infinite;
2428

2529
// for serialization, update with any new fields
26-
MSGPACK_FIELDS(points, scalars, out_point_x, out_point_y, out_point_is_infinite);
30+
MSGPACK_FIELDS(points, scalars, predicate, out_point_x, out_point_y, out_point_is_infinite);
2731
friend bool operator==(MultiScalarMul const& lhs, MultiScalarMul const& rhs) = default;
2832
};
2933

barretenberg/cpp/src/barretenberg/dsl/acir_format/serde/acir.hpp

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2984,6 +2984,7 @@ struct BlackBoxFuncCall {
29842984
struct MultiScalarMul {
29852985
std::vector<Acir::FunctionInput> points;
29862986
std::vector<Acir::FunctionInput> scalars;
2987+
Acir::FunctionInput predicate;
29872988
std::shared_ptr<std::array<Acir::Witness, 3>> outputs;
29882989

29892990
friend bool operator==(const MultiScalarMul&, const MultiScalarMul&);
@@ -2992,9 +2993,10 @@ struct BlackBoxFuncCall {
29922993

29932994
void msgpack_pack(auto& packer) const
29942995
{
2995-
packer.pack_map(3);
2996+
packer.pack_map(4);
29962997
packer.pack(std::make_pair("points", points));
29972998
packer.pack(std::make_pair("scalars", scalars));
2999+
packer.pack(std::make_pair("predicate", predicate));
29983000
packer.pack(std::make_pair("outputs", outputs));
29993001
}
30003002

@@ -3004,13 +3006,15 @@ struct BlackBoxFuncCall {
30043006
auto kvmap = Helpers::make_kvmap(o, name);
30053007
Helpers::conv_fld_from_kvmap(kvmap, name, "points", points, false);
30063008
Helpers::conv_fld_from_kvmap(kvmap, name, "scalars", scalars, false);
3009+
Helpers::conv_fld_from_kvmap(kvmap, name, "predicate", predicate, false);
30073010
Helpers::conv_fld_from_kvmap(kvmap, name, "outputs", outputs, false);
30083011
}
30093012
};
30103013

30113014
struct EmbeddedCurveAdd {
30123015
std::shared_ptr<std::array<Acir::FunctionInput, 3>> input1;
30133016
std::shared_ptr<std::array<Acir::FunctionInput, 3>> input2;
3017+
Acir::FunctionInput predicate;
30143018
std::shared_ptr<std::array<Acir::Witness, 3>> outputs;
30153019

30163020
friend bool operator==(const EmbeddedCurveAdd&, const EmbeddedCurveAdd&);
@@ -3019,9 +3023,10 @@ struct BlackBoxFuncCall {
30193023

30203024
void msgpack_pack(auto& packer) const
30213025
{
3022-
packer.pack_map(3);
3026+
packer.pack_map(4);
30233027
packer.pack(std::make_pair("input1", input1));
30243028
packer.pack(std::make_pair("input2", input2));
3029+
packer.pack(std::make_pair("predicate", predicate));
30253030
packer.pack(std::make_pair("outputs", outputs));
30263031
}
30273032

@@ -3031,6 +3036,7 @@ struct BlackBoxFuncCall {
30313036
auto kvmap = Helpers::make_kvmap(o, name);
30323037
Helpers::conv_fld_from_kvmap(kvmap, name, "input1", input1, false);
30333038
Helpers::conv_fld_from_kvmap(kvmap, name, "input2", input2, false);
3039+
Helpers::conv_fld_from_kvmap(kvmap, name, "predicate", predicate, false);
30343040
Helpers::conv_fld_from_kvmap(kvmap, name, "outputs", outputs, false);
30353041
}
30363042
};
@@ -6331,6 +6337,9 @@ inline bool operator==(const BlackBoxFuncCall::MultiScalarMul& lhs, const BlackB
63316337
if (!(lhs.scalars == rhs.scalars)) {
63326338
return false;
63336339
}
6340+
if (!(lhs.predicate == rhs.predicate)) {
6341+
return false;
6342+
}
63346343
if (!(lhs.outputs == rhs.outputs)) {
63356344
return false;
63366345
}
@@ -6363,6 +6372,7 @@ void serde::Serializable<Acir::BlackBoxFuncCall::MultiScalarMul>::serialize(
63636372
{
63646373
serde::Serializable<decltype(obj.points)>::serialize(obj.points, serializer);
63656374
serde::Serializable<decltype(obj.scalars)>::serialize(obj.scalars, serializer);
6375+
serde::Serializable<decltype(obj.predicate)>::serialize(obj.predicate, serializer);
63666376
serde::Serializable<decltype(obj.outputs)>::serialize(obj.outputs, serializer);
63676377
}
63686378

@@ -6374,6 +6384,7 @@ Acir::BlackBoxFuncCall::MultiScalarMul serde::Deserializable<Acir::BlackBoxFuncC
63746384
Acir::BlackBoxFuncCall::MultiScalarMul obj;
63756385
obj.points = serde::Deserializable<decltype(obj.points)>::deserialize(deserializer);
63766386
obj.scalars = serde::Deserializable<decltype(obj.scalars)>::deserialize(deserializer);
6387+
obj.predicate = serde::Deserializable<decltype(obj.predicate)>::deserialize(deserializer);
63776388
obj.outputs = serde::Deserializable<decltype(obj.outputs)>::deserialize(deserializer);
63786389
return obj;
63796390
}
@@ -6388,6 +6399,9 @@ inline bool operator==(const BlackBoxFuncCall::EmbeddedCurveAdd& lhs, const Blac
63886399
if (!(lhs.input2 == rhs.input2)) {
63896400
return false;
63906401
}
6402+
if (!(lhs.predicate == rhs.predicate)) {
6403+
return false;
6404+
}
63916405
if (!(lhs.outputs == rhs.outputs)) {
63926406
return false;
63936407
}
@@ -6421,6 +6435,7 @@ void serde::Serializable<Acir::BlackBoxFuncCall::EmbeddedCurveAdd>::serialize(
64216435
{
64226436
serde::Serializable<decltype(obj.input1)>::serialize(obj.input1, serializer);
64236437
serde::Serializable<decltype(obj.input2)>::serialize(obj.input2, serializer);
6438+
serde::Serializable<decltype(obj.predicate)>::serialize(obj.predicate, serializer);
64246439
serde::Serializable<decltype(obj.outputs)>::serialize(obj.outputs, serializer);
64256440
}
64266441

@@ -6432,6 +6447,7 @@ Acir::BlackBoxFuncCall::EmbeddedCurveAdd serde::Deserializable<Acir::BlackBoxFun
64326447
Acir::BlackBoxFuncCall::EmbeddedCurveAdd obj;
64336448
obj.input1 = serde::Deserializable<decltype(obj.input1)>::deserialize(deserializer);
64346449
obj.input2 = serde::Deserializable<decltype(obj.input2)>::deserialize(deserializer);
6450+
obj.predicate = serde::Deserializable<decltype(obj.predicate)>::deserialize(deserializer);
64356451
obj.outputs = serde::Deserializable<decltype(obj.outputs)>::deserialize(deserializer);
64366452
return obj;
64376453
}

barretenberg/cpp/src/barretenberg/dsl/acir_format/witness_constant.hpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,15 @@ template <typename FF> struct WitnessOrConstant {
2525
.is_constant = false,
2626
};
2727
}
28+
29+
static WitnessOrConstant from_constant(FF value)
30+
{
31+
return WitnessOrConstant{
32+
.index = 0,
33+
.value = value,
34+
.is_constant = true,
35+
};
36+
}
2837
};
2938

3039
template <typename Builder, typename FF>

0 commit comments

Comments
 (0)