Skip to content

Commit deab9d8

Browse files
committed
Temporarily disabled ParallelSolver
1 parent 2eda1ba commit deab9d8

File tree

1 file changed

+21
-8
lines changed

1 file changed

+21
-8
lines changed

kernel_tuner/searchspace.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,18 @@
1212
MaxProdConstraint,
1313
MinConflictsSolver,
1414
OptimizedBacktrackingSolver,
15-
ParallelSolver,
15+
# ParallelSolver,
1616
Problem,
1717
RecursiveBacktrackingSolver,
1818
Solver,
1919
)
2020

2121
from kernel_tuner.util import check_restrictions as check_instance_restrictions
22-
from kernel_tuner.util import compile_restrictions, default_block_size_names, get_interval
22+
from kernel_tuner.util import (
23+
compile_restrictions,
24+
default_block_size_names,
25+
get_interval,
26+
)
2327

2428
supported_neighbor_methods = ["strictly-adjacent", "adjacent", "Hamming"]
2529

@@ -69,7 +73,9 @@ def __init__(
6973
if (
7074
len(restrictions) > 0
7175
and any(isinstance(restriction, str) for restriction in restrictions)
72-
and not (framework_l == "pysmt" or framework_l == "bruteforce" or solver_method.lower() == "pc_parallelsolver")
76+
and not (
77+
framework_l == "pysmt" or framework_l == "bruteforce" or solver_method.lower() == "pc_parallelsolver"
78+
)
7379
):
7480
self.restrictions = compile_restrictions(
7581
restrictions,
@@ -101,7 +107,8 @@ def __init__(
101107
elif solver_method.lower() == "pc_optimizedbacktrackingsolver":
102108
solver = OptimizedBacktrackingSolver(forwardcheck=False)
103109
elif solver_method.lower() == "pc_parallelsolver":
104-
solver = ParallelSolver()
110+
raise NotImplementedError("ParallelSolver is not yet implemented")
111+
# solver = ParallelSolver()
105112
elif solver_method.lower() == "pc_recursivebacktrackingsolver":
106113
solver = RecursiveBacktrackingSolver()
107114
elif solver_method.lower() == "pc_minconflictssolver":
@@ -266,7 +273,7 @@ def all_smt(formula, keys) -> list:
266273

267274
def __build_searchspace_pyATF(self, block_size_names: list, max_threads: int, solver: Solver):
268275
"""Builds the searchspace using pyATF."""
269-
from pyatf import TP, Set, Interval, Tuner
276+
from pyatf import TP, Interval, Set, Tuner
270277
from pyatf.cost_functions.generic import CostFunction
271278
from pyatf.search_techniques import Exhaustive
272279

@@ -282,7 +289,9 @@ def __build_searchspace_pyATF(self, block_size_names: list, max_threads: int, so
282289
# adding the default blocksize restriction requires recompilation because pyATF requires combined restrictions for the same parameter
283290
max_block_size_product = f"{' * '.join(valid_block_size_names)} <= {max_threads}"
284291
restrictions = self._modified_restrictions.copy() + [max_block_size_product]
285-
self.restrictions = compile_restrictions(restrictions, self.tune_params, format="pyatf", try_to_constraint=False)
292+
self.restrictions = compile_restrictions(
293+
restrictions, self.tune_params, format="pyatf", try_to_constraint=False
294+
)
286295

287296
# build a dictionary of the restrictions, combined based on last parameter
288297
res_dict = dict()
@@ -295,7 +304,9 @@ def __build_searchspace_pyATF(self, block_size_names: list, max_threads: int, so
295304
continue
296305
if all(p in registered_params for p in params):
297306
if param in res_dict:
298-
raise KeyError(f"`{param}` is already in res_dict with `{res_dict[param][1]}`, can't add `{source}`")
307+
raise KeyError(
308+
f"`{param}` is already in res_dict with `{res_dict[param][1]}`, can't add `{source}`"
309+
)
299310
res_dict[param] = (res, source)
300311
print(source, res, param, params)
301312
registered_restrictions.append(index)
@@ -305,7 +316,9 @@ def get_params():
305316
params = list()
306317
for index, (key, values) in enumerate(self.tune_params.items()):
307318
vi = get_interval(values)
308-
vals = Interval(vi[0], vi[1], vi[2]) if vi is not None and vi[2] != 0 else Set(*np.array(values).flatten())
319+
vals = (
320+
Interval(vi[0], vi[1], vi[2]) if vi is not None and vi[2] != 0 else Set(*np.array(values).flatten())
321+
)
309322
constraint = res_dict.get(key, None)
310323
constraint_source = None
311324
if constraint is not None:

0 commit comments

Comments
 (0)