Skip to content

Commit 8da11a7

Browse files
committed
Searchspace object improvements in checking for tensorspace and error messaging
1 parent 7e4f38c commit 8da11a7

File tree

1 file changed

+16
-4
lines changed

1 file changed

+16
-4
lines changed

kernel_tuner/searchspace.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def __init__(
9393
self._tensorspace_param_config_structure = []
9494
self._map_tensor_to_param = {}
9595
self._map_param_to_tensor = {}
96-
restrictions = list(restrictions) if not isinstance(restrictions, (list, tuple)) else restrictions
96+
restrictions = [restrictions] if not isinstance(restrictions, (list, tuple)) else restrictions
9797
self.restrictions = deepcopy(restrictions)
9898
self.original_restrictions = deepcopy(restrictions) # keep the original restrictions, so that the searchspace can be modified later
9999
# the searchspace can add commonly used constraints (e.g. maxprod(blocks) <= maxthreads)
@@ -687,7 +687,15 @@ def get_list_numpy(self) -> np.ndarray:
687687

688688
def get_param_indices(self, param_config: tuple) -> tuple:
689689
"""For each parameter value in the param config, find the index in the tunable parameters."""
690-
return tuple(self.params_values[index].index(param_value) for index, param_value in enumerate(param_config))
690+
try:
691+
return tuple(self.params_values[index].index(param_value) for index, param_value in enumerate(param_config))
692+
except ValueError as e:
693+
for index, param_value in enumerate(param_config):
694+
if param_value not in self.params_values[index]:
695+
# if the parameter value is not in the list of values for that parameter, raise an error
696+
raise ValueError(
697+
f"Parameter value {param_value} ({type(param_value)}) is not in the list of values {self.params_values[index]}"
698+
) from e
691699

692700
def get_param_configs_at_indices(self, indices: List[int]) -> List[tuple]:
693701
"""Get the param configs at the given indices."""
@@ -753,9 +761,13 @@ def initialize_tensorspace(self, dtype=None, device=None):
753761
bounds = torch.tensor(bounds, **self.tensor_kwargs)
754762
self._tensorspace_bounds = torch.cat([bounds[:, 0], bounds[:, 1]]).reshape((2, bounds.shape[0]))
755763

764+
def has_tensorspace(self) -> bool:
765+
"""Check if the tensorspace has been initialized."""
766+
return self._tensorspace is not None
767+
756768
def get_tensorspace(self):
757769
"""Get the searchspace encoded in a Tensor. To use a non-default dtype or device, call `initialize_tensorspace` first."""
758-
if self._tensorspace is None:
770+
if not self.has_tensorspace():
759771
self.initialize_tensorspace()
760772
return self._tensorspace
761773

@@ -800,7 +812,7 @@ def tensor_to_param_config(self, tensor):
800812

801813
def get_tensorspace_bounds(self):
802814
"""Get the bounds to the tensorspace parameters, returned as a 2 x d dimensional tensor, and the indices of the parameters."""
803-
if self._tensorspace is None:
815+
if not self.has_tensorspace():
804816
self.initialize_tensorspace()
805817
return self._tensorspace_bounds, self._tensorspace_bounds_indices
806818

0 commit comments

Comments
 (0)