Skip to content

Commit 7ed3934

Browse files
committed
Allow inversion of the learnt function
1 parent 61e6e5c commit 7ed3934

File tree

4 files changed

+6
-1
lines changed

4 files changed

+6
-1
lines changed

src/arjun.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1386,6 +1386,7 @@ class Arjun
13861386
int force_bw_equal = 1;
13871387
int bva_xor_vars = 0;
13881388
int silent_var_update = 1;
1389+
int inv_learnt = 0;
13891390
uint32_t max_repairs = std::numeric_limits<uint32_t>::max();
13901391
};
13911392

src/main.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,7 @@ void add_arjun_options() {
166166
myopt("--mbve", mconf.manthan_bve, fc_int,"Use BVE with constants instead of training");
167167
myopt("--monflyorder", mconf.manthan_on_the_fly_order, fc_int,"Use on-the-fly training order and post-training topological order");
168168
myopt("--moneperloop", mconf.one_repair_per_loop, fc_int,"One repair per CEX loop");
169+
myopt("--minvertlearn", mconf.inv_learnt, fc_int,"Invert learnt functions");
169170

170171
// repairing on vars
171172
myopt("--bwequal", mconf.force_bw_equal, fc_int,"Force BW vars' indicators to be TRUE -- prevents repairing with them, but faster to repair");

src/manthan.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2066,7 +2066,7 @@ double Manthan::train(const vector<sample>& orig_samples, const uint32_t v) {
20662066
<< setw(6) << (samples.size() - num_ones) << " zeros");
20672067
double train_error;
20682068
if (samples.empty()) {
2069-
var_to_formula[v] = fh->constant_formula(true);
2069+
var_to_formula[v] = fh->constant_formula(!mconf.inv_learnt);
20702070
train_error = 0.0;
20712071
} else {
20722072
// Create the RandomForest object and train it on the training data.
@@ -2115,6 +2115,8 @@ double Manthan::train(const vector<sample>& orig_samples, const uint32_t v) {
21152115
assert(var_to_formula.count(v) == 0);
21162116
uint32_t max_depth = 0;
21172117
var_to_formula[v] = recur(&r, v, used_vars, 0, max_depth);
2118+
if (mconf.inv_learnt)
2119+
var_to_formula[v] = fh->neg(var_to_formula[v]);
21182120
verb_print(1, "Training error: " << setprecision(2) << setw(6) << train_error << "%."
21192121
<< " depth: " << setw(6) << max_depth
21202122
<< " ones: " << setprecision(0) << fixed << setw(5) << (double)num_ones/samples.size()*100.0 << "%"

src/synth.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ const std::map<string, ParamDef> param_table = {
7676
{"force_bw_equal", {PT::Int, [](MC& c, const string& v) { c.force_bw_equal = parse_val<int>(v); }}},
7777
{"bva_xor_vars", {PT::Int, [](MC& c, const string& v) { c.bva_xor_vars = parse_val<int>(v); }}},
7878
{"silent_var_update", {PT::Int, [](MC& c, const string& v) { c.silent_var_update = parse_val<int>(v); }}},
79+
{"inv_learnt", {PT::Int, [](MC& c, const string& v) { c.inv_learnt = parse_val<int>(v); }}},
7980
};
8081
} // namespace
8182

0 commit comments

Comments
 (0)