Skip to content

Commit a457999

Browse files
author
Ryan Kim
authored
Merge pull request #471 from kroma-network/perf/optimize-zkey-parsing
perf(circom): optimize zkey parsing
2 parents f9a40b0 + 0526c10 commit a457999

File tree

9 files changed

+106
-155
lines changed

9 files changed

+106
-155
lines changed

vendors/circom/benchmark/circom_benchmark.cc

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -91,20 +91,15 @@ int RealMain(int argc, char** argv) {
9191
witness_loader.Set("in", Uint8ToBitVector<F>(in));
9292
witness_loader.Load();
9393

94-
const zk::r1cs::ConstraintMatrices<F>& constraint_matrices =
95-
tachyon_runner->constraint_matrices();
96-
97-
domain = Domain::Create(constraint_matrices.num_constraints +
98-
constraint_matrices.num_instance_variables);
94+
domain = Domain::Create(tachyon_runner->GetDomainSize());
9995

96+
size_t num_instance_variables = tachyon_runner->GetNumInstanceVariables();
10097
full_assignments = base::CreateVector(
101-
constraint_matrices.num_instance_variables +
102-
constraint_matrices.num_witness_variables,
98+
num_instance_variables + tachyon_runner->GetNumWitnessVariables(),
10399
[&witness_loader](size_t i) { return witness_loader.Get(i); });
104100

105-
public_inputs =
106-
absl::MakeConstSpan(full_assignments)
107-
.subspan(1, constraint_matrices.num_instance_variables - 1);
101+
public_inputs = absl::MakeConstSpan(full_assignments)
102+
.subspan(1, num_instance_variables - 1);
108103
CheckPublicInput(in, public_inputs);
109104
}
110105

vendors/circom/benchmark/tachyon_runner.h

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,16 +33,22 @@ class TachyonRunner : public Runner<Curve, MaxDegree> {
3333
return proving_key_;
3434
}
3535

