Skip to content

Commit f883324

Browse files
authored
Merge pull request #245 from alwilson/dist_retry_fallback_243
swizzler: add dist constraint retry and fallback
2 parents add71d8 + 7f16613 commit f883324

File tree

1 file changed

+61
-38
lines changed

1 file changed

+61
-38
lines changed

src/vsc/model/solvegroup_swizzler_partsel.py

Lines changed: 61 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def swizzle_field_l(self, field_l, rs : RandSet, bound_m, btor):
7171
if len(field_l) > 0:
7272
field_idx = self.randstate.randint(0, len(field_l)-1)
7373
f = field_l.pop(field_idx)
74-
e_l = self.swizzle_field(f, rs, bound_m)
74+
e_l = self.swizzle_field(f, rs, bound_m, btor)
7575
if e_l is not None:
7676
for e in e_l:
7777
swizzle_node_l.append(e.build(btor))
@@ -104,52 +104,75 @@ def swizzle_field_l(self, field_l, rs : RandSet, bound_m, btor):
104104
def swizzle_field(self,
105105
f : FieldScalarModel,
106106
rs : RandSet,
107-
bound_m : VariableBoundModel)->ExprModel:
107+
bound_m : VariableBoundModel,
108+
btor
109+
) -> list[ExprModel]:
108110
ret = None
109-
111+
110112
if self.debug > 0:
111113
print("Swizzling field %s" % f.name)
112-
114+
113115
if f in rs.dist_field_m.keys():
116+
max_dist_samples = 4
117+
for _ in range(max_dist_samples):
118+
e = self.sample_dist_weights(f, rs)
119+
n = e.build(btor)
120+
btor.Assume(n)
121+
if self.solve_info is not None:
122+
self.solve_info.n_sat_calls += 1
123+
if btor.Sat() == btor.SAT:
124+
if self.debug > 0: print(" Dist constraint SAT")
125+
btor.Assert(n)
126+
return ret
127+
else:
128+
if self.debug > 0: print(" Dist constraint UNSAT")
129+
if self.debug > 0: print(" max_dist_samples exceeded, falling back to rand domain")
130+
131+
if f in bound_m.keys():
132+
f_bound = bound_m[f]
133+
if not f_bound.isEmpty():
134+
ret = self.create_rand_domain_constraint(f, f_bound)
135+
136+
return ret
137+
138+
def sample_dist_weights(self,
139+
f : FieldScalarModel,
140+
rs : RandSet,
141+
) -> ExprModel:
142+
if self.debug > 0:
143+
print("Note: field %s is in dist map" % f.name)
144+
for d in rs.dist_field_m[f]:
145+
print(" Weight list %s" % d.weight_list)
146+
147+
if len(rs.dist_field_m[f]) > 1:
148+
target_d = self.randstate.randint(0, len(rs.dist_field_m[f])-1)
149+
dist_scope_c = rs.dist_field_m[f][target_d]
150+
else:
151+
dist_scope_c = rs.dist_field_m[f][0]
152+
153+
target_range = dist_scope_c.next_target_range(self.randstate)
154+
target_w = dist_scope_c.dist_c.weights[target_range]
155+
if target_w.rng_rhs is not None:
156+
# Dual-bound range
157+
val_l = target_w.rng_lhs.val()
158+
val_r = target_w.rng_rhs.val()
159+
val = self.randstate.randint(val_l, val_r)
114160
if self.debug > 0:
115-
print("Note: field %s is in dist map" % f.name)
116-
for d in rs.dist_field_m[f]:
117-
print(" Weight list %s" % d.weight_list)
118-
if len(rs.dist_field_m[f]) > 1:
119-
target_d = self.randstate.randint(0, len(rs.dist_field_m[f])-1)
120-
dist_scope_c = rs.dist_field_m[f][target_d]
121-
else:
122-
dist_scope_c = rs.dist_field_m[f][0]
123-
124-
target_range = dist_scope_c.next_target_range(self.randstate)
125-
target_w = dist_scope_c.dist_c.weights[target_range]
126-
if target_w.rng_rhs is not None:
127-
# Dual-bound range
128-
val_l = target_w.rng_lhs.val()
129-
val_r = target_w.rng_rhs.val()
130-
val = self.randstate.randint(val_l, val_r)
131-
if self.debug > 0:
132-
print("Select dist-weight range: %d..%d ; specific value %d" % (
133-
int(val_l), int(val_r), int(val)))
134-
ret = [ExprBinModel(
161+
print("Select dist-weight range: %d..%d ; specific value %d" % (
162+
int(val_l), int(val_r), int(val)))
163+
ret = ExprBinModel(
135164
ExprFieldRefModel(f),
136165
BinExprType.Eq,
137-
ExprLiteralModel(val, f.is_signed, f.width))]
138-
else:
139-
# Single value
140-
val = target_w.rng_lhs.val()
141-
if self.debug > 0:
142-
print("Select dist-weight value %d" % (int(val)))
143-
ret = [ExprBinModel(
166+
ExprLiteralModel(val, f.is_signed, f.width))
167+
else:
168+
# Single value
169+
val = target_w.rng_lhs.val()
170+
if self.debug > 0:
171+
print("Select dist-weight value %d" % (int(val)))
172+
ret = ExprBinModel(
144173
ExprFieldRefModel(f),
145174
BinExprType.Eq,
146-
ExprLiteralModel(int(val), f.is_signed, f.width))]
147-
else:
148-
if f in bound_m.keys():
149-
f_bound = bound_m[f]
150-
if not f_bound.isEmpty():
151-
ret = self.create_rand_domain_constraint(f, f_bound)
152-
175+
ExprLiteralModel(int(val), f.is_signed, f.width))
153176
return ret
154177

155178
def create_rand_domain_constraint(self,

0 commit comments

Comments
 (0)