Skip to content

Commit f3b9991

Browse files
committed
Added parallel solver support
1 parent 5b6a0be commit f3b9991

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

kernel_tuner/searchspace.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
MaxProdConstraint,
1313
MinConflictsSolver,
1414
OptimizedBacktrackingSolver,
15+
ParallelSolver,
1516
Problem,
1617
RecursiveBacktrackingSolver,
1718
Solver,
@@ -57,6 +58,7 @@ def __init__(
5758
self.params_values = tuple(tuple(param_vals) for param_vals in self.tune_params.values())
5859
self.params_values_indices = None
5960
self.build_neighbors_index = build_neighbors_index
61+
self.solver_method = solver_method
6062
self.__neighbor_cache = dict()
6163
self.neighbor_method = neighbor_method
6264
if (neighbor_method is not None or build_neighbors_index) and neighbor_method not in supported_neighbor_methods:
@@ -67,7 +69,7 @@ def __init__(
6769
if (
6870
len(restrictions) > 0
6971
and any(isinstance(restriction, str) for restriction in restrictions)
70-
and not (framework_l == "pysmt" or framework_l == "bruteforce")
72+
and not (framework_l == "pysmt" or framework_l == "bruteforce" or solver_method.lower() == "pc_parallelsolver")
7173
):
7274
self.restrictions = compile_restrictions(
7375
restrictions,
@@ -98,6 +100,8 @@ def __init__(
98100
solver = BacktrackingSolver()
99101
elif solver_method.lower() == "pc_optimizedbacktrackingsolver":
100102
solver = OptimizedBacktrackingSolver(forwardcheck=False)
103+
elif solver_method.lower() == "pc_parallelsolver":
104+
solver = ParallelSolver()
101105
elif solver_method.lower() == "pc_recursivebacktrackingsolver":
102106
solver = RecursiveBacktrackingSolver()
103107
elif solver_method.lower() == "pc_minconflictssolver":
@@ -407,6 +411,8 @@ def __add_restrictions(self, parameter_space: Problem) -> Problem:
407411
elif isinstance(restriction, Constraint):
408412
all_params_required = all(param_name in required_params for param_name in self.param_names)
409413
parameter_space.addConstraint(restriction, None if all_params_required else required_params)
414+
elif isinstance(restriction, str) and self.solver_method.lower() == "pc_parallelsolver":
415+
parameter_space.addConstraint(restriction)
410416
else:
411417
raise ValueError(f"Unrecognized restriction {restriction}")
412418

0 commit comments

Comments
 (0)