@@ -79,6 +79,7 @@ def __init__(
79
79
framework_l = framework .lower ()
80
80
restrictions = restrictions if restrictions is not None else []
81
81
self .tune_params = tune_params
82
+ self .original_tune_params = tune_params .copy () if hasattr (tune_params , "copy" ) else tune_params
82
83
self .max_threads = max_threads
83
84
self .block_size_names = block_size_names
84
85
self ._tensorspace = None
@@ -92,6 +93,7 @@ def __init__(
92
93
self ._map_tensor_to_param = {}
93
94
self ._map_param_to_tensor = {}
94
95
self .restrictions = restrictions .copy () if hasattr (restrictions , "copy" ) else restrictions
96
+ self .original_restrictions = restrictions .copy () if hasattr (restrictions , "copy" ) else restrictions
95
97
# the searchspace can add commonly used constraints (e.g. maxprod(blocks) <= maxthreads)
96
98
self ._modified_restrictions = restrictions .copy () if hasattr (restrictions , "copy" ) else restrictions
97
99
self .param_names = list (self .tune_params .keys ())
@@ -100,6 +102,7 @@ def __init__(
100
102
self .build_neighbors_index = build_neighbors_index
101
103
self .solver_method = solver_method
102
104
self .__neighbor_cache = { method : dict () for method in supported_neighbor_methods }
105
+ self .neighbors_index = dict ()
103
106
self .neighbor_method = neighbor_method
104
107
if (neighbor_method is not None or build_neighbors_index ) and neighbor_method not in supported_neighbor_methods :
105
108
raise ValueError (f"Neighbor method is { neighbor_method } , must be one of { supported_neighbor_methods } " )
@@ -175,7 +178,7 @@ def __init__(
175
178
if neighbor_method is not None and neighbor_method != "Hamming" :
176
179
self .__prepare_neighbors_index ()
177
180
if build_neighbors_index :
178
- self .neighbors_index = self .__build_neighbors_index (neighbor_method )
181
+ self .neighbors_index [ neighbor_method ] = self .__build_neighbors_index (neighbor_method )
179
182
180
183
# def __build_searchspace_ortools(self, block_size_names: list, max_threads: int) -> Tuple[List[tuple], np.ndarray, dict, int]:
181
184
# # Based on https://developers.google.com/optimization/cp/cp_solver#python_2
@@ -452,7 +455,8 @@ def __build_searchspace(self, block_size_names: list, max_threads: int, solver:
452
455
# add the user-specified restrictions as constraints on the parameter space
453
456
if not isinstance (self .restrictions , (list , tuple )):
454
457
self .restrictions = [self .restrictions ]
455
- self .restrictions = convert_constraint_lambdas (self .restrictions )
458
+ if any (not isinstance (restriction , (Constraint , FunctionConstraint , str )) for restriction in self .restrictions ):
459
+ self .restrictions = convert_constraint_lambdas (self .restrictions )
456
460
parameter_space = self .__add_restrictions (parameter_space )
457
461
458
462
# add the default blocksize threads restrictions last, because it is unlikely to reduce the parameter space by much
@@ -901,24 +905,25 @@ def get_random_sample(self, num_samples: int) -> List[tuple]:
901
905
num_samples = self .size
902
906
return self .get_param_configs_at_indices (self .get_random_sample_indices (num_samples ))
903
907
904
- def get_neighbors_indices_no_cache (self , param_config : tuple , neighbor_method = None ) -> List [int ]:
908
+ def get_neighbors_indices_no_cache (self , param_config : tuple , neighbor_method = None , build_full_cache = False ) -> List [int ]:
905
909
"""Get the neighbors indices for a parameter configuration (does not check running cache, useful when mixing neighbor methods)."""
906
910
param_config_index = self .get_param_config_index (param_config )
907
911
908
- # this is the simplest case, just return the cached value
909
- if self .build_neighbors_index and param_config_index is not None :
910
- if neighbor_method is not None and neighbor_method != self .neighbor_method :
911
- raise ValueError (
912
- f"The neighbor method { neighbor_method } differs from the neighbor method { self .neighbor_method } initially used for indexing"
913
- )
914
- return self .neighbors_index [param_config_index ]
915
-
916
912
# check if there is a neighbor method to use
917
913
if neighbor_method is None :
918
914
if self .neighbor_method is None :
919
915
raise ValueError ("Neither the neighbor_method argument nor self.neighbor_method was set" )
920
916
neighbor_method = self .neighbor_method
921
917
918
+ # this is the simplest case, just return the cached value
919
+ if param_config_index is not None :
920
+ if neighbor_method in self .neighbors_index :
921
+ return self .neighbors_index [neighbor_method ][param_config_index ]
922
+ elif build_full_cache :
923
+ # build the neighbors index for the given neighbor method
924
+ self .neighbors_index [neighbor_method ] = self .__build_neighbors_index (neighbor_method )
925
+ return self .neighbors_index [neighbor_method ][param_config_index ]
926
+
922
927
if neighbor_method == "Hamming" :
923
928
return self .__get_neighbors_indices_hamming (param_config )
924
929
@@ -933,7 +938,7 @@ def get_neighbors_indices_no_cache(self, param_config: tuple, neighbor_method=No
933
938
return self .__get_neighbors_indices_adjacent (param_config_index , param_config )
934
939
raise ValueError (f"The neighbor method { neighbor_method } is not in { supported_neighbor_methods } " )
935
940
936
- def get_neighbors_indices (self , param_config : tuple , neighbor_method = None ) -> List [int ]:
941
+ def get_neighbors_indices (self , param_config : tuple , neighbor_method = None , build_full_cache = False ) -> List [int ]:
937
942
"""Get the neighbors indices for a parameter configuration, cached if requested before."""
938
943
if neighbor_method is None :
939
944
neighbor_method = self .neighbor_method
@@ -942,7 +947,7 @@ def get_neighbors_indices(self, param_config: tuple, neighbor_method=None) -> Li
942
947
neighbors = self .__neighbor_cache [neighbor_method ].get (param_config , None )
943
948
# if there are no cached neighbors, compute them
944
949
if neighbors is None :
945
- neighbors = self .get_neighbors_indices_no_cache (param_config , neighbor_method )
950
+ neighbors = self .get_neighbors_indices_no_cache (param_config , neighbor_method , build_full_cache )
946
951
self .__neighbor_cache [neighbor_method ][param_config ] = neighbors
947
952
return neighbors
948
953
@@ -958,9 +963,9 @@ def get_neighbors_no_cache(self, param_config: tuple, neighbor_method=None) -> L
958
963
"""Get the neighbors for a parameter configuration (does not check running cache, useful when mixing neighbor methods)."""
959
964
return self .get_param_configs_at_indices (self .get_neighbors_indices_no_cache (param_config , neighbor_method ))
960
965
961
- def get_neighbors (self , param_config : tuple , neighbor_method = None ) -> List [tuple ]:
966
+ def get_neighbors (self , param_config : tuple , neighbor_method = None , build_full_cache = False ) -> List [tuple ]:
962
967
"""Get the neighbors for a parameter configuration."""
963
- return self .get_param_configs_at_indices (self .get_neighbors_indices (param_config , neighbor_method ))
968
+ return self .get_param_configs_at_indices (self .get_neighbors_indices (param_config , neighbor_method , build_full_cache ))
964
969
965
970
def get_param_neighbors (self , param_config : tuple , index : int , neighbor_method : str , randomize : bool ) -> list :
966
971
"""Get the neighboring parameters at an index."""
0 commit comments