Skip to content

Commit a616cfb

Browse files
committed
balanced AIG construction to avoid segfault due to deep recursion
1 parent 31dd021 commit a616cfb

File tree

2 files changed

+53
-11
lines changed

2 files changed

+53
-11
lines changed

src/interpolant.cpp

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,19 @@ void MyTracer::add_derived_clause(uint64_t id, bool /*red*/, const std::vector<i
4848

4949
const uint64_t id1 = rantec[0];
5050
auto aig = fs_clid[id1];
51+
release_assert(aig != nullptr);
5152
set<Lit> resolvent(cls[id1].begin(),cls[id1].end());
53+
std::vector<aig_ptr> same_op_terms;
54+
bool same_op_is_and = false;
55+
bool same_op_started = false;
56+
57+
auto flush_terms = [&]() {
58+
if (!same_op_started) return;
59+
aig = combine_balanced(std::move(same_op_terms), same_op_is_and);
60+
same_op_terms.clear();
61+
same_op_started = false;
62+
};
63+
5264
for(uint32_t i = 1; i < rantec.size(); i++) {
5365
if (conf.verb >= 4) {
5466
cout << "resolvent: "; for(const auto& l: resolvent) cout << l << " "; cout << endl;
@@ -70,9 +82,24 @@ void MyTracer::add_derived_clause(uint64_t id, bool /*red*/, const std::vector<i
7082
}
7183
assert(res_lit != lit_Undef);
7284
bool input_or_copy = input.count(res_lit.var()) || res_lit.var() >= (uint32_t)orig_num_vars;
73-
if (input_or_copy) aig = AIG::new_and(aig, fs_clid[id2]);
74-
else aig = AIG::new_or(aig, fs_clid[id2]);
85+
auto rhs = fs_clid[id2];
86+
release_assert(rhs != nullptr);
87+
if (!same_op_started) {
88+
same_op_started = true;
89+
same_op_is_and = input_or_copy;
90+
same_op_terms.push_back(aig);
91+
same_op_terms.push_back(rhs);
92+
} else if (same_op_is_and == input_or_copy) {
93+
same_op_terms.push_back(rhs);
94+
} else {
95+
flush_terms();
96+
same_op_started = true;
97+
same_op_is_and = input_or_copy;
98+
same_op_terms.push_back(aig);
99+
same_op_terms.push_back(rhs);
100+
}
75101
}
102+
flush_terms();
76103
fs_clid[id] = aig;
77104
verb_print(5, "intermediate formula: " << fs_clid[id]);
78105
if (clause.empty()) {
@@ -179,8 +206,8 @@ void Interpolant::generate_interpolant(
179206
}
180207

181208
// CaDiCaL on the core only
182-
auto cdcl = std::make_unique<Solver>();
183209
MyTracer t(orig_num_vars, input_vars, conf, lit_to_aig, cnf.get_aig_mng());
210+
auto cdcl = std::make_unique<Solver>();
184211

185212
cdcl->connect_proof_tracer(&t, true);
186213
/* std::stringstream name; */
@@ -204,6 +231,7 @@ void Interpolant::generate_interpolant(
204231
}
205232
release_assert(pret == Status::UNSATISFIABLE);
206233
cdcl->disconnect_proof_tracer(&t);
234+
cdcl.reset();
207235

208236
defs[test_var] = t.out;
209237
verb_print(5, "definition of var: " << test_var+1 << " is: " << t.out);

src/interpolant.h

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ extern "C" {
3434
#include "formula.h"
3535
#include <vector>
3636
#include <map>
37+
#include <utility>
3738
#include <cstdint>
3839
#include <cadical.hpp>
3940
#include <tracer.hpp>
@@ -63,6 +64,21 @@ struct MyTracer : public CaDiCaL::Tracer {
6364
// AIG cache
6465
map<Lit, aig_ptr>& lit_to_aig;
6566

67+
static aig_ptr combine_balanced(std::vector<aig_ptr> terms, bool use_and) {
68+
release_assert(!terms.empty());
69+
while (terms.size() > 1) {
70+
std::vector<aig_ptr> next;
71+
next.reserve((terms.size()+1)/2);
72+
for(uint32_t i = 0; i < terms.size(); i += 2) {
73+
if (i+1 >= terms.size()) next.push_back(terms[i]);
74+
else if (use_and) next.push_back(AIG::new_and(terms[i], terms[i+1]));
75+
else next.push_back(AIG::new_or(terms[i], terms[i+1]));
76+
}
77+
terms = std::move(next);
78+
}
79+
return terms[0];
80+
}
81+
6682
aig_ptr get_aig(const Lit l) {
6783
if (lit_to_aig.count(l)) return lit_to_aig.at(l);
6884
aig_ptr aig = AIG::new_lit(l);
@@ -73,13 +89,12 @@ struct MyTracer : public CaDiCaL::Tracer {
7389
aig_ptr get_aig(const vector<Lit>& unsorted_cl) {
7490
vector<Lit> cl = unsorted_cl;
7591
std::sort(cl.begin(), cl.end());
76-
aig_ptr aig = nullptr;
77-
for(const auto& l: cl) {
78-
if (aig == nullptr) aig = get_aig(l);
79-
else aig = AIG::new_or(aig, get_aig(l));
80-
}
81-
if (aig == nullptr) aig = aig_mng.new_const(false);
82-
return aig;
92+
if (cl.empty()) return aig_mng.new_const(false);
93+
94+
std::vector<aig_ptr> leaves;
95+
leaves.reserve(cl.size());
96+
for(const auto& l: cl) leaves.push_back(get_aig(l));
97+
return combine_balanced(std::move(leaves), false);
8398
};
8499

85100
void add_derived_clause (uint64_t id, bool red, const std::vector<int> & clause,
@@ -124,4 +139,3 @@ class Interpolant {
124139
vector<uint32_t> var_to_indic; //maps an ORIG VAR to an INDICATOR VAR
125140
vector<aig_ptr> defs; //definition of variables in terms of AIG. ORIGINAL number space
126141
};
127-

0 commit comments

Comments
 (0)