4343 Probability ,
4444 ProbabilityWeight ,
4545 PyBanditsBaseModel ,
46+ QuantitativeMOProbability ,
47+ QuantitativeMOProbabilityWeight ,
48+ QuantitativeProbability ,
49+ QuantitativeProbabilityWeight ,
4650 Serializable ,
4751 UnifiedActionId ,
4852)
5256 validate_call ,
5357)
5458from pybandits .quantitative_model import QuantitativeModel
55- from pybandits .strategy import Strategy
59+ from pybandits .strategy import BaseStrategy
5660from pybandits .utils import extract_argument_names_from_function
5761
5862
@@ -79,12 +83,12 @@ class BaseMab(PyBanditsBaseModel, ABC):
7983 """
8084
8185 actions_manager : ActionsManager
82- strategy : Strategy
86+ strategy : BaseStrategy
8387 epsilon : Optional [Float01 ] = None
8488 default_action : Optional [UnifiedActionId ] = None
8589 version : Optional [str ] = None
86- deprecated_adwin_keys : ClassVar [List [str ]] = ["adaptive_window_size" , "actions_memory" , "rewards_memory" ]
87- current_supported_version_th : ClassVar [str ] = "3.0.0"
90+ _deprecated_adwin_keys : ClassVar [List [str ]] = ["adaptive_window_size" , "actions_memory" , "rewards_memory" ]
91+ _current_supported_version_th : ClassVar [str ] = "3.0.0"
8892
8993 def __init__ (
9094 self ,
@@ -228,32 +232,13 @@ def update(
228232 def _transform_nested_list (lst : List [List [Dict ]]):
229233 return [{k : v for d in single_action_dicts for k , v in d .items ()} for single_action_dicts in zip (* lst )]
230234
231- @staticmethod
232- def _is_so_standard_action (value : Any ) -> bool :
233- # Probability ProbabilityWeight
234- return isinstance (value , float ) or (isinstance (value , tuple ) and isinstance (value [0 ], float ))
235-
236- @staticmethod
237- def _is_so_quantitative_action (value : Any ) -> bool :
238- return isinstance (value , tuple ) and isinstance (value [0 ], tuple )
239-
240- @classmethod
241- def _is_standard_action (cls , value : Any ) -> bool :
242- return cls ._is_so_standard_action (value ) or (isinstance (value , list ) and cls ._is_so_standard_action (value [0 ]))
243-
244- @classmethod
245- def _is_quantitative_action (cls , value : Any ) -> bool :
246- return cls ._is_so_quantitative_action (value ) or (
247- isinstance (value , list ) and cls ._is_so_quantitative_action (value [0 ])
248- )
249-
250235 def _get_action_probabilities (
251236 self , forbidden_actions : Optional [Set [ActionId ]] = None , ** kwargs
252237 ) -> Union [
253- List [Dict [UnifiedActionId , Probability ]],
254- List [Dict [UnifiedActionId , ProbabilityWeight ]],
255- List [Dict [UnifiedActionId , MOProbability ]],
256- List [Dict [UnifiedActionId , MOProbabilityWeight ]],
238+ List [Dict [ActionId , Union [ Probability , QuantitativeProbability ] ]],
239+ List [Dict [ActionId , Union [ ProbabilityWeight , QuantitativeProbabilityWeight ] ]],
240+ List [Dict [ActionId , Union [ MOProbability , QuantitativeMOProbability ] ]],
241+ List [Dict [ActionId , Union [ MOProbabilityWeight , QuantitativeMOProbabilityWeight ] ]],
257242 ]:
258243 """
259244 Get the probability of getting a positive reward for each action.
@@ -276,34 +261,9 @@ def _get_action_probabilities(
276261 action : model .sample_proba (** kwargs ) for action , model in self .actions .items () if action in valid_actions
277262 }
278263 # Handle standard actions for which the value is a (probability, weight) tuple
279- actions_transformations = [
280- [{key : proba } for proba in value ]
281- for key , value in action_probabilities .items ()
282- if self ._is_standard_action (value [0 ])
283- ]
284- actions_transformations = self ._transform_nested_list (actions_transformations )
285- # Handle quantitative actions, for which the value is a tuple of
286- # tuples of (quantity, (probability, weight) or probability)
287- quantitative_actions_transformations = [
288- [{(key , quantity ): proba for quantity , proba in sample } for sample in value ]
289- for key , value in action_probabilities .items ()
290- if self ._is_quantitative_action (value [0 ])
291- ]
292- quantitative_actions_transformations = self ._transform_nested_list (quantitative_actions_transformations )
293- if not actions_transformations and not quantitative_actions_transformations :
294- return []
295- if not actions_transformations : # No standard actions
296- actions_transformations = [dict () for _ in range (len (quantitative_actions_transformations ))]
297- if not quantitative_actions_transformations : # No quantitative actions
298- quantitative_actions_transformations = [dict () for _ in range (len (actions_transformations ))]
299- if len (actions_transformations ) != len (quantitative_actions_transformations ):
300- raise ValueError ("The number of standard and quantitative actions should be the same." )
301- action_probabilities = [
302- {** actions_dict , ** quantitative_actions_dict }
303- for actions_dict , quantitative_actions_dict in zip (
304- actions_transformations , quantitative_actions_transformations
305- )
306- ]
264+ actions_transformations = [[{key : proba } for proba in value ] for key , value in action_probabilities .items ()]
265+ action_probabilities = self ._transform_nested_list (actions_transformations )
266+
307267 return action_probabilities
308268
309269 @abstractmethod
@@ -386,7 +346,7 @@ def _select_epsilon_greedy_action(
386346 if self .default_action :
387347 selected_action = self .default_action
388348 else :
389- actions = list (set ( a [ 0 ] if isinstance ( a , tuple ) else a for a in p .keys () ))
349+ actions = list (p .keys ())
390350 selected_action = random .choice (actions )
391351 if isinstance (self .actions [selected_action ], QuantitativeModel ):
392352 selected_action = (
@@ -450,7 +410,7 @@ def update_old_state(
450410 state ["actions_manager" ]["actions" ] = state .pop ("actions" )
451411 state ["actions_manager" ]["delta" ] = delta
452412
453- for key in cls .deprecated_adwin_keys :
413+ for key in cls ._deprecated_adwin_keys :
454414 if key in state ["actions_manager" ]:
455415 state ["actions_manager" ].pop (key )
456416
@@ -483,10 +443,10 @@ def from_old_state(
483443
484444 state_dict = json .loads (state )
485445 if ("version" in state_dict ) and (
486- version .parse (state_dict ["version" ]) >= version .parse (cls .current_supported_version_th )
446+ version .parse (state_dict ["version" ]) >= version .parse (cls ._current_supported_version_th )
487447 ):
488448 raise ValueError (
489- f"The state is expected to be in the old format of PyBandits < { cls .current_supported_version_th } ."
449+ f"The state is expected to be in the old format of PyBandits < { cls ._current_supported_version_th } ."
490450 )
491451 state_dict = cls .update_old_state (state_dict , delta )
492452 state = json .dumps (state_dict )
0 commit comments