@@ -93,7 +93,7 @@ def __init__(
93
93
self ._tensorspace_param_config_structure = []
94
94
self ._map_tensor_to_param = {}
95
95
self ._map_param_to_tensor = {}
96
- restrictions = list ( restrictions ) if not isinstance (restrictions , (list , tuple )) else restrictions
96
+ restrictions = [ restrictions ] if not isinstance (restrictions , (list , tuple )) else restrictions
97
97
self .restrictions = deepcopy (restrictions )
98
98
self .original_restrictions = deepcopy (restrictions ) # keep the original restrictions, so that the searchspace can be modified later
99
99
# the searchspace can add commonly used constraints (e.g. maxprod(blocks) <= maxthreads)
@@ -687,7 +687,15 @@ def get_list_numpy(self) -> np.ndarray:
687
687
688
688
def get_param_indices (self , param_config : tuple ) -> tuple :
689
689
"""For each parameter value in the param config, find the index in the tunable parameters."""
690
- return tuple (self .params_values [index ].index (param_value ) for index , param_value in enumerate (param_config ))
690
+ try :
691
+ return tuple (self .params_values [index ].index (param_value ) for index , param_value in enumerate (param_config ))
692
+ except ValueError as e :
693
+ for index , param_value in enumerate (param_config ):
694
+ if param_value not in self .params_values [index ]:
695
+ # if the parameter value is not in the list of values for that parameter, raise an error
696
+ raise ValueError (
697
+ f"Parameter value { param_value } ({ type (param_value )} ) is not in the list of values { self .params_values [index ]} "
698
+ ) from e
691
699
692
700
def get_param_configs_at_indices (self , indices : List [int ]) -> List [tuple ]:
693
701
"""Get the param configs at the given indices."""
@@ -753,9 +761,13 @@ def initialize_tensorspace(self, dtype=None, device=None):
753
761
bounds = torch .tensor (bounds , ** self .tensor_kwargs )
754
762
self ._tensorspace_bounds = torch .cat ([bounds [:, 0 ], bounds [:, 1 ]]).reshape ((2 , bounds .shape [0 ]))
755
763
764
+ def has_tensorspace (self ) -> bool :
765
+ """Check if the tensorspace has been initialized."""
766
+ return self ._tensorspace is not None
767
+
756
768
def get_tensorspace (self ):
757
769
"""Get the searchspace encoded in a Tensor. To use a non-default dtype or device, call `initialize_tensorspace` first."""
758
- if self ._tensorspace is None :
770
+ if not self .has_tensorspace () :
759
771
self .initialize_tensorspace ()
760
772
return self ._tensorspace
761
773
@@ -800,7 +812,7 @@ def tensor_to_param_config(self, tensor):
800
812
801
813
def get_tensorspace_bounds (self ):
802
814
"""Get the bounds to the tensorspace parameters, returned as a 2 x d dimensional tensor, and the indices of the parameters."""
803
- if self ._tensorspace is None :
815
+ if not self .has_tensorspace () :
804
816
self .initialize_tensorspace ()
805
817
return self ._tensorspace_bounds , self ._tensorspace_bounds_indices
806
818
0 commit comments