Skip to content

Commit 7b5cd29

Browse files
committed
Improved passing of restrictions
1 parent 05caa5f commit 7b5cd29

File tree

3 files changed

+20
-18
lines changed

3 files changed

+20
-18
lines changed

kernel_tuner/interface.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -646,7 +646,6 @@ def tune_kernel(
646646

647647
# ensure strategy_options is an Options object
648648
tuning_options.strategy_options = Options(strategy_options or {})
649-
650649
# if no strategy selected
651650
else:
652651
strategy = brute_force

kernel_tuner/searchspace.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from random import choice, shuffle
66
from typing import List, Union
77
from warnings import warn
8+
from copy import deepcopy
89

910
import numpy as np
1011
from constraint import (
@@ -92,10 +93,11 @@ def __init__(
9293
self._tensorspace_param_config_structure = []
9394
self._map_tensor_to_param = {}
9495
self._map_param_to_tensor = {}
95-
self.restrictions = restrictions.copy() if hasattr(restrictions, "copy") else restrictions
96-
self.original_restrictions = restrictions.copy() if hasattr(restrictions, "copy") else restrictions
96+
restrictions = list(restrictions) if not isinstance(restrictions, (list, tuple)) else restrictions
97+
self.restrictions = deepcopy(restrictions)
98+
self.original_restrictions = deepcopy(restrictions) # keep the original restrictions, so that the searchspace can be modified later
9799
# the searchspace can add commonly used constraints (e.g. maxprod(blocks) <= maxthreads)
98-
self._modified_restrictions = restrictions.copy() if hasattr(restrictions, "copy") else restrictions
100+
self._modified_restrictions = deepcopy(restrictions)
99101
self.param_names = list(self.tune_params.keys())
100102
self.params_values = tuple(tuple(param_vals) for param_vals in self.tune_params.values())
101103
self.params_values_indices = None
@@ -479,8 +481,9 @@ def __build_searchspace(self, block_size_names: list, max_threads: int, solver:
479481

480482
def __add_restrictions(self, parameter_space: Problem) -> Problem:
481483
"""Add the user-specified restrictions as constraints on the parameter space."""
482-
if isinstance(self.restrictions, list):
483-
for restriction in self.restrictions:
484+
restrictions = deepcopy(self.restrictions)
485+
if isinstance(restrictions, list):
486+
for restriction in restrictions:
484487
required_params = self.param_names
485488

486489
# (un)wrap where necessary
@@ -510,14 +513,14 @@ def __add_restrictions(self, parameter_space: Problem) -> Problem:
510513
raise ValueError(f"Unrecognized restriction type {type(restriction)} ({restriction})")
511514

512515
# if the restrictions are the old monolithic function, apply them directly (only for backwards compatibility, likely slower than well-specified constraints!)
513-
elif callable(self.restrictions):
516+
elif callable(restrictions):
514517

515518
def restrictions_wrapper(*args):
516-
return check_instance_restrictions(self.restrictions, dict(zip(self.param_names, args)), False)
519+
return check_instance_restrictions(restrictions, dict(zip(self.param_names, args)), False)
517520

518521
parameter_space.addConstraint(FunctionConstraint(restrictions_wrapper), self.param_names)
519-
elif self.restrictions is not None:
520-
raise ValueError(f"The restrictions are of unsupported type {type(self.restrictions)}")
522+
elif restrictions is not None:
523+
raise ValueError(f"The restrictions are of unsupported type {type(restrictions)}")
521524
return parameter_space
522525

523526
def __parse_restrictions_pysmt(self, restrictions: list, tune_params: dict, symbols: dict):

test/strategies/test_strategies.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,11 @@ def test_strategies(vector_add, strategy):
6767
if strategy != "brute_force":
6868
filter_options["max_fevals"] = 10
6969

70-
restrictions = ["test_string == 'alg_2'", "test_bool == True", "test_mixed == 2.45"]
70+
restrictions = [
71+
"test_string == 'alg_2'",
72+
"test_bool == True",
73+
"test_mixed == 2.45"
74+
]
7175

7276
# pyATF can't handle non-number tune parameters, so we filter them out
7377
cache_filename_local = cache_filename
@@ -81,6 +85,7 @@ def test_strategies(vector_add, strategy):
8185

8286
# run the tuning in simulation mode
8387
assert cache_filename_local.exists()
88+
assert restrictions is not None
8489
results, _ = kernel_tuner.tune_kernel(*vector_add, restrictions=restrictions, strategy=strategy, strategy_options=filter_options,
8590
verbose=False, cache=cache_filename_local, simulation_mode=True)
8691

@@ -123,15 +128,10 @@ def test_strategies(vector_add, strategy):
123128
x0 = [256]
124129
filter_options["x0"] = x0
125130
if not strategy in ["brute_force", "random_sample", "bayes_opt"]:
126-
results, _ = kernel_tuner.tune_kernel(*vector_add, strategy=strategy, strategy_options=filter_options,
131+
results, _ = kernel_tuner.tune_kernel(*vector_add, restrictions=restrictions, strategy=strategy, strategy_options=filter_options,
127132
verbose=False, cache=cache_filename, simulation_mode=True)
128133
assert results[0]["block_size_x"] == x0[0]
129134
else:
130135
with pytest.raises(ValueError):
131-
results, _ = kernel_tuner.tune_kernel(*vector_add, strategy=strategy, strategy_options=filter_options,
136+
results, _ = kernel_tuner.tune_kernel(*vector_add, restrictions=restrictions, strategy=strategy, strategy_options=filter_options,
132137
verbose=False, cache=cache_filename, simulation_mode=True)
133-
134-
135-
136-
137-

0 commit comments

Comments
 (0)