36-
const zk::r1cs::ConstraintMatrices<F>& constraint_matrices() const {
37-
return constraint_matrices_;
36+
size_t GetDomainSize() const { return zkey_->GetDomainSize(); }
37+
38+
size_t GetNumInstanceVariables() const {
39+
return zkey_->GetNumInstanceVariables();
40+
}
41+
42+
size_t GetNumWitnessVariables() const {
43+
return zkey_->GetNumWitnessVariables();
3844
}
3945

4046
void LoadZkey(const base::FilePath& zkey_path) override {
4147
zkey_ = ParseZKey<Curve>(zkey_path);
4248
CHECK(zkey_);
4349

4450
proving_key_ = zkey_->GetProvingKey().ToNativeProvingKey();
45-
constraint_matrices_ = zkey_->GetConstraintMatrices();
51+
coefficients_ = zkey_->GetCoefficients();
4652
}
4753

4854
zk::r1cs::groth16::Proof<Curve> Run(const Domain* domain,
@@ -53,15 +59,16 @@ class TachyonRunner : public Runner<Curve, MaxDegree> {
5359

5460
std::vector<F> h_evals =
5561
QuadraticArithmeticProgram<F>::WitnessMapFromMatrices(
56-
domain, constraint_matrices_, full_assignments);
62+
domain, coefficients_, full_assignments);
5763

64+
size_t num_instance_variables = GetNumInstanceVariables();
5865
zk::r1cs::groth16::Proof<Curve> proof =
5966
zk::r1cs::groth16::CreateProofWithAssignmentNoZK(
6067
proving_key_, absl::MakeConstSpan(h_evals),
6168
absl::MakeConstSpan(full_assignments)
62-
.subspan(1, constraint_matrices_.num_instance_variables - 1),
69+
.subspan(1, num_instance_variables - 1),
6370
absl::MakeConstSpan(full_assignments)
64-
.subspan(constraint_matrices_.num_instance_variables),
71+
.subspan(num_instance_variables),
6572
absl::MakeConstSpan(full_assignments).subspan(1));
6673

6774
delta = base::TimeTicks::Now() - now;
@@ -80,7 +87,7 @@ class TachyonRunner : public Runner<Curve, MaxDegree> {
8087
WitnessLoader<F> witness_loader_;
8188
std::unique_ptr<ZKey<Curve>> zkey_;
8289
zk::r1cs::groth16::ProvingKey<Curve> proving_key_;
83-
zk::r1cs::ConstraintMatrices<F> constraint_matrices_;
90+
absl::Span<const Coefficient<F>> coefficients_;
8491
std::optional<zk::r1cs::groth16::PreparedVerifyingKey<Curve>>
8592
prepared_verifying_key_;
8693
};

vendors/circom/circomlib/circuit/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ tachyon_cc_library(
1616
name = "quadratic_arithmetic_program",
1717
hdrs = ["quadratic_arithmetic_program.h"],
1818
deps = [
19+
"//circomlib/zkey:coefficient",
1920
"@kroma_network_tachyon//tachyon/base:logging",
2021
"@kroma_network_tachyon//tachyon/zk/r1cs/constraint_system:quadratic_arithmetic_program",
2122
],

vendors/circom/circomlib/circuit/circuit_test.h

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -54,22 +54,18 @@ class CircuitTest : public testing::Test {
5454

5555
zk::r1cs::groth16::ProvingKey<Curve> pk =
5656
zkey.GetProvingKey().ToNativeProvingKey();
57-
zk::r1cs::ConstraintMatrices<F> constraint_matrices =
58-
zkey.GetConstraintMatrices();
57+
absl::Span<const Coefficient<F>> coefficients = zkey.GetCoefficients();
5958

60-
std::unique_ptr<Domain> domain =
61-
Domain::Create(constraint_matrices.num_constraints +
62-
constraint_matrices.num_instance_variables);
59+
std::unique_ptr<Domain> domain = Domain::Create(zkey.GetDomainSize());
6360
std::vector<F> h_evals = QAP::WitnessMapFromMatrices(
64-
domain.get(), constraint_matrices, full_assignments);
61+
domain.get(), coefficients, full_assignments);
6562

63+
size_t num_instance_variables = zkey.GetNumInstanceVariables();
6664
zk::r1cs::groth16::Proof<Curve> proof =
6765
zk::r1cs::groth16::CreateProofWithAssignmentZK(
6866
pk, absl::MakeConstSpan(h_evals),
69-
full_assignments.subspan(
70-
1, constraint_matrices.num_instance_variables - 1),
71-
full_assignments.subspan(
72-
constraint_matrices.num_instance_variables),
67+
full_assignments.subspan(1, num_instance_variables - 1),
68+
full_assignments.subspan(num_instance_variables),
7369
full_assignments.subspan(1));
7470
zk::r1cs::groth16::PreparedVerifyingKey<Curve> pvk =
7571
std::move(pk).TakeVerifyingKey().ToPreparedVerifyingKey();

vendors/circom/circomlib/circuit/quadratic_arithmetic_program.h

Lines changed: 30 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@
66
#ifndef VENDORS_CIRCOM_CIRCOMLIB_CIRCUIT_QUADRATIC_ARITHMETIC_PROGRAM_H_
77
#define VENDORS_CIRCOM_CIRCOMLIB_CIRCUIT_QUADRATIC_ARITHMETIC_PROGRAM_H_
88

9-
#include <memory>
109
#include <utility>
1110
#include <vector>
1211

12+
#include "circomlib/zkey/coefficient.h"
1313
#include "tachyon/base/logging.h"
1414
#include "tachyon/zk/r1cs/constraint_system/quadratic_arithmetic_program.h"
1515

@@ -20,55 +20,46 @@ class QuadraticArithmeticProgram {
2020
public:
2121
QuadraticArithmeticProgram() = delete;
2222

23-
template <typename Domain>
24-
static zk::r1cs::QAPInstanceMapResult<F> InstanceMap(
25-
const Domain* domain, const zk::r1cs::ConstraintSystem<F>& cs,
26-
const F& x) {
27-
return zk::r1cs::QuadraticArithmeticProgram<F>::InstanceMap(domain, cs, x);
28-
}
29-
3023
template <typename Domain>
3124
static std::vector<F> WitnessMapFromMatrices(
32-
const Domain* domain, const zk::r1cs::ConstraintMatrices<F>& matrices,
25+
const Domain* domain, absl::Span<const Coefficient<F>> coefficients,
3326
absl::Span<const F> full_assignments) {
3427
using Evals = typename Domain::Evals;
3528
using DensePoly = typename Domain::DensePoly;
3629

37-
CHECK_GE(domain->size(), matrices.num_constraints);
38-
3930
std::vector<F> a(domain->size());
4031
std::vector<F> b(domain->size());
4132
std::vector<F> c(domain->size());
4233

43-
// clang-format off
44-
// |a[i]| = Σⱼ₌₀..ₘ (xⱼ * Aᵢ,ⱼ) (if i < |num_constraints|)
45-
// = x[i - num_constraints] (otherwise)
46-
// |b[i]| = Σⱼ₌₀..ₘ (xⱼ * Bᵢ,ⱼ) (if i < |num_constraints|)
47-
// = 0 (otherwise)
48-
// |c[i]| = |a[i]|* |b[i]| (if i < |num_constraints|)
49-
// = 0 (otherwise)
50-
// where x is |full_assignments|.
51-
// clang-format on
52-
OMP_PARALLEL {
53-
OMP_FOR_NOWAIT
54-
for (size_t i = 0; i < matrices.num_constraints; ++i) {
55-
a[i] = zk::r1cs::EvaluateConstraint(matrices.a[i], full_assignments);
56-
}
57-
58-
OMP_FOR
59-
for (size_t i = 0; i < matrices.num_constraints; ++i) {
60-
b[i] = zk::r1cs::EvaluateConstraint(matrices.b[i], full_assignments);
61-
}
62-
63-
OMP_FOR
64-
for (size_t i = 0; i < matrices.num_constraints; ++i) {
65-
c[i] = a[i] * b[i];
34+
// See
35+
// https://github.com/iden3/rapidsnark/blob/b17e6fed08e9ceec3518edeffe4384313f91e9ad/src/groth16.cpp#L116-L156.
36+
#if defined(TACHYON_HAS_OPENMP)
37+
constexpr size_t kNumLocks = 1024;
38+
omp_lock_t locks[kNumLocks];
39+
for (size_t i = 0; i < kNumLocks; i++) omp_init_lock(&locks[i]);
40+
#endif
41+
OPENMP_PARALLEL_FOR(size_t i = 0; i < coefficients.size(); i++) {
42+
const Coefficient<F>& c = coefficients[i];
43+
std::vector<F>& ab = (c.matrix == 0) ? a : b;
44+
45+
#if defined(TACHYON_HAS_OPENMP)
46+
omp_set_lock(&locks[c.constraint % kNumLocks]);
47+
#endif
48+
if (c.value.IsOne()) {
49+
ab[c.constraint] += full_assignments[c.signal];
50+
} else {
51+
ab[c.constraint] += c.value * full_assignments[c.signal];
6652
}
53+
#if defined(TACHYON_HAS_OPENMP)
54+
omp_unset_lock(&locks[c.constraint % kNumLocks]);
55+
#endif
6756
}
57+
#if defined(TACHYON_HAS_OPENMP)
58+
for (size_t i = 0; i < kNumLocks; i++) omp_destroy_lock(&locks[i]);
59+
#endif
6860

69-
for (size_t i = matrices.num_constraints;
70-
i < matrices.num_constraints + matrices.num_instance_variables; ++i) {
71-
a[i] = full_assignments[i - matrices.num_constraints];
61+
OPENMP_PARALLEL_FOR(size_t i = 0; i < domain->size(); ++i) {
62+
c[i] = a[i] * b[i];
7263
}
7364

7465
Evals a_evals(std::move(a));
@@ -79,11 +70,8 @@ class QuadraticArithmeticProgram {
7970
DensePoly c_poly = domain->IFFT(std::move(c_evals));
8071

8172
F root_of_unity;
82-
{
83-
std::unique_ptr<Domain> extended_domain =
84-
Domain::Create(2 * domain->size());
85-
root_of_unity = extended_domain->GetElement(1);
86-
}
73+
CHECK(F::GetRootOfUnity(2 * domain->size(), &root_of_unity));
74+
8775
Domain::DistributePowers(a_poly, root_of_unity);
8876
Domain::DistributePowers(b_poly, root_of_unity);
8977
Domain::DistributePowers(c_poly, root_of_unity);
@@ -101,26 +89,6 @@ class QuadraticArithmeticProgram {
10189

10290
return std::move(a_evals).TakeEvaluations();
10391
}
104-
105-
template <typename Domain>
106-
static std::vector<F> ComputeHQuery(const Domain* domain, const F& t_x,
107-
const F& x, const F& delta_inverse) {
108-
using Evals = typename Domain::Evals;
109-
using DensePoly = typename Domain::DensePoly;
110-
111-
// The usual H query has domain - 1 powers. Z has domain powers. So HZ has
112-
// 2 * domain - 1 powers.
113-
std::unique_ptr<Domain> extended_domain =
114-
Domain::Create(domain->size() * 2 + 1);
115-
Evals evals(
116-
F::GetSuccessivePowers(extended_domain->size(), t_x, delta_inverse));
117-
DensePoly poly = extended_domain->IFFT(std::move(evals));
118-
std::vector<F> ret(domain->size());
119-
OPENMP_PARALLEL_FOR(size_t i = 0; i < domain->size(); ++i) {
120-
ret[i] = poly[2 * i + 1];
121-
}
122-
return ret;
123-
}
12492
};
12593

12694
} // namespace tachyon::circom

vendors/circom/circomlib/zkey/BUILD.bazel

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@ tachyon_cc_library(
4545
"@kroma_network_tachyon//tachyon/base/buffer:read_only_buffer",
4646
"@kroma_network_tachyon//tachyon/base/files:file_util",
4747
"@kroma_network_tachyon//tachyon/base/strings:string_util",
48-
"@kroma_network_tachyon//tachyon/zk/r1cs/constraint_system:constraint_matrices",
4948
],
5049
)
5150

vendors/circom/circomlib/zkey/zkey.h

Lines changed: 19 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
#include "tachyon/base/logging.h"
2222
#include "tachyon/base/openmp_util.h"
2323
#include "tachyon/base/strings/string_util.h"
24-
#include "tachyon/zk/r1cs/constraint_system/constraint_matrices.h"
2524

2625
namespace tachyon::circom {
2726
namespace v1 {
@@ -45,7 +44,10 @@ struct ZKey {
4544
virtual bool Read(const base::ReadOnlyBuffer& buffer) = 0;
4645

4746
virtual ProvingKey<Curve> GetProvingKey() const = 0;
48-
virtual zk::r1cs::ConstraintMatrices<F> GetConstraintMatrices() const = 0;
47+
virtual absl::Span<const Coefficient<F>> GetCoefficients() const = 0;
48+
virtual size_t GetDomainSize() const = 0;
49+
virtual size_t GetNumInstanceVariables() const = 0;
50+
virtual size_t GetNumWitnessVariables() const = 0;
4951

5052
std::vector<uint8_t> data;
5153
};
@@ -211,6 +213,12 @@ struct CoefficientsSection {
211213
Coefficient<F>* ptr;
212214
if (!buffer.ReadPtr(&ptr, num_coefficients)) return false;
213215
coefficients = {ptr, num_coefficients};
216+
217+
OPENMP_PARALLEL_FOR(size_t i = 0; i < coefficients.size(); ++i) {
218+
coefficients[i].value =
219+
F::FromMontgomery(coefficients[i].value.ToBigInt());
220+
}
221+
214222
return true;
215223
}
216224

@@ -294,50 +302,18 @@ struct ZKey : public circom::ZKey<Curve> {
294302
};
295303
}
296304

297-
zk::r1cs::ConstraintMatrices<F> GetConstraintMatrices() const override {
298-
std::vector<std::vector<zk::r1cs::Cell<F>>> a(header_groth.domain_size);
299-
std::vector<std::vector<zk::r1cs::Cell<F>>> b(header_groth.domain_size);
300-
301-
uint32_t max_constraint = 0;
302-
for (const Coefficient<F>& c : coefficients.coefficients) {
303-
max_constraint = std::max(c.constraint, max_constraint);
304-
if (c.matrix == 0) {
305-
a[c.constraint].push_back({std::move(c.value), c.signal});
306-
} else {
307-
b[c.constraint].push_back({std::move(c.value), c.signal});
308-
}
309-
}
310-
311-
// Need to divide by R, since snarkjs outputs the zkey with coefficients
312-
// multiplied by R².
313-
OPENMP_PARALLEL_FOR(size_t i = 0; i < max_constraint; ++i) {
314-
if (i < a.size()) {
315-
for (size_t j = 0; j < a[i].size(); ++j) {
316-
a[i][j].coefficient =
317-
F::FromMontgomery(a[i][j].coefficient.ToBigInt());
318-
}
319-
}
320-
if (i < b.size()) {
321-
for (size_t j = 0; j < b[i].size(); ++j) {
322-
b[i][j].coefficient =
323-
F::FromMontgomery(b[i][j].coefficient.ToBigInt());
324-
}
325-
}
326-
}
305+
absl::Span<const Coefficient<F>> GetCoefficients() const override {
306+
return coefficients.coefficients;
307+
}
327308

328-
return {
329-
header_groth.num_public_inputs + 1,
330-
header_groth.num_vars - header_groth.num_public_inputs - 1,
331-
max_constraint - header_groth.num_public_inputs,
309+
size_t GetDomainSize() const override { return header_groth.domain_size; }
332310

333-
0,
334-
0,
335-
0,
311+
size_t GetNumInstanceVariables() const override {
312+
return header_groth.num_public_inputs + 1;
313+
}
336314

337-
zk::r1cs::Matrix<F>(std::move(a)),
338-
zk::r1cs::Matrix<F>(std::move(b)),
339-
{},
340-
};
315+
size_t GetNumWitnessVariables() const override {
316+
return header_groth.num_vars - header_groth.num_public_inputs - 1;
341317
}
342318

343319
std::string ToString() const {

vendors/circom/circomlib/zkey/zkey_unittest.cc

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -157,12 +157,12 @@ TEST_F(ZKeyTest, Parse) {
157157

158158
// clang-format off
159159
CoefficientData coefficient_datas[] = {
160-
{0, 0, 2, "15537367993719455909907449462855742678020201736855642022731641111541721333766"},
161-
{1, 0, 3, "6350874878119819312338956282401532410528162663560392320966563075034087161851"},
162-
{0, 1, 5, "15537367993719455909907449462855742678020201736855642022731641111541721333766"},
163-
{1, 1, 4, "6350874878119819312338956282401532410528162663560392320966563075034087161851"},
164-
{0, 2, 0, "6350874878119819312338956282401532410528162663560392320966563075034087161851"},
165-
{0, 3, 1, "6350874878119819312338956282401532410528162663560392320966563075034087161851"},
160+
{0, 0, 2, "21888242871839275222246405745257275088548364400416034343698204186575808495616"},
161+
{1, 0, 3, "1"},
162+
{0, 1, 5, "21888242871839275222246405745257275088548364400416034343698204186575808495616"},
163+
{1, 1, 4, "1"},
164+
{0, 2, 0, "1"},
165+
{0, 3, 1, "1"},
166166
};
167167
// clang-format on
168168

0 commit comments

Comments
 (0)