66from typing import List , Union
77from warnings import warn
88from copy import deepcopy
9+ from collections import defaultdict
910
1011import numpy as np
1112from constraint import (
@@ -104,6 +105,7 @@ def __init__(
104105 self .build_neighbors_index = build_neighbors_index
105106 self .solver_method = solver_method
106107 self .__neighbor_cache = { method : dict () for method in supported_neighbor_methods }
108+ self .__neighbor_partial_cache = { method : defaultdict (list ) for method in supported_neighbor_methods }
107109 self .neighbors_index = dict ()
108110 self .neighbor_method = neighbor_method
109111 if (neighbor_method is not None or build_neighbors_index ) and neighbor_method not in supported_neighbor_methods :
@@ -837,6 +839,7 @@ def __get_random_neighbor_hamming(self, param_config: tuple) -> tuple:
837839 for i in random_order_indices :
838840 # assert arr[i].shape == target.shape, f"Row {i} shape {arr[i].shape} does not match target shape {target.shape}"
839841 if np .count_nonzero (arr [i ] != target ) == 1 :
842+ self .__add_to_neighbor_partial_cache (param_config , [i ], full_neighbors = False )
840843 return self .get_param_configs_at_indices ([i ])[0 ]
841844 return None
842845
@@ -876,14 +879,33 @@ def __get_random_neighbor_adjacent(self, param_config: tuple) -> tuple:
876879
877880 # if there are matching indices, return a random one
878881 if len (matching_indices ) > 0 :
879- # get the random index from the matching indices
882+ self .__add_to_neighbor_partial_cache (param_config , matching_indices , full_neighbors = allowed_index_difference == max_index_difference )
883+
884+ # get a random index from the matching indices
880885 random_neighbor_index = choice (matching_indices )
881886 return self .get_param_configs_at_indices ([random_neighbor_index ])[0 ]
882887
883888 # if there are no matching indices, increase the allowed index difference and start over
884889 allowed_index_difference += 1
885890 return None
886891
892+ def __add_to_neighbor_partial_cache (self , param_config : tuple , neighbor_indices : List [int ], neighbor_method : str , full_neighbors = False ):
893+ """Add the neighbor indices to the partial cache using the given parameter configuration."""
894+ param_config_index = self .get_param_config_index (param_config )
895+ if param_config_index is None :
896+ return # we need a valid parameter configuration to add to the cache
897+ # add the indices to the partial cache for the parameter configuration
898+ if full_neighbors :
899+ self .__neighbor_partial_cache [neighbor_method ][param_config_index ] = neighbor_indices
900+ else :
901+ for neighbor_index in neighbor_indices :
902+ if neighbor_index not in self .__neighbor_partial_cache [neighbor_method ][param_config_index ]:
903+ self .__neighbor_partial_cache [neighbor_method ][param_config_index ].append (neighbor_index )
904+ # add the parameter configuration index to the partial cache for each neighbor
905+ for neighbor_index in neighbor_indices :
906+ if param_config_index not in self .__neighbor_partial_cache [neighbor_method ][neighbor_index ]:
907+ self .__neighbor_partial_cache [neighbor_method ][neighbor_index ].append (param_config_index )
908+
887909 def __get_neighbors_indices_strictlyadjacent (
888910 self , param_config_index : int = None , param_config : tuple = None
889911 ) -> List [int ]:
@@ -1022,6 +1044,10 @@ def get_neighbors_indices(self, param_config: tuple, neighbor_method=None, build
10221044 if neighbors is None :
10231045 neighbors = self .get_neighbors_indices_no_cache (param_config , neighbor_method , build_full_cache )
10241046 self .__neighbor_cache [neighbor_method ][param_config ] = neighbors
1047+ self .__add_to_neighbor_partial_cache (param_config , neighbors , neighbor_method , full_neighbors = True )
1048+ if neighbor_method == "strictly-adjacent" :
1049+ # any neighbor in strictly-adjacent is also an adjacent neighbor
1050+ self .__add_to_neighbor_partial_cache (param_config , neighbors , "adjacent" , full_neighbors = False )
10251051 return neighbors
10261052
10271053 def are_neighbors_indices_cached (self , param_config : tuple , neighbor_method = None ) -> bool :
@@ -1040,28 +1066,65 @@ def get_neighbors(self, param_config: tuple, neighbor_method=None, build_full_ca
10401066 """Get the neighbors for a parameter configuration."""
10411067 return self .get_param_configs_at_indices (self .get_neighbors_indices (param_config , neighbor_method , build_full_cache ))
10421068
1043- def get_random_neighbor (self , param_config : tuple , neighbor_method = None ) -> tuple :
1044- """Get an approximately random neighbor for a parameter configuration. Much faster than taking a random choice of all neighbors, but does not build cache."""
1069+ def get_partial_neighbors_indices (self , param_config : tuple , neighbor_method = None ) -> List [tuple ]:
1070+ """Get the partial neighbors for a parameter configuration."""
1071+ if neighbor_method is None :
1072+ neighbor_method = self .neighbor_method
1073+ if neighbor_method is None :
1074+ raise ValueError ("Neither the neighbor_method argument nor self.neighbor_method was set" )
1075+ param_config_index = self .get_param_config_index (param_config )
1076+ if param_config_index is None or param_config_index not in self .__neighbor_partial_cache [neighbor_method ]:
1077+ return []
1078+ return self .get_param_configs_at_indices (self .__neighbor_partial_cache [neighbor_method ][param_config_index ])
1079+
1080+ def pop_random_partial_neighbor (self , param_config : tuple , neighbor_method = None , threshold = 2 ) -> tuple :
1081+ """Pop a random partial neighbor for a given a parameter configuration if there are at least `threshold` neighbors."""
1082+ if neighbor_method is None :
1083+ neighbor_method = self .neighbor_method
1084+ if neighbor_method is None :
1085+ raise ValueError ("Neither the neighbor_method argument nor self.neighbor_method was set" )
1086+ param_config_index = self .get_param_config_index (param_config )
1087+ if param_config_index is None or param_config_index not in self .__neighbor_partial_cache [neighbor_method ]:
1088+ return None
1089+ partial_neighbors = self .get_param_configs_at_indices (self .__neighbor_partial_cache [neighbor_method ][param_config_index ])
1090+ if len (partial_neighbors ) < threshold :
1091+ return None
1092+ partial_neighbor_index = choice (range (len (partial_neighbors )))
1093+ random_neighbor = self .__neighbor_partial_cache [neighbor_method ][param_config_index ].pop (partial_neighbor_index )
1094+ return self .get_param_configs_at_indices ([random_neighbor ])[0 ]
1095+
1096+ def get_random_neighbor (self , param_config : tuple , neighbor_method = None , use_partial_cache = True ) -> tuple :
1097+ """Get an approximately random neighbor for a parameter configuration. Much faster than taking a random choice of all neighbors, but does not build full cache."""
10451098 if self .are_neighbors_indices_cached (param_config , neighbor_method ):
10461099 neighbors = self .get_neighbors (param_config , neighbor_method )
10471100 return choice (neighbors ) if len (neighbors ) > 0 else None
1048- else :
1049- # check if there is a neighbor method to use
1101+ elif use_partial_cache :
1102+ # pop the chosen neighbor from the cache to avoid choosing it again until it is re-added
1103+ random_neighbor = self .pop_random_partial_neighbor (param_config , neighbor_method )
1104+ if random_neighbor is not None :
1105+ return random_neighbor
1106+
1107+ # check if there is a neighbor method to use
1108+ if neighbor_method is None :
1109+ neighbor_method = self .neighbor_method
10501110 if neighbor_method is None :
1051- neighbor_method = self .neighbor_method
1052-
1053- # find the random neighbor based on the method
1054- if neighbor_method == "adjacent" :
1055- return self .__get_random_neighbor_adjacent (param_config )
1056- # elif neighbor_method == "Hamming":
1057- # this implementation is not as efficient as just generating all neighbors
1058- # return self.__get_random_neighbor_hamming(param_config)
1059- else :
1060- # not much performance to be gained for strictly-adjacent neighbors, just generate the neighbors
1061- neighbors = self .get_neighbors (param_config , neighbor_method )
1062- if len (neighbors ) == 0 :
1063- return None
1064- return choice (neighbors )
1111+ raise ValueError ("Neither the neighbor_method argument nor self.neighbor_method was set" )
1112+
1113+ # oddly enough, the custom random neighbor methods are not faster than just generating all neighbor + partials
1114+ # # find the random neighbor based on the method
1115+ # if neighbor_method == "adjacent":
1116+ # return self.__get_random_neighbor_adjacent(param_config)
1117+ # elif neighbor_method == "Hamming":
1118+ # this implementation is not as efficient as just generating all neighbors
1119+ # return self.__get_random_neighbor_hamming(param_config)
1120+ # # else:
1121+ # # not much performance to be gained for strictly-adjacent neighbors, just generate the neighbors
1122+
1123+ # calculate the full neighbors and return a random one
1124+ neighbors = self .get_neighbors (param_config , neighbor_method )
1125+ if len (neighbors ) == 0 :
1126+ return None
1127+ return choice (neighbors )
10651128
10661129 def get_param_neighbors (self , param_config : tuple , index : int , neighbor_method : str , randomize : bool ) -> list :
10671130 """Get the neighboring parameters at an index."""
0 commit comments