Skip to content

Commit 2c00401

Browse files
committed
Unify quantitative actions interface
### Changes: * Moved `_normalize_field` method to `PyBanditsBaseModel` for default value normalization of optional fields. * Unified quantitative actions interface to use callables for general form, rather than floats for zooming. * Improved tests for action selection and quantitative actions, ensuring proper handling of constraints and expected results.
1 parent d295ff4 commit 2c00401

20 files changed

+3279
-764
lines changed

docs/src/tutorials/cmab_zooming.ipynb

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
"import matplotlib.pyplot as plt\n",
2929
"import numpy as np\n",
3030
"import pandas as pd\n",
31-
"from sklearn.preprocessing import StandardScaler\n",
3231
"\n",
3332
"from pybandits.cmab import CmabBernoulli\n",
3433
"from pybandits.quantitative_model import CmabZoomingModel"
@@ -202,11 +201,7 @@
202201
"n_batches = 10\n",
203202
"batch_size = 100\n",
204203
"n_rounds = n_batches * batch_size\n",
205-
"raw_context_data = np.random.normal(0, 1, (n_rounds, n_features))\n",
206-
"\n",
207-
"# Standardize the context data\n",
208-
"scaler = StandardScaler()\n",
209-
"context_data = scaler.fit_transform(raw_context_data)\n",
204+
"context_data = np.random.uniform(0, 1, (n_rounds, n_features))\n",
210205
"\n",
211206
"# Preview the context data\n",
212207
"pd.DataFrame(context_data[:5], columns=[f\"Feature {i + 1}\" for i in range(n_features)])"
@@ -313,24 +308,24 @@
313308
"outputs": [],
314309
"source": [
315310
"# Define test contexts\n",
316-
"test_contexts = [\n",
317-
" [2.0, -1.0, 0.0], # High feature 1, low feature 2\n",
318-
" [-1.0, 2.0, 0.0], # Low feature 1, high feature 2\n",
319-
" [1.0, 1.0, 0.0], # High feature 1 and 2\n",
320-
" [-1.0, -1.0, 0.0], # Low feature 1 and 2\n",
321-
"]\n",
322-
"test_contexts = scaler.transform(test_contexts)\n",
311+
"test_contexts = np.array(\n",
312+
" [\n",
313+
" [1.0, 0.0, 0.0], # High feature 1, low feature 2\n",
314+
" [0.0, 1.0, 0.0], # Low feature 1, high feature 2\n",
315+
" [1.0, 1.0, 0.0], # High feature 1 and 2\n",
316+
" [0.0, 0.0, 0.0], # Low feature 1 and 2\n",
317+
" ]\n",
318+
")\n",
323319
"\n",
324320
"# Test predictions\n",
325321
"results = []\n",
326322
"for i, context in enumerate(test_contexts):\n",
327323
" context_reshaped = context.reshape(1, -1)\n",
328324
" pred_actions, probs, weighted_sums = cmab.predict(context=context_reshaped)\n",
329325
" chosen_action_quantity = pred_actions[0]\n",
330-
" chosen_action_probs = {action: probs[0][chosen_action_quantity] for action in actions}\n",
331326
" chosen_action = chosen_action_quantity[0]\n",
332327
" chosen_quantities = chosen_action_quantity[1][0]\n",
333-
" chosen_action_probs = probs[0][chosen_action_quantity]\n",
328+
" chosen_action_probs = probs[0][chosen_action](chosen_quantities)\n",
334329
"\n",
335330
" # Sample optimal quantity for the chosen action\n",
336331
" # In a real application, you would have a method to test different quantities\n",
@@ -347,6 +342,7 @@
347342
" {\n",
348343
" \"Context\": context,\n",
349344
" \"Chosen Action\": chosen_action,\n",
345+
" \"Chosen Qunatity\": chosen_quantities,\n",
350346
" \"Action Probabilities\": chosen_action_probs,\n",
351347
" \"Optimal Quantity\": optimal_quantity,\n",
352348
" \"Expected Reward\": expected_reward,\n",
@@ -368,6 +364,7 @@
368364
" print(f\"\\nTest {i + 1}: {context_type}\")\n",
369365
" print(f\"Context: {result['Context']}\")\n",
370366
" print(f\"Chosen Action: {result['Chosen Action']}\")\n",
367+
" print(f\"Chosen Quantity: {result['Chosen Qunatity']}\")\n",
371368
" print(f\"Action Probabilities: {result['Action Probabilities']}\")\n",
372369
" print(f\"Optimal Quantity: {result['Optimal Quantity']:.2f}\")\n",
373370
" print(f\"Expected Reward: {result['Expected Reward']}\")"

pybandits/base.py

Lines changed: 47 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,22 @@
2020
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
2121
# SOFTWARE.
2222

23-
from typing import Any, Dict, List, Mapping, NewType, Optional, Tuple, Union, _GenericAlias, get_args, get_origin
23+
from typing import (
24+
Any,
25+
Callable,
26+
Dict,
27+
List,
28+
Mapping,
29+
NewType,
30+
Optional,
31+
Tuple,
32+
Union,
33+
_GenericAlias,
34+
get_args,
35+
get_origin,
36+
)
2437

38+
import numpy as np
2539
from typing_extensions import Self
2640

2741
from pybandits.pydantic_version_compatibility import (
@@ -45,10 +59,12 @@
4559
MOProbability = List[Probability]
4660
MOProbabilityWeight = List[ProbabilityWeight]
4761
# QuantitativeProbability generalizes probability to include both action quantities and their associated probability
48-
QuantitativeProbability = Tuple[Tuple[Tuple[Float01, ...], Probability], ...]
49-
QuantitativeProbabilityWeight = Tuple[Tuple[Tuple[Float01, ...], ProbabilityWeight], ...]
50-
QuantitativeMOProbability = Tuple[Tuple[Tuple[Float01, ...], List[Probability]], ...]
51-
QuantitativeMOProbabilityWeight = Tuple[Tuple[Tuple[Float01, ...], List[ProbabilityWeight]], ...]
62+
QuantitativeProbability = Callable[[np.ndarray], Probability]
63+
QuantitativeWeight = Callable[[np.ndarray], float]
64+
QuantitativeProbabilityWeight = Tuple[QuantitativeProbability, QuantitativeWeight]
65+
QuantitativeMOProbability = Callable[[np.ndarray], MOProbability]
66+
QuantitativeMOProbabilityWeight = Tuple[Callable[[np.ndarray], MOProbability], Callable[[np.ndarray], float]]
67+
5268
UnifiedProbability = Union[Probability, QuantitativeProbability]
5369
UnifiedProbabilityWeight = Union[ProbabilityWeight, QuantitativeProbabilityWeight]
5470
UnifiedMOProbability = Union[MOProbability, QuantitativeMOProbability]
@@ -79,10 +95,10 @@
7995
ActionRewardLikelihood = NewType(
8096
"ActionRewardLikelihood",
8197
Union[
82-
Dict[UnifiedActionId, float],
83-
Dict[UnifiedActionId, List[float]],
84-
Dict[UnifiedActionId, Probability],
85-
Dict[UnifiedActionId, List[Probability]],
98+
Dict[ActionId, Union[float, Callable[[np.ndarray], float]]],
99+
Dict[ActionId, Union[List[float], Callable[[np.ndarray], List[float]]]],
100+
Dict[ActionId, Union[Probability, Callable[[np.ndarray], Probability]]],
101+
Dict[ActionId, Union[List[Probability], Callable[[np.ndarray], List[Probability]]]],
86102
],
87103
)
88104
ACTION_IDS_PREFIX = "action_ids_"
@@ -190,6 +206,28 @@ def _get_field_type(cls, key: str) -> Any:
190206
annotation = get_args(annotation)
191207
return annotation
192208

209+
@classmethod
210+
def _normalize_field(cls, v: Any, field_name: str) -> Any:
211+
"""
212+
Normalize a field value to its default if None.
213+
214+
This utility method ensures that optional fields receive their default
215+
values when not explicitly provided.
216+
217+
Parameters
218+
----------
219+
v : Any
220+
The field value to normalize.
221+
field_name : str
222+
Name of the field in the model.
223+
224+
Returns
225+
-------
226+
Any
227+
The original value if not None, otherwise the field's default value.
228+
"""
229+
return v if v is not None else cls.model_fields[field_name].default
230+
193231
if pydantic_version == PYDANTIC_VERSION_1:
194232

195233
@classproperty

pybandits/cmab_simulator.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,11 @@
3535
ParametricActionProbability,
3636
Simulator,
3737
)
38-
from pybandits.utils import extract_argument_names_from_function
38+
from pybandits.utils import (
39+
OptimizationFailedError,
40+
extract_argument_names_from_function,
41+
maximize_by_quantity,
42+
)
3943

4044
CmabProbabilityValue = Union[ParametricActionProbability, DoubleParametricActionProbability]
4145
CmabActionProbabilityGroundTruth = Dict[ActionId, CmabProbabilityValue]
@@ -232,13 +236,20 @@ def _finalize_step(self, batch_results: pd.DataFrame, update_kwargs: Dict[str, n
232236
for a, q, g, c in zip(action_id, quantity, group_id, update_kwargs["context"])
233237
]
234238
batch_results.loc[:, "selected_prob_reward"] = selected_prob_reward
239+
240+
def get_max_prob_for_action(g: str, a: ActionId, c: np.ndarray, m) -> float:
241+
"""Get maximum probability for an action, handling optimization failures."""
242+
if isinstance(m, QuantitativeModel):
243+
try:
244+
opt_q = maximize_by_quantity((lambda q: self.probs_reward[g][a](c, q)), m.dimension)
245+
return self.probs_reward[g][a](c, opt_q)
246+
except OptimizationFailedError as e:
247+
raise ValueError(f"Optimization failed for action {a}: {e}")
248+
else:
249+
return self.probs_reward[g][a](c)
250+
235251
max_prob_reward = [
236-
max(
237-
self._maximize_prob_reward((lambda q: self.probs_reward[g][a](c, q)), m.dimension)
238-
if isinstance(m, QuantitativeModel)
239-
else self.probs_reward[g][a](c)
240-
for a, m in self.mab.actions.items()
241-
)
252+
max(get_max_prob_for_action(g, a, c, m) for a, m in self.mab.actions.items())
242253
for g, c in zip(group_id, update_kwargs["context"])
243254
]
244255
batch_results.loc[:, "max_prob_reward"] = max_prob_reward

pybandits/mab.py

Lines changed: 19 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,10 @@
4343
Probability,
4444
ProbabilityWeight,
4545
PyBanditsBaseModel,
46+
QuantitativeMOProbability,
47+
QuantitativeMOProbabilityWeight,
48+
QuantitativeProbability,
49+
QuantitativeProbabilityWeight,
4650
Serializable,
4751
UnifiedActionId,
4852
)
@@ -52,7 +56,7 @@
5256
validate_call,
5357
)
5458
from pybandits.quantitative_model import QuantitativeModel
55-
from pybandits.strategy import Strategy
59+
from pybandits.strategy import BaseStrategy
5660
from pybandits.utils import extract_argument_names_from_function
5761

