Skip to content

Commit 4624372

Browse files
AnastasiiaKabeshovaAnastasiia
andauthored
Feature/bnn enhancements (#113)
Add configurable activation and residuals to BNN model * Addd configurable activation function * Add residual connection when layer dimension stays the same or expands * Make module-level activation functions for pickling compatibility * Improve MAB tests Co-authored-by: Anastasiia <anastasiiak@playtika.com>
1 parent 60c6b46 commit 4624372

File tree

7 files changed

+307
-45
lines changed

7 files changed

+307
-45
lines changed

pybandits/mab.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -135,8 +135,10 @@ def model_post_init(self, __context: Any) -> None:
135135
raise ValueError("Adaptive window requires epsilon greedy super strategy with not default action.")
136136
if not self.epsilon and self.default_action:
137137
raise AttributeError("A default action should only be defined when epsilon is defined.")
138-
if self.default_action and self.default_action not in self.actions:
139-
raise AttributeError("The default action must be valid action defined in the actions set.")
138+
if self.default_action:
139+
action_id = self.default_action[0] if isinstance(self.default_action, tuple) else self.default_action
140+
if action_id not in self.actions:
141+
raise AttributeError("The default action must be valid action defined in the actions set.")
140142
if (
141143
self.default_action
142144
and isinstance(self.default_action, tuple)
@@ -174,8 +176,10 @@ def _get_valid_actions(self, forbidden_actions: Optional[Set[ActionId]]) -> Set[
174176
valid_actions = action_ids - forbidden_actions
175177
if len(valid_actions) == 0:
176178
raise ValueError("All actions are forbidden. You must allow at least 1 action.")
177-
if self.default_action and self.default_action not in valid_actions:
178-
raise ValueError("The default action is forbidden.")
179+
if self.default_action:
180+
action_id = self.default_action[0] if isinstance(self.default_action, tuple) else self.default_action
181+
if action_id not in valid_actions:
182+
raise ValueError("The default action is forbidden.")
179183

180184
return valid_actions
181185

@@ -380,8 +384,17 @@ def _select_epsilon_greedy_action(
380384
"""
381385

382386
if self.epsilon:
383-
if self.default_action and self.default_action not in p.keys():
384-
raise KeyError(f"Default action {self.default_action} not in actions.")
387+
if self.default_action:
388+
if isinstance(self.default_action, tuple):
389+
# For quantitative models, check if any key has the same action_id
390+
default_action_id = self.default_action[0]
391+
if not any(
392+
(isinstance(key, tuple) and key[0] == default_action_id) or key == default_action_id
393+
for key in p.keys()
394+
):
395+
raise KeyError(f"Default action {self.default_action} not in actions.")
396+
elif self.default_action not in p.keys():
397+
raise KeyError(f"Default action {self.default_action} not in actions.")
385398
if np.random.binomial(1, self.epsilon):
386399
if self.default_action:
387400
selected_action = self.default_action

pybandits/model.py

Lines changed: 125 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from pymc import Bernoulli, Data, Deterministic, Minibatch, fit, math, sample
3333
from pymc import Model as PymcModel
3434
from pymc import StudentT as PymcStudentT
35+
from scipy.special import erf
3536
from scipy.stats import t
3637
from typing_extensions import Self
3738

@@ -53,6 +54,33 @@
5354
)
5455

5556
UpdateMethods = Literal["VI", "MCMC"]
57+
ActivationFunctions = Literal["tanh", "relu", "sigmoid", "gelu"]
58+
59+
60+
# Module-level activation functions for pickling compatibility
61+
def _pymc_relu(x):
62+
"""ReLU activation function for PyMC."""
63+
return math.maximum(0, x)
64+
65+
66+
def _pymc_gelu(x):
67+
"""GELU activation function for PyMC."""
68+
return 0.5 * x * (1 + math.erf(x / np.sqrt(2.0)))
69+
70+
71+
def _numpy_relu(x: np.ndarray) -> np.ndarray:
72+
"""ReLU activation function for NumPy."""
73+
return np.maximum(0, x)
74+
75+
76+
def _numpy_gelu(x: np.ndarray) -> np.ndarray:
77+
"""GELU activation function for NumPy."""
78+
return 0.5 * x * (1 + erf(x / np.sqrt(2.0)))
79+
80+
81+
def _stable_sigmoid(x):
82+
"""Stable sigmoid activation function for NumPy."""
83+
return np.where(x >= 0, 1 / (1 + np.exp(-x)), np.exp(x) / (1 + np.exp(x)))
5684

5785

5886
class Model(BaseModelSO, ABC):
@@ -452,12 +480,19 @@ class BaseBayesianNeuralNetwork(Model, ABC):
452480
update_kwargs : Optional[dict], optional
453481
A dictionary of keyword arguments for the update method. For MCMC, it contains 'trace' settings.
454482
For VI, it contains both 'trace' and 'fit' settings.
483+
activation : str, optional
484+
The activation function to use for hidden layers. Supported values are: "tanh", "relu", "sigmoid", "gelu" (default is "tanh").
485+
use_residual_connections : bool, optional
486+
Whether to use residual connections in the network. Residual connections are only added when
487+
the layer output dimension is greater than or equal to the input dimension (default is False).
455488
456489
Notes
457490
-----
458-
- The model uses tanh activation for hidden layers and sigmoid activation for the output layer.
491+
- The model uses the specified activation function for hidden layers and sigmoid activation for the output layer.
459492
- The output layer is designed for binary classification tasks, with probabilities modeled
460493
using a Bernoulli likelihood.
494+
- When use_residual_connections is True, residual connections are added to hidden layers where the output
495+
dimension is >= input dimension. For expanding dimensions, the residual is zero-padded.
461496
"""
462497

463498
model_params: BnnParams
@@ -477,9 +512,23 @@ class BaseBayesianNeuralNetwork(Model, ABC):
477512
"adam",
478513
"adamax",
479514
]
515+
_pymc_activations: ClassVar[dict] = {
516+
"tanh": math.tanh,
517+
"relu": _pymc_relu,
518+
"sigmoid": math.sigmoid,
519+
"gelu": _pymc_gelu,
520+
}
521+
_numpy_activations: ClassVar[dict] = {
522+
"tanh": np.tanh,
523+
"relu": _numpy_relu,
524+
"sigmoid": _stable_sigmoid,
525+
"gelu": _numpy_gelu,
526+
}
480527

481528
update_method: str = "VI"
482529
update_kwargs: Optional[dict] = None
530+
activation: ActivationFunctions = "tanh"
531+
use_residual_connections: bool = False
483532

484533
_default_mcmc_trace_kwargs: ClassVar[dict] = dict(
485534
tune=500,
@@ -495,6 +544,8 @@ class BaseBayesianNeuralNetwork(Model, ABC):
495544
_default_variational_inference_fit_kwargs: ClassVar[dict] = dict(method="advi")
496545

497546
_approx_history: np.ndarray = PrivateAttr(None)
547+
_numpy_activation_fn: Callable = PrivateAttr(None)
548+
_pymc_activation_fn: Callable = PrivateAttr(None)
498549

499550
class Config:
500551
arbitrary_types_allowed = True
@@ -569,6 +620,15 @@ def arrange_update_kwargs(self):
569620
else:
570621
raise ValueError(f"Unsupported pydantic version: {pydantic_version}")
571622

623+
@field_validator("activation")
624+
@classmethod
625+
def validate_activation(cls, v):
626+
if v not in cls._pymc_activations.keys():
627+
raise ValueError(
628+
f"Invalid activation function: {v}. Supported activations are: {list(cls._pymc_activations.keys())}"
629+
)
630+
return v
631+
572632
@property
573633
def approx_history(self) -> Optional[np.ndarray]:
574634
return self._approx_history
@@ -585,10 +645,6 @@ def optimizer(self) -> Callable:
585645

586646
return _optimizer
587647

588-
@classmethod
589-
def _stable_sigmoid(cls, x):
590-
return np.where(x >= 0, 1 / (1 + np.exp(-x)), np.exp(x) / (1 + np.exp(x)))
591-
592648
@classmethod
593649
def get_layer_params_name(cls, layer_ind: PositiveInt) -> Tuple[str, str]:
594650
weight_layer_params_name = f"{cls._weight_var_name}_{layer_ind}"
@@ -676,6 +732,14 @@ def input_dim(self) -> PositiveInt:
676732
"""
677733
return self.model_params.bnn_layer_params[0].weight.shape[0]
678734

735+
def model_post_init(self, __context: Any) -> None:
736+
"""
737+
Initialize activation function PrivateAttr based on the activation setting.
738+
"""
739+
# Initialize activation functions (always set to ensure they're available after model_copy)
740+
self._numpy_activation_fn = self._numpy_activations[self.activation]
741+
self._pymc_activation_fn = self._pymc_activations[self.activation]
742+
679743
def create_update_model(
680744
self, x: ArrayLike, y: Union[List[BinaryReward], np.ndarray], batch_size: Optional[PositiveInt] = None
681745
) -> PymcModel:
@@ -720,6 +784,8 @@ def create_update_model(
720784
w_shape = layer_params.weight.shape # without it n_features = 1 doesn't work
721785
b_shape = layer_params.bias.shape
722786
weight_layer_params_name, bias_layer_params_name = self.get_layer_params_name(layer_ind)
787+
input_dim = w_shape[0]
788+
output_dim = w_shape[1]
723789

724790
# For training, use shared weights and biases
725791
w = PymcStudentT(
@@ -732,7 +798,20 @@ def create_update_model(
732798
linear_transform = math.dot(next_layer_input, w) + b
733799

734800
if layer_ind < len(self.model_params.bnn_layer_params) - 1:
735-
next_layer_input = math.tanh(linear_transform)
801+
activated_output = self._pymc_activation_fn(linear_transform)
802+
803+
# Add residual connection if enabled and dimensions allow
804+
if self.use_residual_connections and output_dim >= input_dim:
805+
if output_dim == input_dim:
806+
next_layer_input = activated_output + next_layer_input
807+
else:
808+
residual_padded = math.concatenate(
809+
[next_layer_input, math.zeros((next_layer_input.shape[0], output_dim - input_dim))],
810+
axis=1,
811+
)
812+
next_layer_input = activated_output + residual_padded
813+
else:
814+
next_layer_input = activated_output
736815

737816
# Final output processing
738817
logit = Deterministic(self._logit_var_name, linear_transform.squeeze())
@@ -769,6 +848,8 @@ def sample_proba(self, context: np.ndarray) -> List[ProbabilityWeight]:
769848
# Sample weights and biases from StudentT distributions
770849
w_params = layer_params.weight.params
771850
b_params = layer_params.bias.params
851+
input_dim = layer_params.weight.shape[0]
852+
output_dim = layer_params.weight.shape[1]
772853

773854
# Sample weights and biases using scipy.stats
774855
w = t.rvs(
@@ -784,13 +865,25 @@ def sample_proba(self, context: np.ndarray) -> List[ProbabilityWeight]:
784865
# Linear transformation
785866
linear_transform = np.einsum("...i,...ij->...j", next_layer_input, w) + b
786867

787-
# Apply activation function (tanh for hidden layers, sigmoid for output)
868+
# Apply activation function for hidden layers, sigmoid for output
788869
if layer_ind < len(self.model_params.bnn_layer_params) - 1:
789-
next_layer_input = np.tanh(linear_transform)
870+
activated_output = self._numpy_activation_fn(linear_transform)
871+
872+
# Add residual connection if enabled and dimensions allow
873+
if self.use_residual_connections and output_dim >= input_dim:
874+
if output_dim == input_dim:
875+
next_layer_input = activated_output + next_layer_input
876+
else:
877+
residual_padded = np.pad(
878+
next_layer_input, ((0, 0), (0, output_dim - input_dim)), mode="constant", constant_values=0
879+
)
880+
next_layer_input = activated_output + residual_padded
881+
else:
882+
next_layer_input = activated_output
790883
else:
791884
# Output layer - apply sigmoid
792885
weighted_sum = linear_transform.squeeze(-1)
793-
prob = self._stable_sigmoid(weighted_sum)
886+
prob = _stable_sigmoid(weighted_sum)
794887

795888
return list(zip(prob, weighted_sum))
796889

@@ -884,6 +977,8 @@ def cold_start(
884977
update_method: UpdateMethods = "VI",
885978
update_kwargs: Optional[dict] = None,
886979
dist_params_init: Optional[Dict[str, float]] = None,
980+
activation: ActivationFunctions = "tanh",
981+
use_residual_connections: bool = False,
887982
**kwargs,
888983
) -> Self:
889984
"""
@@ -901,6 +996,10 @@ def cold_start(
901996
Additional keyword arguments for the update method. Default is None.
902997
dist_params_init : Optional[Dict[str, float]], optional
903998
Initial distribution parameters for the network weights and biases. Default is None.
999+
activation : str
1000+
The activation function to use for hidden layers. Supported values are: "tanh", "relu", "sigmoid", "gelu" (default is "tanh").
1001+
use_residual_connections : bool
1002+
Whether to use residual connections in the network (default is False).
9041003
**kwargs
9051004
Additional keyword arguments for the BayesianNeuralNetwork constructor.
9061005
@@ -916,7 +1015,14 @@ def cold_start(
9161015
model_params = cls.create_model_params(
9171016
n_features=n_features, hidden_dim_list=hidden_dim_list, **dist_params_init
9181017
)
919-
return cls(model_params=model_params, update_method=update_method, update_kwargs=update_kwargs, **kwargs)
1018+
return cls(
1019+
model_params=model_params,
1020+
update_method=update_method,
1021+
update_kwargs=update_kwargs,
1022+
activation=activation,
1023+
use_residual_connections=use_residual_connections,
1024+
**kwargs,
1025+
)
9201026

9211027
def _reset(self):
9221028
"""
@@ -1001,6 +1107,8 @@ def cold_start(
10011107
update_method: UpdateMethods = "VI",
10021108
update_kwargs: Optional[dict] = None,
10031109
dist_params_init: Optional[Dict[str, float]] = None,
1110+
activation: ActivationFunctions = "tanh",
1111+
use_residual_connections: bool = False,
10041112
**kwargs,
10051113
) -> Self:
10061114
"""
@@ -1020,6 +1128,10 @@ def cold_start(
10201128
Additional keyword arguments for the update method.
10211129
dist_params_init : Optional[Dict[str, float]], optional
10221130
Initial distribution parameters for the network weights and biases.
1131+
activation : str
1132+
The activation function to use for hidden layers. Supported values are: "tanh", "relu", "sigmoid", "gelu" (default is "tanh").
1133+
use_residual_connections : bool
1134+
Whether to use residual connections in the network (default is False).
10231135
**kwargs
10241136
Additional keyword arguments.
10251137
@@ -1028,13 +1140,16 @@ def cold_start(
10281140
BayesianNeuralNetworkMO
10291141
A multi-objective BNN with the specified number of objectives.
10301142
"""
1143+
10311144
models = [
10321145
BayesianNeuralNetwork.cold_start(
10331146
n_features=n_features,
10341147
hidden_dim_list=hidden_dim_list,
10351148
update_method=update_method,
10361149
update_kwargs=update_kwargs,
10371150
dist_params_init=dist_params_init,
1151+
activation=activation,
1152+
use_residual_connections=use_residual_connections,
10381153
)
10391154
for _ in range(n_objectives)
10401155
]

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "pybandits"
3-
version = "4.0.16"
3+
version = "4.0.17"
44
description = "Python Multi-Armed Bandit Library"
55
authors = [
66
"Dario d'Andrea <dariod@playtika.com>",

tests/test_actions_manager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,12 +92,12 @@ def test_update_with_missing_memory_delta_set(action_list):
9292
@given(
9393
data_len=st.integers(min_value=1, max_value=200),
9494
regular_kwargs=st.dictionaries(
95-
st.text().filter(lambda x: not x.endswith("_memory") and x not in ["actions", "rewards"]),
95+
st.text().filter(lambda x: not x.endswith("_memory") and x not in ["actions", "rewards", "self"]),
9696
st.integers(),
9797
min_size=1,
9898
),
9999
memory_kwargs=st.dictionaries(
100-
st.text().filter(lambda x: x not in ["actions", "rewards"]).map(lambda x: x + "_memory"),
100+
st.text().filter(lambda x: x not in ["actions", "rewards", "self"]).map(lambda x: x + "_memory"),
101101
st.integers(),
102102
min_size=1,
103103
),

tests/test_cmab.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,7 @@ def create_cmab_and_actions(
299299
action_ids, costs, n_features, hidden_dim_list, update_method, update_kwargs, n_objectives
300300
)
301301
default_action = action_ids[0] if epsilon and not delta else None
302-
if default_action and isinstance(self.model_types[0], QuantitativeModel):
302+
if default_action and isinstance(actions[default_action], QuantitativeModel):
303303
default_action = (default_action, tuple(np.random.random(actions[default_action].dimension)))
304304
epsilon = epsilon if not delta else 0.1
305305
kwargs = {

tests/test_mab.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -220,15 +220,11 @@ def test_mab_model_post_init_invalid_default_action(epsilon=0.1):
220220
def test_mab_model_post_init_quantitative_default_action_validation(epsilon=0.1):
221221
"""Test model_post_init validation for quantitative default action requirements."""
222222

223-
# This test is demonstrating that the current validation logic has an issue:
224-
# When default_action is a tuple, it checks if the entire tuple is in self.actions keys,
225-
# but actions only contains string keys. This causes the validation to fail at line 138-139
226-
# before it reaches the quantitative validation at lines 140-145.
227-
228-
# Test case: quantitative default action (tuple) with any actions will fail the basic validation
223+
# Test case: quantitative default action (tuple) with standard (non-quantitative) actions
224+
# should fail the quantitative model validation
229225
actions = {"action1": Beta(), "action2": Beta()}
230226

231-
with pytest.raises(AttributeError, match="The default action must be valid action defined in the actions set."):
227+
with pytest.raises(AttributeError, match="Quantitative default action requires a quantitative action model."):
232228
DummyMab(actions=actions, strategy=ClassicBandit(), epsilon=epsilon, default_action=("action1", (0.5, 0.5)))
233229

234230

0 commit comments

Comments
 (0)