@@ -64,7 +64,7 @@ def __init__(
6464 self .params_values_indices = None
6565 self .build_neighbors_index = build_neighbors_index
6666 self .solver_method = solver_method
67- self .__neighbor_cache = dict ()
67+ self .__neighbor_cache = { method : dict () for method in supported_neighbor_methods }
6868 self .neighbor_method = neighbor_method
6969 if (neighbor_method is not None or build_neighbors_index ) and neighbor_method not in supported_neighbor_methods :
7070 raise ValueError (f"Neighbor method is { neighbor_method } , must be one of { supported_neighbor_methods } " )
@@ -758,24 +758,25 @@ def get_neighbors_indices_no_cache(self, param_config: tuple, neighbor_method=No
758758 raise ValueError (f"The neighbor method { neighbor_method } is not in { supported_neighbor_methods } " )
759759
760760 def get_neighbors_indices (self , param_config : tuple , neighbor_method = None ) -> List [int ]:
761- """Get the neighbors indices for a parameter configuration, possibly cached."""
762- neighbors = self .__neighbor_cache .get (param_config , None )
761+ """Get the neighbors indices for a parameter configuration, cached if requested before."""
762+ if neighbor_method is None :
763+ neighbor_method = self .neighbor_method
764+ if neighbor_method is None :
765+ raise ValueError ("Neither the neighbor_method argument nor self.neighbor_method was set" )
766+ neighbors = self .__neighbor_cache [neighbor_method ].get (param_config , None )
763767 # if there are no cached neighbors, compute them
764768 if neighbors is None :
765769 neighbors = self .get_neighbors_indices_no_cache (param_config , neighbor_method )
766- self .__neighbor_cache [param_config ] = neighbors
767- # if the neighbors were cached but the specified neighbor method was different than the one initially used to build the cache, throw an error
768- elif (
769- self .neighbor_method is not None and neighbor_method is not None and self .neighbor_method != neighbor_method
770- ):
771- raise ValueError (
772- f"The neighbor method { neighbor_method } differs from the intially set { self .neighbor_method } , can not use cached neighbors. Use 'get_neighbors_no_cache()' when mixing neighbor methods to avoid this."
773- )
770+ self .__neighbor_cache [neighbor_method ][param_config ] = neighbors
774771 return neighbors
775772
776- def are_neighbors_indices_cached (self , param_config : tuple ) -> bool :
773+ def are_neighbors_indices_cached (self , param_config : tuple , neighbor_method = None ) -> bool :
777774 """Returns true if the neighbor indices are in the cache, false otherwise."""
778- return param_config in self .__neighbor_cache
775+ if neighbor_method is None :
776+ neighbor_method = self .neighbor_method
777+ if neighbor_method is None :
778+ raise ValueError ("Neither the neighbor_method argument nor self.neighbor_method was set" )
779+ return param_config in self .__neighbor_cache [neighbor_method ]
779780
780781 def get_neighbors_no_cache (self , param_config : tuple , neighbor_method = None ) -> List [tuple ]:
781782 """Get the neighbors for a parameter configuration (does not check running cache, useful when mixing neighbor methods)."""
0 commit comments