5862

@@ -79,12 +83,12 @@ class BaseMab(PyBanditsBaseModel, ABC):
7983
"""
8084

8185
actions_manager: ActionsManager
82-
strategy: Strategy
86+
strategy: BaseStrategy
8387
epsilon: Optional[Float01] = None
8488
default_action: Optional[UnifiedActionId] = None
8589
version: Optional[str] = None
86-
deprecated_adwin_keys: ClassVar[List[str]] = ["adaptive_window_size", "actions_memory", "rewards_memory"]
87-
current_supported_version_th: ClassVar[str] = "3.0.0"
90+
_deprecated_adwin_keys: ClassVar[List[str]] = ["adaptive_window_size", "actions_memory", "rewards_memory"]
91+
_current_supported_version_th: ClassVar[str] = "3.0.0"
8892

8993
def __init__(
9094
self,
@@ -232,32 +236,13 @@ def update(
232236
def _transform_nested_list(lst: List[List[Dict]]):
233237
return [{k: v for d in single_action_dicts for k, v in d.items()} for single_action_dicts in zip(*lst)]
234238

235-
@staticmethod
236-
def _is_so_standard_action(value: Any) -> bool:
237-
# Probability ProbabilityWeight
238-
return isinstance(value, float) or (isinstance(value, tuple) and isinstance(value[0], float))
239-
240-
@staticmethod
241-
def _is_so_quantitative_action(value: Any) -> bool:
242-
return isinstance(value, tuple) and isinstance(value[0], tuple)
243-
244-
@classmethod
245-
def _is_standard_action(cls, value: Any) -> bool:
246-
return cls._is_so_standard_action(value) or (isinstance(value, list) and cls._is_so_standard_action(value[0]))
247-
248-
@classmethod
249-
def _is_quantitative_action(cls, value: Any) -> bool:
250-
return cls._is_so_quantitative_action(value) or (
251-
isinstance(value, list) and cls._is_so_quantitative_action(value[0])
252-
)
253-
254239
def _get_action_probabilities(
255240
self, forbidden_actions: Optional[Set[ActionId]] = None, **kwargs
256241
) -> Union[
257-
List[Dict[UnifiedActionId, Probability]],
258-
List[Dict[UnifiedActionId, ProbabilityWeight]],
259-
List[Dict[UnifiedActionId, MOProbability]],
260-
List[Dict[UnifiedActionId, MOProbabilityWeight]],
242+
List[Dict[ActionId, Union[Probability, QuantitativeProbability]]],
243+
List[Dict[ActionId, Union[ProbabilityWeight, QuantitativeProbabilityWeight]]],
244+
List[Dict[ActionId, Union[MOProbability, QuantitativeMOProbability]]],
245+
List[Dict[ActionId, Union[MOProbabilityWeight, QuantitativeMOProbabilityWeight]]],
261246
]:
262247
"""
263248
Get the probability of getting a positive reward for each action.
@@ -280,34 +265,9 @@ def _get_action_probabilities(
280265
action: model.sample_proba(**kwargs) for action, model in self.actions.items() if action in valid_actions
281266
}
282267
# Handle standard actions for which the value is a (probability, weight) tuple
283-
actions_transformations = [
284-
[{key: proba} for proba in value]
285-
for key, value in action_probabilities.items()
286-
if self._is_standard_action(value[0])
287-
]
288-
actions_transformations = self._transform_nested_list(actions_transformations)
289-
# Handle quantitative actions, for which the value is a tuple of
290-
# tuples of (quantity, (probability, weight) or probability)
291-
quantitative_actions_transformations = [
292-
[{(key, quantity): proba for quantity, proba in sample} for sample in value]
293-
for key, value in action_probabilities.items()
294-
if self._is_quantitative_action(value[0])
295-
]
296-
quantitative_actions_transformations = self._transform_nested_list(quantitative_actions_transformations)
297-
if not actions_transformations and not quantitative_actions_transformations:
298-
return []
299-
if not actions_transformations: # No standard actions
300-
actions_transformations = [dict() for _ in range(len(quantitative_actions_transformations))]
301-
if not quantitative_actions_transformations: # No quantitative actions
302-
quantitative_actions_transformations = [dict() for _ in range(len(actions_transformations))]
303-
if len(actions_transformations) != len(quantitative_actions_transformations):
304-
raise ValueError("The number of standard and quantitative actions should be the same.")
305-
action_probabilities = [
306-
{**actions_dict, **quantitative_actions_dict}
307-
for actions_dict, quantitative_actions_dict in zip(
308-
actions_transformations, quantitative_actions_transformations
309-
)
310-
]
268+
actions_transformations = [[{key: proba} for proba in value] for key, value in action_probabilities.items()]
269+
action_probabilities = self._transform_nested_list(actions_transformations)
270+
311271
return action_probabilities
312272

