Skip to content

Commit ce96ec5

Browse files
committed
Fix bug in searchspace.py causing it be incompatible with numpy 1.24
1 parent d6acee3 commit ce96ec5

File tree

1 file changed

+7
-9
lines changed

1 file changed

+7
-9
lines changed

kernel_tuner/searchspace.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -192,24 +192,22 @@ def __get_neighbors_indices_adjacent(self, param_config_index: int = None, param
192192
matching_indices = np.setdiff1d(matching_indices, [param_config_index], assume_unique=False)
193193
return matching_indices
194194

195-
def __build_neighbors_index(self, neighbor_method) -> np.ndarray:
195+
def __build_neighbors_index(self, neighbor_method) -> List[List[int]]:
196196
"""build an index of the neighbors for each parameter configuration"""
197197
# for Hamming no preperation is necessary, find the neighboring parameter configurations
198198
if neighbor_method == "Hamming":
199-
return np.array(list(self.__get_neighbors_indices_hamming(param_config) for param_config in self.list))
199+
return list(self.__get_neighbors_indices_hamming(param_config) for param_config in self.list)
200200

201201
# for each parameter configuration, find the neighboring parameter configurations
202202
if self.params_values_indices is None:
203203
self.__prepare_neighbors_index()
204204
if neighbor_method == "strictly-adjacent":
205-
return np.array(
206-
list(
207-
self.__get_neighbors_indices_strictlyadjacent(param_config_index, param_config)
208-
for param_config_index, param_config in enumerate(self.list)))
205+
return list(self.__get_neighbors_indices_strictlyadjacent(param_config_index, param_config) for param_config_index, param_config in enumerate(self.list))
206+
209207
if neighbor_method == "adjacent":
210-
return np.array(
211-
list(self.__get_neighbors_indices_adjacent(param_config_index, param_config) for param_config_index, param_config in enumerate(self.list)))
212-
raise NotImplementedError()
208+
return list(self.__get_neighbors_indices_adjacent(param_config_index, param_config) for param_config_index, param_config in enumerate(self.list))
209+
210+
raise NotImplementedError(f"The neighbor method {neighbor_method} is not implemented")
213211

214212
def get_random_sample_indices(self, num_samples: int) -> np.ndarray:
215213
"""Get the list indices for a random, non-conflicting sample"""

0 commit comments

Comments
 (0)