Skip to content

Commit 435b56b

Browse files
committed
Fixed torch import error due to Tensor type hint
1 parent 539aed3 commit 435b56b

File tree

1 file changed

+39
-26
lines changed

1 file changed

+39
-26
lines changed

kernel_tuner/searchspace.py

Lines changed: 39 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
try:
2222
import torch
2323
from torch import Tensor
24+
2425
torch_available = True
2526
except ImportError:
2627
torch_available = False
@@ -42,7 +43,7 @@ def __init__(
4243
block_size_names=default_block_size_names,
4344
build_neighbors_index=False,
4445
neighbor_method=None,
45-
from_cache: dict=None,
46+
from_cache: dict = None,
4647
framework="PythonConstraint",
4748
solver_method="PC_OptimizedBacktrackingSolver",
4849
path_to_ATF_cache: Path = None,
@@ -58,10 +59,14 @@ def __init__(
5859
"""
5960
# check the arguments
6061
if from_cache is not None:
61-
assert tune_params is None and restrictions is None and max_threads is None, "When `from_cache` is used, the positional arguments must be set to None."
62+
assert (
63+
tune_params is None and restrictions is None and max_threads is None
64+
), "When `from_cache` is used, the positional arguments must be set to None."
6265
tune_params = from_cache["tune_params"]
6366
if from_cache is None:
64-
assert tune_params is not None and restrictions is not None and max_threads is not None, "Must specify positional arugments ."
67+
assert (
68+
tune_params is not None and restrictions is not None and max_threads is not None
69+
), "Must specify positional arugments ."
6570

6671
# set the object attributes using the arguments
6772
framework_l = framework.lower()
@@ -77,9 +82,9 @@ def __init__(
7782
self._tensorspace_param_config_structure = []
7883
self._map_tensor_to_param = {}
7984
self._map_param_to_tensor = {}
80-
self.restrictions = restrictions.copy() if hasattr(restrictions, 'copy') else restrictions
85+
self.restrictions = restrictions.copy() if hasattr(restrictions, "copy") else restrictions
8186
# the searchspace can add commonly used constraints (e.g. maxprod(blocks) <= maxthreads)
82-
self._modified_restrictions = restrictions.copy() if hasattr(restrictions, 'copy') else restrictions
87+
self._modified_restrictions = restrictions.copy() if hasattr(restrictions, "copy") else restrictions
8388
self.param_names = list(self.tune_params.keys())
8489
self.params_values = tuple(tuple(param_vals) for param_vals in self.tune_params.values())
8590
self.params_values_indices = None
@@ -93,8 +98,12 @@ def __init__(
9398
restrictions = [restrictions] if not isinstance(restrictions, list) else restrictions
9499
if (
95100
len(restrictions) > 0
96-
and (any(isinstance(restriction, str) for restriction in restrictions)
97-
or any(isinstance(restriction[0], str) for restriction in restrictions if isinstance(restriction, tuple)))
101+
and (
102+
any(isinstance(restriction, str) for restriction in restrictions)
103+
or any(
104+
isinstance(restriction[0], str) for restriction in restrictions if isinstance(restriction, tuple)
105+
)
106+
)
98107
and not (framework_l == "pysmt" or framework_l == "bruteforce")
99108
):
100109
self.restrictions = compile_restrictions(
@@ -609,14 +618,14 @@ def get_param_configs_at_indices(self, indices: List[int]) -> List[tuple]:
609618
# map(get) is ~40% faster than numpy[indices] (average based on six searchspaces with 10000, 100000 and 1000000 configs and 10 or 100 random indices)
610619
return list(map(self.list.__getitem__, indices))
611620

612-
def get_param_config_index(self, param_config: Union[tuple, Tensor]):
621+
def get_param_config_index(self, param_config: Union[tuple, any]):
613622
"""Lookup the index for a parameter configuration, returns None if not found."""
614623
if torch_available and isinstance(param_config, Tensor):
615624
param_config = self.tensor_to_param_config(param_config)
616625
# constant time O(1) access - much faster than any other method, but needs a shadow dict of the search space
617626
return self.__dict.get(param_config, None)
618-
619-
def initialize_tensorspace(self, dtype = None, device = None):
627+
628+
def initialize_tensorspace(self, dtype=None, device=None):
620629
"""Encode the searchspace in a Tensor. Save the mapping. Call this function directly to control the precision or device used."""
621630
assert self._tensorspace is None, "Tensorspace is already initialized"
622631
skipped_count = 0
@@ -642,16 +651,16 @@ def initialize_tensorspace(self, dtype = None, device = None):
642651
if all(isinstance(v, numbers.Real) for v in param_values):
643652
tensor_values = torch.tensor(param_values, dtype=self.tensor_dtype)
644653
else:
645-
self._tensorspace_categorical_dimensions.append(index-skipped_count)
654+
self._tensorspace_categorical_dimensions.append(index - skipped_count)
646655
# tensor_values = np.arange(len(param_values))
647656
tensor_values = torch.arange(len(param_values), dtype=self.tensor_dtype)
648657

649658
# write the mappings to the object
650-
self._map_param_to_tensor[index] = (dict(zip(param_values, tensor_values.tolist())))
651-
self._map_tensor_to_param[index] = (dict(zip(tensor_values.tolist(), param_values)))
659+
self._map_param_to_tensor[index] = dict(zip(param_values, tensor_values.tolist()))
660+
self._map_tensor_to_param[index] = dict(zip(tensor_values.tolist(), param_values))
652661
bounds.append((tensor_values.min(), tensor_values.max()))
653662
if tensor_values.min() < tensor_values.max():
654-
self._tensorspace_bounds_indices.append(index-skipped_count)
663+
self._tensorspace_bounds_indices.append(index - skipped_count)
655664

656665
# do some checks
657666
assert len(self.params_values) == len(self._tensorspace_param_config_structure)
@@ -666,26 +675,26 @@ def initialize_tensorspace(self, dtype = None, device = None):
666675

667676
# set the bounds in the correct format (one array for the min, one for the max)
668677
bounds = torch.tensor(bounds, **self.tensor_kwargs)
669-
self._tensorspace_bounds = torch.cat([bounds[:,0], bounds[:,1]]).reshape((2, bounds.shape[0]))
670-
678+
self._tensorspace_bounds = torch.cat([bounds[:, 0], bounds[:, 1]]).reshape((2, bounds.shape[0]))
679+
671680
def get_tensorspace(self):
672681
"""Get the searchspace encoded in a Tensor. To use a non-default dtype or device, call `initialize_tensorspace` first."""
673682
if self._tensorspace is None:
674683
self.initialize_tensorspace()
675684
return self._tensorspace
676-
685+
677686
def get_tensorspace_categorical_dimensions(self):
678687
"""Get the a list of the categorical dimensions in the tensorspace."""
679688
return self._tensorspace_categorical_dimensions
680-
689+
681690
def param_config_to_tensor(self, param_config: tuple):
682691
"""Convert from a parameter configuration to a Tensor."""
683692
if len(self._map_param_to_tensor) == 0:
684693
self.initialize_tensorspace()
685694
array = []
686695
for i, param in enumerate(param_config):
687696
if self._tensorspace_param_config_structure[i] is not None:
688-
continue # skip over parameters not in the tensorspace
697+
continue # skip over parameters not in the tensorspace
689698
mapping = self._map_param_to_tensor[i]
690699
conversions = [None, str, float, int, bool]
691700
for c in conversions:
@@ -697,7 +706,7 @@ def param_config_to_tensor(self, param_config: tuple):
697706
if c == conversions[-1]:
698707
raise KeyError(f"No variant of {param} could be found in {mapping}") from e
699708
return torch.tensor(array, **self.tensor_kwargs)
700-
709+
701710
def tensor_to_param_config(self, tensor: Tensor):
702711
"""Convert from a Tensor to a parameter configuration."""
703712
assert tensor.dim() == 1, f"Parameter configuration tensor must be 1-dimensional, is {tensor.dim()} ({tensor})"
@@ -709,10 +718,10 @@ def tensor_to_param_config(self, tensor: Tensor):
709718
if param is not None:
710719
skip_counter += 1
711720
else:
712-
value = tensor[i-skip_counter].item()
721+
value = tensor[i - skip_counter].item()
713722
config[i] = self._map_tensor_to_param[i][value]
714723
return tuple(config)
715-
724+
716725
def get_tensorspace_bounds(self):
717726
"""Get the bounds to the tensorspace parameters, returned as a 2 x d dimensional tensor, and the indices of the parameters."""
718727
if self._tensorspace is None:
@@ -929,7 +938,7 @@ def order_param_configs(
929938
f"The number of ordered parameter configurations ({len(ordered_param_configs)}) differs from the original number of parameter configurations ({len(param_configs)})"
930939
)
931940
return ordered_param_configs
932-
941+
933942
def to_ax_searchspace(self):
934943
"""Convert this searchspace to an Ax SearchSpace."""
935944
from ax import ChoiceParameter, FixedParameter, ParameterType, SearchSpace
@@ -943,12 +952,14 @@ def to_ax_searchspace(self):
943952
continue
944953

945954
# convert the types
946-
assert all(isinstance(param_values[0], type(v)) for v in param_values), f"Parameter values of mixed types are not supported: {param_values}"
955+
assert all(
956+
isinstance(param_values[0], type(v)) for v in param_values
957+
), f"Parameter values of mixed types are not supported: {param_values}"
947958
param_type_mapping = {
948959
str: ParameterType.STRING,
949960
int: ParameterType.INT,
950961
float: ParameterType.FLOAT,
951-
bool: ParameterType.BOOL
962+
bool: ParameterType.BOOL,
952963
}
953964
param_type = param_type_mapping[type(param_values[0])]
954965

@@ -959,6 +970,8 @@ def to_ax_searchspace(self):
959970
ax_searchspace.add_parameter(ChoiceParameter(param_name, param_type, param_values))
960971

961972
# add the constraints
962-
raise NotImplementedError("Conversion to Ax SearchSpace has not been fully implemented as Ax Searchspaces can't capture full complexity.")
973+
raise NotImplementedError(
974+
"Conversion to Ax SearchSpace has not been fully implemented as Ax Searchspaces can't capture full complexity."
975+
)
963976

964977
return ax_searchspace

0 commit comments

Comments
 (0)