4343 Probability ,
4444 ProbabilityWeight ,
4545 PyBanditsBaseModel ,
46+ QuantitativeMOProbability ,
47+ QuantitativeMOProbabilityWeight ,
48+ QuantitativeProbability ,
49+ QuantitativeProbabilityWeight ,
4650 Serializable ,
4751 UnifiedActionId ,
4852)
5256 validate_call ,
5357)
5458from pybandits .quantitative_model import QuantitativeModel
55- from pybandits .strategy import Strategy
59+ from pybandits .strategy import BaseStrategy
5660from pybandits .utils import extract_argument_names_from_function
5761
5862
@@ -79,12 +83,12 @@ class BaseMab(PyBanditsBaseModel, ABC):
7983 """
8084
8185 actions_manager : ActionsManager
82- strategy : Strategy
86+ strategy : BaseStrategy
8387 epsilon : Optional [Float01 ] = None
8488 default_action : Optional [UnifiedActionId ] = None
8589 version : Optional [str ] = None
86- deprecated_adwin_keys : ClassVar [List [str ]] = ["adaptive_window_size" , "actions_memory" , "rewards_memory" ]
87- current_supported_version_th : ClassVar [str ] = "3.0.0"
90+ _deprecated_adwin_keys : ClassVar [List [str ]] = ["adaptive_window_size" , "actions_memory" , "rewards_memory" ]
91+ _current_supported_version_th : ClassVar [str ] = "3.0.0"
8892
8993 def __init__ (
9094 self ,
@@ -232,32 +236,13 @@ def update(
232236 def _transform_nested_list (lst : List [List [Dict ]]):
233237 return [{k : v for d in single_action_dicts for k , v in d .items ()} for single_action_dicts in zip (* lst )]
234238
235- @staticmethod
236- def _is_so_standard_action (value : Any ) -> bool :
237- # Probability ProbabilityWeight
238- return isinstance (value , float ) or (isinstance (value , tuple ) and isinstance (value [0 ], float ))
239-
240- @staticmethod
241- def _is_so_quantitative_action (value : Any ) -> bool :
242- return isinstance (value , tuple ) and isinstance (value [0 ], tuple )
243-
244- @classmethod
245- def _is_standard_action (cls , value : Any ) -> bool :
246- return cls ._is_so_standard_action (value ) or (isinstance (value , list ) and cls ._is_so_standard_action (value [0 ]))
247-
248- @classmethod
249- def _is_quantitative_action (cls , value : Any ) -> bool :
250- return cls ._is_so_quantitative_action (value ) or (
251- isinstance (value , list ) and cls ._is_so_quantitative_action (value [0 ])
252- )
253-
254239 def _get_action_probabilities (
255240 self , forbidden_actions : Optional [Set [ActionId ]] = None , ** kwargs
256241 ) -> Union [
257- List [Dict [UnifiedActionId , Probability ]],
258- List [Dict [UnifiedActionId , ProbabilityWeight ]],
259- List [Dict [UnifiedActionId , MOProbability ]],
260- List [Dict [UnifiedActionId , MOProbabilityWeight ]],
242+ List [Dict [ActionId , Union [ Probability , QuantitativeProbability ] ]],
243+ List [Dict [ActionId , Union [ ProbabilityWeight , QuantitativeProbabilityWeight ] ]],
244+ List [Dict [ActionId , Union [ MOProbability , QuantitativeMOProbability ] ]],
245+ List [Dict [ActionId , Union [ MOProbabilityWeight , QuantitativeMOProbabilityWeight ] ]],
261246 ]:
262247 """
263248 Get the probability of getting a positive reward for each action.
@@ -280,34 +265,9 @@ def _get_action_probabilities(
280265 action : model .sample_proba (** kwargs ) for action , model in self .actions .items () if action in valid_actions
281266 }
282267 # Handle standard actions for which the value is a (probability, weight) tuple
283- actions_transformations = [
284- [{key : proba } for proba in value ]
285- for key , value in action_probabilities .items ()
286- if self ._is_standard_action (value [0 ])
287- ]
288- actions_transformations = self ._transform_nested_list (actions_transformations )
289- # Handle quantitative actions, for which the value is a tuple of
290- # tuples of (quantity, (probability, weight) or probability)
291- quantitative_actions_transformations = [
292- [{(key , quantity ): proba for quantity , proba in sample } for sample in value ]
293- for key , value in action_probabilities .items ()
294- if self ._is_quantitative_action (value [0 ])
295- ]
296- quantitative_actions_transformations = self ._transform_nested_list (quantitative_actions_transformations )
297- if not actions_transformations and not quantitative_actions_transformations :
298- return []
299- if not actions_transformations : # No standard actions
300- actions_transformations = [dict () for _ in range (len (quantitative_actions_transformations ))]
301- if not quantitative_actions_transformations : # No quantitative actions
302- quantitative_actions_transformations = [dict () for _ in range (len (actions_transformations ))]
303- if len (actions_transformations ) != len (quantitative_actions_transformations ):
304- raise ValueError ("The number of standard and quantitative actions should be the same." )
305- action_probabilities = [
306- {** actions_dict , ** quantitative_actions_dict }
307- for actions_dict , quantitative_actions_dict in zip (
308- actions_transformations , quantitative_actions_transformations
309- )
310- ]
268+ actions_transformations = [[{key : proba } for proba in value ] for key , value in action_probabilities .items ()]
269+ action_probabilities = self ._transform_nested_list (actions_transformations )
270+
311271 return action_probabilities
312272
313273 @abstractmethod
@@ -399,7 +359,7 @@ def _select_epsilon_greedy_action(
399359 if self .default_action :
400360 selected_action = self .default_action
401361 else :
402- actions = list (set ( a [ 0 ] if isinstance ( a , tuple ) else a for a in p .keys () ))
362+ actions = list (p .keys ())
403363 selected_action = random .choice (actions )
404364 if isinstance (self .actions [selected_action ], QuantitativeModel ):
405365 selected_action = (
@@ -463,7 +423,7 @@ def update_old_state(
463423 state ["actions_manager" ]["actions" ] = state .pop ("actions" )
464424 state ["actions_manager" ]["delta" ] = delta
465425
466- for key in cls .deprecated_adwin_keys :
426+ for key in cls ._deprecated_adwin_keys :
467427 if key in state ["actions_manager" ]:
468428 state ["actions_manager" ].pop (key )
469429
@@ -496,10 +456,10 @@ def from_old_state(
496456
497457 state_dict = json .loads (state )
498458 if ("version" in state_dict ) and (
499- version .parse (state_dict ["version" ]) >= version .parse (cls .current_supported_version_th )
459+ version .parse (state_dict ["version" ]) >= version .parse (cls ._current_supported_version_th )
500460 ):
501461 raise ValueError (
502- f"The state is expected to be in the old format of PyBandits < { cls .current_supported_version_th } ."
462+ f"The state is expected to be in the old format of PyBandits < { cls ._current_supported_version_th } ."
503463 )
504464 state_dict = cls .update_old_state (state_dict , delta )
505465 state = json .dumps (state_dict )
0 commit comments