Skip to content

Commit 1664e90

Browse files
committed
Unify quantitative actions interface
### Changes: * Moved `_normalize_field` method to `PyBanditsBaseModel` for default value normalization of optional fields. * Unified quantitative actions interface to use callables for general form, rather than floats for zooming. * Improved tests for action selection and quantitative actions, ensuring proper handling of constraints and expected results.
1 parent 60c6b46 commit 1664e90

19 files changed

+3183
-726
lines changed

docs/src/tutorials/cmab_zooming.ipynb

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
"import matplotlib.pyplot as plt\n",
2929
"import numpy as np\n",
3030
"import pandas as pd\n",
31-
"from sklearn.preprocessing import StandardScaler\n",
3231
"\n",
3332
"from pybandits.cmab import CmabBernoulli\n",
3433
"from pybandits.quantitative_model import CmabZoomingModel"
@@ -202,11 +201,7 @@
202201
"n_batches = 10\n",
203202
"batch_size = 100\n",
204203
"n_rounds = n_batches * batch_size\n",
205-
"raw_context_data = np.random.normal(0, 1, (n_rounds, n_features))\n",
206-
"\n",
207-
"# Standardize the context data\n",
208-
"scaler = StandardScaler()\n",
209-
"context_data = scaler.fit_transform(raw_context_data)\n",
204+
"context_data = np.random.uniform(0, 1, (n_rounds, n_features))\n",
210205
"\n",
211206
"# Preview the context data\n",
212207
"pd.DataFrame(context_data[:5], columns=[f\"Feature {i + 1}\" for i in range(n_features)])"
@@ -313,24 +308,24 @@
313308
"outputs": [],
314309
"source": [
315310
"# Define test contexts\n",
316-
"test_contexts = [\n",
317-
" [2.0, -1.0, 0.0], # High feature 1, low feature 2\n",
318-
" [-1.0, 2.0, 0.0], # Low feature 1, high feature 2\n",
319-
" [1.0, 1.0, 0.0], # High feature 1 and 2\n",
320-
" [-1.0, -1.0, 0.0], # Low feature 1 and 2\n",
321-
"]\n",
322-
"test_contexts = scaler.transform(test_contexts)\n",
311+
"test_contexts = np.array(\n",
312+
" [\n",
313+
" [1.0, 0.0, 0.0], # High feature 1, low feature 2\n",
314+
" [0.0, 1.0, 0.0], # Low feature 1, high feature 2\n",
315+
" [1.0, 1.0, 0.0], # High feature 1 and 2\n",
316+
" [0.0, 0.0, 0.0], # Low feature 1 and 2\n",
317+
" ]\n",
318+
")\n",
323319
"\n",
324320
"# Test predictions\n",
325321
"results = []\n",
326322
"for i, context in enumerate(test_contexts):\n",
327323
" context_reshaped = context.reshape(1, -1)\n",
328324
" pred_actions, probs, weighted_sums = cmab.predict(context=context_reshaped)\n",
329325
" chosen_action_quantity = pred_actions[0]\n",
330-
" chosen_action_probs = {action: probs[0][chosen_action_quantity] for action in actions}\n",
331326
" chosen_action = chosen_action_quantity[0]\n",
332327
" chosen_quantities = chosen_action_quantity[1][0]\n",
333-
" chosen_action_probs = probs[0][chosen_action_quantity]\n",
328+
" chosen_action_probs = probs[0][chosen_action](chosen_quantities)\n",
334329
"\n",
335330
" # Sample optimal quantity for the chosen action\n",
336331
" # In a real application, you would have a method to test different quantities\n",
@@ -347,6 +342,7 @@
347342
" {\n",
348343
" \"Context\": context,\n",
349344
" \"Chosen Action\": chosen_action,\n",
345+
" \"Chosen Qunatity\": chosen_quantities,\n",
350346
" \"Action Probabilities\": chosen_action_probs,\n",
351347
" \"Optimal Quantity\": optimal_quantity,\n",
352348
" \"Expected Reward\": expected_reward,\n",
@@ -368,6 +364,7 @@
368364
" print(f\"\\nTest {i + 1}: {context_type}\")\n",
369365
" print(f\"Context: {result['Context']}\")\n",
370366
" print(f\"Chosen Action: {result['Chosen Action']}\")\n",
367+
" print(f\"Chosen Quantity: {result['Chosen Qunatity']}\")\n",
371368
" print(f\"Action Probabilities: {result['Action Probabilities']}\")\n",
372369
" print(f\"Optimal Quantity: {result['Optimal Quantity']:.2f}\")\n",
373370
" print(f\"Expected Reward: {result['Expected Reward']}\")"

pybandits/base.py

Lines changed: 47 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,22 @@
2020
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
2121
# SOFTWARE.
2222

23-
from typing import Any, Dict, List, Mapping, NewType, Optional, Tuple, Union, _GenericAlias, get_args, get_origin
23+
from typing import (
24+
Any,
25+
Callable,
26+
Dict,
27+
List,
28+
Mapping,
29+
NewType,
30+
Optional,
31+
Tuple,
32+
Union,
33+
_GenericAlias,
34+
get_args,
35+
get_origin,
36+
)
2437

38+
import numpy as np
2539
from typing_extensions import Self
2640

2741
from pybandits.pydantic_version_compatibility import (
@@ -45,10 +59,12 @@
4559
MOProbability = List[Probability]
4660
MOProbabilityWeight = List[ProbabilityWeight]
4761
# QuantitativeProbability generalizes probability to include both action quantities and their associated probability
48-
QuantitativeProbability = Tuple[Tuple[Tuple[Float01, ...], Probability], ...]
49-
QuantitativeProbabilityWeight = Tuple[Tuple[Tuple[Float01, ...], ProbabilityWeight], ...]
50-
QuantitativeMOProbability = Tuple[Tuple[Tuple[Float01, ...], List[Probability]], ...]
51-
QuantitativeMOProbabilityWeight = Tuple[Tuple[Tuple[Float01, ...], List[ProbabilityWeight]], ...]
62+
QuantitativeProbability = Callable[[np.ndarray], Probability]
63+
QuantitativeWeight = Callable[[np.ndarray], float]
64+
QuantitativeProbabilityWeight = Tuple[QuantitativeProbability, QuantitativeWeight]
65+
QuantitativeMOProbability = Callable[[np.ndarray], MOProbability]
66+
QuantitativeMOProbabilityWeight = Tuple[Callable[[np.ndarray], MOProbability], Callable[[np.ndarray], float]]
67+
5268
UnifiedProbability = Union[Probability, QuantitativeProbability]
5369
UnifiedProbabilityWeight = Union[ProbabilityWeight, QuantitativeProbabilityWeight]
5470
UnifiedMOProbability = Union[MOProbability, QuantitativeMOProbability]
@@ -79,10 +95,10 @@
7995
ActionRewardLikelihood = NewType(
8096
"ActionRewardLikelihood",
8197
Union[
82-
Dict[UnifiedActionId, float],
83-
Dict[UnifiedActionId, List[float]],
84-
Dict[UnifiedActionId, Probability],
85-
Dict[UnifiedActionId, List[Probability]],
98+
Dict[ActionId, Union[float, Callable[[np.ndarray], float]]],
99+
Dict[ActionId, Union[List[float], Callable[[np.ndarray], List[float]]]],
100+
Dict[ActionId, Union[Probability, Callable[[np.ndarray], Probability]]],
101+
Dict[ActionId, Union[List[Probability], Callable[[np.ndarray], List[Probability]]]],
86102
],
87103
)
88104
ACTION_IDS_PREFIX = "action_ids_"
@@ -190,6 +206,28 @@ def _get_field_type(cls, key: str) -> Any:
190206
annotation = get_args(annotation)
191207
return annotation
192208

209+
@classmethod
210+
def _normalize_field(cls, v: Any, field_name: str) -> Any:
211+
"""
212+
Normalize a field value to its default if None.
213+
214+
This utility method ensures that optional fields receive their default
215+
values when not explicitly provided.
216+
217+
Parameters
218+
----------
219+
v : Any
220+
The field value to normalize.
221+
field_name : str
222+
Name of the field in the model.
223+
224+
Returns
225+
-------
226+
Any
227+
The original value if not None, otherwise the field's default value.
228+
"""
229+
return v if v is not None else cls.model_fields[field_name].default
230+
193231
if pydantic_version == PYDANTIC_VERSION_1:
194232

195233
@classproperty

pybandits/cmab_simulator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
ParametricActionProbability,
3636
Simulator,
3737
)
38-
from pybandits.utils import extract_argument_names_from_function
38+
from pybandits.utils import extract_argument_names_from_function, maximize_by_quantity
3939

4040
CmabProbabilityValue = Union[ParametricActionProbability, DoubleParametricActionProbability]
4141
CmabActionProbabilityGroundTruth = Dict[ActionId, CmabProbabilityValue]
@@ -234,7 +234,7 @@ def _finalize_step(self, batch_results: pd.DataFrame, update_kwargs: Dict[str, n
234234
batch_results.loc[:, "selected_prob_reward"] = selected_prob_reward
235235
max_prob_reward = [
236236
max(
237-
self._maximize_prob_reward((lambda q: self.probs_reward[g][a](c, q)), m.dimension)
237+
maximize_by_quantity((lambda q: self.probs_reward[g][a](c, q)), m.dimension)
238238
if isinstance(m, QuantitativeModel)
239239
else self.probs_reward[g][a](c)
240240
for a, m in self.mab.actions.items()

pybandits/mab.py

Lines changed: 19 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,10 @@
4343
Probability,
4444
ProbabilityWeight,
4545
PyBanditsBaseModel,
46+
QuantitativeMOProbability,
47+
QuantitativeMOProbabilityWeight,
48+
QuantitativeProbability,
49+
QuantitativeProbabilityWeight,
4650
Serializable,
4751
UnifiedActionId,
4852
)
@@ -52,7 +56,7 @@
5256
validate_call,
5357
)
5458
from pybandits.quantitative_model import QuantitativeModel
55-
from pybandits.strategy import Strategy
59+
from pybandits.strategy import BaseStrategy
5660
from pybandits.utils import extract_argument_names_from_function
5761

5862

@@ -79,12 +83,12 @@ class BaseMab(PyBanditsBaseModel, ABC):
7983
"""
8084

8185
actions_manager: ActionsManager
82-
strategy: Strategy
86+
strategy: BaseStrategy
8387
epsilon: Optional[Float01] = None
8488
default_action: Optional[UnifiedActionId] = None
8589
version: Optional[str] = None
86-
deprecated_adwin_keys: ClassVar[List[str]] = ["adaptive_window_size", "actions_memory", "rewards_memory"]
87-
current_supported_version_th: ClassVar[str] = "3.0.0"
90+
_deprecated_adwin_keys: ClassVar[List[str]] = ["adaptive_window_size", "actions_memory", "rewards_memory"]
91+
_current_supported_version_th: ClassVar[str] = "3.0.0"
8892

8993
def __init__(
9094
self,
@@ -228,32 +232,13 @@ def update(
228232
def _transform_nested_list(lst: List[List[Dict]]):
229233
return [{k: v for d in single_action_dicts for k, v in d.items()} for single_action_dicts in zip(*lst)]
230234

231-
@staticmethod
232-
def _is_so_standard_action(value: Any) -> bool:
233-
# Probability ProbabilityWeight
234-
return isinstance(value, float) or (isinstance(value, tuple) and isinstance(value[0], float))
235-
236-
@staticmethod
237-
def _is_so_quantitative_action(value: Any) -> bool:
238-
return isinstance(value, tuple) and isinstance(value[0], tuple)
239-
240-
@classmethod
241-
def _is_standard_action(cls, value: Any) -> bool:
242-
return cls._is_so_standard_action(value) or (isinstance(value, list) and cls._is_so_standard_action(value[0]))
243-
244-
@classmethod
245-
def _is_quantitative_action(cls, value: Any) -> bool:
246-
return cls._is_so_quantitative_action(value) or (
247-
isinstance(value, list) and cls._is_so_quantitative_action(value[0])
248-
)
249-
250235
def _get_action_probabilities(
251236
self, forbidden_actions: Optional[Set[ActionId]] = None, **kwargs
252237
) -> Union[
253-
List[Dict[UnifiedActionId, Probability]],
254-
List[Dict[UnifiedActionId, ProbabilityWeight]],
255-
List[Dict[UnifiedActionId, MOProbability]],
256-
List[Dict[UnifiedActionId, MOProbabilityWeight]],
238+
List[Dict[ActionId, Union[Probability, QuantitativeProbability]]],
239+
List[Dict[ActionId, Union[ProbabilityWeight, QuantitativeProbabilityWeight]]],
240+
List[Dict[ActionId, Union[MOProbability, QuantitativeMOProbability]]],
241+
List[Dict[ActionId, Union[MOProbabilityWeight, QuantitativeMOProbabilityWeight]]],
257242
]:
258243
"""
259244
Get the probability of getting a positive reward for each action.
@@ -276,34 +261,9 @@ def _get_action_probabilities(
276261
action: model.sample_proba(**kwargs) for action, model in self.actions.items() if action in valid_actions
277262
}
278263
# Handle standard actions for which the value is a (probability, weight) tuple
279-
actions_transformations = [
280-
[{key: proba} for proba in value]
281-
for key, value in action_probabilities.items()
282-
if self._is_standard_action(value[0])
283-
]
284-
actions_transformations = self._transform_nested_list(actions_transformations)
285-
# Handle quantitative actions, for which the value is a tuple of
286-
# tuples of (quantity, (probability, weight) or probability)
287-
quantitative_actions_transformations = [
288-
[{(key, quantity): proba for quantity, proba in sample} for sample in value]
289-
for key, value in action_probabilities.items()
290-
if self._is_quantitative_action(value[0])
291-
]
292-
quantitative_actions_transformations = self._transform_nested_list(quantitative_actions_transformations)
293-
if not actions_transformations and not quantitative_actions_transformations:
294-
return []
295-
if not actions_transformations: # No standard actions
296-
actions_transformations = [dict() for _ in range(len(quantitative_actions_transformations))]
297-
if not quantitative_actions_transformations: # No quantitative actions
298-
quantitative_actions_transformations = [dict() for _ in range(len(actions_transformations))]
299-
if len(actions_transformations) != len(quantitative_actions_transformations):
300-
raise ValueError("The number of standard and quantitative actions should be the same.")
301-
action_probabilities = [
302-
{**actions_dict, **quantitative_actions_dict}
303-
for actions_dict, quantitative_actions_dict in zip(
304-
actions_transformations, quantitative_actions_transformations
305-
)
306-
]
264+
actions_transformations = [[{key: proba} for proba in value] for key, value in action_probabilities.items()]
265+
action_probabilities = self._transform_nested_list(actions_transformations)
266+
307267
return action_probabilities
308268

309269
@abstractmethod
@@ -386,7 +346,7 @@ def _select_epsilon_greedy_action(
386346
if self.default_action:
387347
selected_action = self.default_action
388348
else:
389-
actions = list(set(a[0] if isinstance(a, tuple) else a for a in p.keys()))
349+
actions = list(p.keys())
390350
selected_action = random.choice(actions)
391351
if isinstance(self.actions[selected_action], QuantitativeModel):
392352
selected_action = (
@@ -450,7 +410,7 @@ def update_old_state(
450410
state["actions_manager"]["actions"] = state.pop("actions")
451411
state["actions_manager"]["delta"] = delta
452412

453-
for key in cls.deprecated_adwin_keys:
413+
for key in cls._deprecated_adwin_keys:
454414
if key in state["actions_manager"]:
455415
state["actions_manager"].pop(key)
456416

@@ -483,10 +443,10 @@ def from_old_state(
483443

484444
state_dict = json.loads(state)
485445
if ("version" in state_dict) and (
486-
version.parse(state_dict["version"]) >= version.parse(cls.current_supported_version_th)
446+
version.parse(state_dict["version"]) >= version.parse(cls._current_supported_version_th)
487447
):
488448
raise ValueError(
489-
f"The state is expected to be in the old format of PyBandits < {cls.current_supported_version_th}."
449+
f"The state is expected to be in the old format of PyBandits < {cls._current_supported_version_th}."
490450
)
491451
state_dict = cls.update_old_state(state_dict, delta)
492452
state = json.dumps(state_dict)

0 commit comments

Comments
 (0)