@@ -99,7 +99,7 @@ def __init__(
99
99
self .params_values_indices = None
100
100
self .build_neighbors_index = build_neighbors_index
101
101
self .solver_method = solver_method
102
- self .__neighbor_cache = dict ()
102
+ self .__neighbor_cache = { method : dict () for method in supported_neighbor_methods }
103
103
self .neighbor_method = neighbor_method
104
104
if (neighbor_method is not None or build_neighbors_index ) and neighbor_method not in supported_neighbor_methods :
105
105
raise ValueError (f"Neighbor method is { neighbor_method } , must be one of { supported_neighbor_methods } " )
@@ -934,24 +934,25 @@ def get_neighbors_indices_no_cache(self, param_config: tuple, neighbor_method=No
934
934
raise ValueError (f"The neighbor method { neighbor_method } is not in { supported_neighbor_methods } " )
935
935
936
936
def get_neighbors_indices (self , param_config : tuple , neighbor_method = None ) -> List [int ]:
937
- """Get the neighbors indices for a parameter configuration, possibly cached."""
938
- neighbors = self .__neighbor_cache .get (param_config , None )
937
+ """Get the neighbors indices for a parameter configuration, cached if requested before."""
938
+ if neighbor_method is None :
939
+ neighbor_method = self .neighbor_method
940
+ if neighbor_method is None :
941
+ raise ValueError ("Neither the neighbor_method argument nor self.neighbor_method was set" )
942
+ neighbors = self .__neighbor_cache [neighbor_method ].get (param_config , None )
939
943
# if there are no cached neighbors, compute them
940
944
if neighbors is None :
941
945
neighbors = self .get_neighbors_indices_no_cache (param_config , neighbor_method )
942
- self .__neighbor_cache [param_config ] = neighbors
943
- # if the neighbors were cached but the specified neighbor method was different than the one initially used to build the cache, throw an error
944
- elif (
945
- self .neighbor_method is not None and neighbor_method is not None and self .neighbor_method != neighbor_method
946
- ):
947
- raise ValueError (
948
- f"The neighbor method { neighbor_method } differs from the intially set { self .neighbor_method } , can not use cached neighbors. Use 'get_neighbors_no_cache()' when mixing neighbor methods to avoid this."
949
- )
946
+ self .__neighbor_cache [neighbor_method ][param_config ] = neighbors
950
947
return neighbors
951
948
952
- def are_neighbors_indices_cached (self , param_config : tuple ) -> bool :
949
+ def are_neighbors_indices_cached (self , param_config : tuple , neighbor_method = None ) -> bool :
953
950
"""Returns true if the neighbor indices are in the cache, false otherwise."""
954
- return param_config in self .__neighbor_cache
951
+ if neighbor_method is None :
952
+ neighbor_method = self .neighbor_method
953
+ if neighbor_method is None :
954
+ raise ValueError ("Neither the neighbor_method argument nor self.neighbor_method was set" )
955
+ return param_config in self .__neighbor_cache [neighbor_method ]
955
956
956
957
def get_neighbors_no_cache (self , param_config : tuple , neighbor_method = None ) -> List [tuple ]:
957
958
"""Get the neighbors for a parameter configuration (does not check running cache, useful when mixing neighbor methods)."""
0 commit comments