@@ -64,7 +64,7 @@ def __init__(
64
64
self .params_values_indices = None
65
65
self .build_neighbors_index = build_neighbors_index
66
66
self .solver_method = solver_method
67
- self .__neighbor_cache = dict ()
67
+ self .__neighbor_cache = { method : dict () for method in supported_neighbor_methods }
68
68
self .neighbor_method = neighbor_method
69
69
if (neighbor_method is not None or build_neighbors_index ) and neighbor_method not in supported_neighbor_methods :
70
70
raise ValueError (f"Neighbor method is { neighbor_method } , must be one of { supported_neighbor_methods } " )
@@ -758,24 +758,25 @@ def get_neighbors_indices_no_cache(self, param_config: tuple, neighbor_method=No
758
758
raise ValueError (f"The neighbor method { neighbor_method } is not in { supported_neighbor_methods } " )
759
759
760
760
def get_neighbors_indices (self , param_config : tuple , neighbor_method = None ) -> List [int ]:
761
- """Get the neighbors indices for a parameter configuration, possibly cached."""
762
- neighbors = self .__neighbor_cache .get (param_config , None )
761
+ """Get the neighbors indices for a parameter configuration, cached if requested before."""
762
+ if neighbor_method is None :
763
+ neighbor_method = self .neighbor_method
764
+ if neighbor_method is None :
765
+ raise ValueError ("Neither the neighbor_method argument nor self.neighbor_method was set" )
766
+ neighbors = self .__neighbor_cache [neighbor_method ].get (param_config , None )
763
767
# if there are no cached neighbors, compute them
764
768
if neighbors is None :
765
769
neighbors = self .get_neighbors_indices_no_cache (param_config , neighbor_method )
766
- self .__neighbor_cache [param_config ] = neighbors
767
- # if the neighbors were cached but the specified neighbor method was different than the one initially used to build the cache, throw an error
768
- elif (
769
- self .neighbor_method is not None and neighbor_method is not None and self .neighbor_method != neighbor_method
770
- ):
771
- raise ValueError (
772
- f"The neighbor method { neighbor_method } differs from the intially set { self .neighbor_method } , can not use cached neighbors. Use 'get_neighbors_no_cache()' when mixing neighbor methods to avoid this."
773
- )
770
+ self .__neighbor_cache [neighbor_method ][param_config ] = neighbors
774
771
return neighbors
775
772
776
- def are_neighbors_indices_cached (self , param_config : tuple ) -> bool :
773
+ def are_neighbors_indices_cached (self , param_config : tuple , neighbor_method = None ) -> bool :
777
774
"""Returns true if the neighbor indices are in the cache, false otherwise."""
778
- return param_config in self .__neighbor_cache
775
+ if neighbor_method is None :
776
+ neighbor_method = self .neighbor_method
777
+ if neighbor_method is None :
778
+ raise ValueError ("Neither the neighbor_method argument nor self.neighbor_method was set" )
779
+ return param_config in self .__neighbor_cache [neighbor_method ]
779
780
780
781
def get_neighbors_no_cache (self , param_config : tuple , neighbor_method = None ) -> List [tuple ]:
781
782
"""Get the neighbors for a parameter configuration (does not check running cache, useful when mixing neighbor methods)."""
0 commit comments