313273
@abstractmethod
@@ -399,7 +359,7 @@ def _select_epsilon_greedy_action(
399359
if self.default_action:
400360
selected_action = self.default_action
401361
else:
402-
actions = list(set(a[0] if isinstance(a, tuple) else a for a in p.keys()))
362+
actions = list(p.keys())
403363
selected_action = random.choice(actions)
404364
if isinstance(self.actions[selected_action], QuantitativeModel):
405365
selected_action = (
@@ -463,7 +423,7 @@ def update_old_state(
463423
state["actions_manager"]["actions"] = state.pop("actions")
464424
state["actions_manager"]["delta"] = delta
465425

466-
for key in cls.deprecated_adwin_keys:
426+
for key in cls._deprecated_adwin_keys:
467427
if key in state["actions_manager"]:
468428
state["actions_manager"].pop(key)
469429

@@ -496,10 +456,10 @@ def from_old_state(
496456

497457
state_dict = json.loads(state)
498458
if ("version" in state_dict) and (
499-
version.parse(state_dict["version"]) >= version.parse(cls.current_supported_version_th)
459+
version.parse(state_dict["version"]) >= version.parse(cls._current_supported_version_th)
500460
):
501461
raise ValueError(
502-
f"The state is expected to be in the old format of PyBandits < {cls.current_supported_version_th}."
462+
f"The state is expected to be in the old format of PyBandits < {cls._current_supported_version_th}."
503463
)
504464
state_dict = cls.update_old_state(state_dict, delta)
505465
state = json.dumps(state_dict)

pybandits/offline_policy_evaluator.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1023,9 +1023,13 @@ def estimate_policy(
10231023
# finalize the dataframe shape to #samples X #mc experiments
10241024
mc_actions = pd.DataFrame(mc_actions).T
10251025

1026+
# Get unique actions that actually appear in the test set (to match validation requirements)
1027+
# The action array contains encoded indices, so we need to map them back to action IDs
1028+
unique_actions_in_test = sorted(set(self._test_data["action_ids"]))
1029+
10261030
# for each sample / each action, count the occurrence frequency during MC iteration
1027-
mc_action_counts = pd.DataFrame(0, index=mc_actions.index, columns=self._test_data["unique_actions"])
1028-
for action in self._test_data["unique_actions"]:
1031+
mc_action_counts = pd.DataFrame(0, index=mc_actions.index, columns=unique_actions_in_test)
1032+
for action in unique_actions_in_test:
10291033
mc_action_counts[action] = (mc_actions == action).sum(axis=1)
10301034
estimated_policy = mc_action_counts / n_mc_experiments
10311035

@@ -1110,6 +1114,7 @@ def evaluate(
11101114
axis=0,
11111115
)
11121116
if save_path:
1117+
os.makedirs(save_path, exist_ok=True)
11131118
multi_objective_estimated_policy_value_df.to_csv(os.path.join(save_path, "estimated_policy_value.csv"))
11141119

11151120
if visualize:

0 commit comments

Comments
 (0)