Skip to content

Commit a97b0d7

Browse files
authored
Refactor BNN for Acceleration (#110)
### Changes: * Updated StudentTArray input validation method to avoid redundant conversions from array to list and vice versa. * Optimized the calculation of linear transformations in the Bayesian Neural Network using `np.einsum` for better performance (at least X2 speedup). * Optimized posterior update in VI (at least 1.1 speedup).
1 parent d4f4a28 commit a97b0d7

File tree

10 files changed

+241
-106
lines changed

10 files changed

+241
-106
lines changed

docs/src/tutorials/bnn.ipynb

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,9 @@
138138
"metadata": {},
139139
"outputs": [],
140140
"source": [
141-
"blr = BayesianLogisticRegression.cold_start(n_features=2, update_method=\"VI\", update_kwargs={\"fit\": {\"n\": 40000}})"
141+
"blr = BayesianLogisticRegression.cold_start(\n",
142+
" n_features=2, update_method=\"VI\", update_kwargs={\"fit\": {\"n\": 10000}, \"batch_size\": 256, \"optimizer_type\": \"adam\"}\n",
143+
")"
142144
]
143145
},
144146
{
@@ -346,7 +348,7 @@
346348
"dist_params_init = {\"mu\": 0, \"sigma\": 1, \"nu\": 5}\n",
347349
"bnn = BayesianNeuralNetwork.cold_start(\n",
348350
" n_features=2,\n",
349-
" hidden_dim_list=[5, 5],\n",
351+
" hidden_dim_list=[16, 16],\n",
350352
" update_method=\"VI\",\n",
351353
" dist_params_init=dist_params_init,\n",
352354
" update_kwargs={\"fit\": {\"n\": 10000}, \"batch_size\": 256, \"optimizer_type\": \"adam\"},\n",
@@ -363,11 +365,7 @@
363365
{
364366
"cell_type": "code",
365367
"execution_count": null,
366-
"metadata": {
367-
"jupyter": {
368-
"is_executing": true
369-
}
370-
},
368+
"metadata": {},
371369
"outputs": [],
372370
"source": [
373371
"bnn.update(context=x_train, rewards=y_train)"

docs/src/tutorials/cmab.ipynb

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,9 +93,16 @@
9393
"layer_params = BnnLayerParams(weight=weight, bias=bias)\n",
9494
"model_params = BnnParams(bnn_layer_params=[layer_params])\n",
9595
"\n",
96+
"update_method = \"VI\"\n",
97+
"update_kwargs = {\"fit\": {\"n\": 100}, \"batch_size\": 128, \"optimizer_type\": \"adam\"}\n",
98+
"\n",
9699
"actions = {\n",
97-
" \"a1\": BayesianLogisticRegression(model_params=model_params),\n",
98-
" \"a2\": BayesianLogisticRegression(model_params=model_params),\n",
100+
" \"a1\": BayesianLogisticRegression(\n",
101+
" model_params=model_params, update_method=update_method, update_kwargs=update_kwargs\n",
102+
" ),\n",
103+
" \"a2\": BayesianLogisticRegression(\n",
104+
" model_params=model_params, update_method=update_method, update_kwargs=update_kwargs\n",
105+
" ),\n",
99106
"}"
100107
]
101108
},

docs/src/tutorials/cmab_simulator.ipynb

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -111,15 +111,23 @@
111111
" return model_params\n",
112112
"\n",
113113
"\n",
114+
"update_method = \"VI\"\n",
115+
"update_kwargs = {\"fit\": {\"n\": 100}, \"batch_size\": 128, \"optimizer_type\": \"adam\"}\n",
114116
"actions = {\n",
115117
" \"a1\": BayesianLogisticRegression(\n",
116-
" model_params=create_model_params(n_features=n_features, bias_mu=1, bias_sigma=2), update_method=\"VI\"\n",
118+
" model_params=create_model_params(n_features=n_features, bias_mu=1, bias_sigma=2),\n",
119+
" update_method=update_method,\n",
120+
" update_kwargs=update_kwargs,\n",
117121
" ),\n",
118122
" \"a2\": BayesianLogisticRegression(\n",
119-
" model_params=create_model_params(n_features=n_features, bias_mu=1, bias_sigma=2), update_method=\"VI\"\n",
123+
" model_params=create_model_params(n_features=n_features, bias_mu=1, bias_sigma=2),\n",
124+
" update_method=update_method,\n",
125+
" update_kwargs=update_kwargs,\n",
120126
" ),\n",
121127
" \"a3\": BayesianLogisticRegression(\n",
122-
" model_params=create_model_params(n_features=n_features, bias_mu=1, bias_sigma=2), update_method=\"VI\"\n",
128+
" model_params=create_model_params(n_features=n_features, bias_mu=1, bias_sigma=2),\n",
129+
" update_method=update_method,\n",
130+
" update_kwargs=update_kwargs,\n",
123131
" ),\n",
124132
"}\n",
125133
"# init contextual Multi-Armed Bandit model\n",

docs/src/tutorials/cmab_zooming.ipynb

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,13 @@
5656
"n_features = 3\n",
5757
"# Define number of segments for each action\n",
5858
"n_max_segments = 16 # Maximum number of segments for each action\n",
59+
"\n",
5960
"# Define cold start parameters for the base model\n",
61+
"update_method = \"VI\" # Variational Inference for Bayesian updates\n",
62+
"update_kwargs = {\"fit\": {\"n\": 1000}, \"batch_size\": 256, \"optimizer_type\": \"adam\"}\n",
6063
"base_model_cold_start_kwargs = {\n",
6164
" \"n_features\": n_features, # Number of context features\n",
62-
" \"update_method\": \"VI\", # Variational Inference for Bayesian updates\n",
65+
" \"update_method\": \"VI\",\n",
6366
"}\n",
6467
"\n",
6568
"\n",

docs/src/tutorials/ope.ipynb

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,14 @@
44
"cell_type": "markdown",
55
"metadata": {},
66
"source": [
7+
"# Offline Policy Evaluation"
8+
]
9+
},
10+
{
11+
"cell_type": "markdown",
12+
"metadata": {},
13+
"source": [
14+
"\n",
715
"### Introduction\n",
816
"\n",
917
"This notebook demonstrates the use of offline policy evaluation for MABs.\n",

pybandits/model.py

Lines changed: 58 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ class StudentTArray(PyBanditsBaseModel):
274274
nu: Union[List[PositiveFloat], List[List[PositiveFloat]]]
275275

276276
@staticmethod
277-
def convert_list_to_array(input_list: Union[List[float], List[List[float]]]) -> bool:
277+
def maybe_convert_list_to_array(input_list: Union[List[float], List[List[float]]]) -> bool:
278278
if len(input_list) == 0:
279279
is_valid_input = False
280280

@@ -292,19 +292,16 @@ def convert_list_to_array(input_list: Union[List[float], List[List[float]]]) ->
292292
else:
293293
raise ValueError("Input list must be a 1D or 2D list with the same length for all inner lists.")
294294

295-
@model_validator(mode="after")
295+
@model_validator(mode="before")
296296
@classmethod
297297
def validate_input_shapes(cls, values):
298-
if pydantic_version == PYDANTIC_VERSION_1:
299-
mu_arr = cls.convert_list_to_array(values.get("mu"))
300-
sigma_arr = cls.convert_list_to_array(values.get("sigma"))
301-
nu_arr = cls.convert_list_to_array(values.get("nu"))
302-
elif pydantic_version == PYDANTIC_VERSION_2:
303-
mu_arr = cls.convert_list_to_array(values.mu)
304-
sigma_arr = cls.convert_list_to_array(values.sigma)
305-
nu_arr = cls.convert_list_to_array(values.nu)
306-
else:
307-
raise ValueError(f"Unsupported pydantic version: {pydantic_version}")
298+
mu_input = values.get("mu")
299+
sigma_input = values.get("sigma")
300+
nu_input = values.get("nu")
301+
302+
mu_arr = cls.maybe_convert_list_to_array(mu_input)
303+
sigma_arr = cls.maybe_convert_list_to_array(sigma_input)
304+
nu_arr = cls.maybe_convert_list_to_array(nu_input)
308305

309306
if (mu_arr.shape != sigma_arr.shape) or (mu_arr.shape != nu_arr.shape):
310307
raise ValueError(
@@ -315,6 +312,9 @@ def validate_input_shapes(cls, values):
315312
if any(dim_len == 0 for dim_len in mu_arr.shape):
316313
raise ValueError("mu, sigma, and nu must have at least one element in every dimension.")
317314

315+
for key, value in zip(["mu", "sigma", "nu"], [mu_input, sigma_input, nu_input]):
316+
if isinstance(value, np.ndarray):
317+
values[key] = value.tolist()
318318
return values
319319

320320
@classmethod
@@ -331,9 +331,9 @@ def cold_start(
331331
if any(dim_len == 0 for dim_len in shape):
332332
raise ValueError("shape of mu, sigma, and nu must have at least one element in every dimension.")
333333

334-
mu = np.full(shape, mu).tolist()
335-
sigma = np.full(shape, sigma).tolist()
336-
nu = np.full(shape, nu).tolist()
334+
mu = np.full(shape, mu)
335+
sigma = np.full(shape, sigma)
336+
nu = np.full(shape, nu)
337337
return cls(mu=mu, sigma=sigma, nu=nu)
338338

339339
@property
@@ -449,9 +449,6 @@ class BaseBayesianNeuralNetwork(Model, ABC):
449449
)
450450

451451
_default_variational_inference_fit_kwargs: ClassVar[dict] = dict(method="advi")
452-
_default_variational_inference_trace_kwargs: ClassVar[dict] = dict(
453-
draws=1000, progressbar=False, return_inferencedata=False
454-
)
455452

456453
_approx_history: np.ndarray = PrivateAttr(None)
457454

@@ -470,12 +467,7 @@ def arrange_update_kwargs(cls, values):
470467
update_kwargs = dict()
471468

472469
if update_method == "VI":
473-
update_kwargs["trace"] = {
474-
**cls._default_variational_inference_trace_kwargs,
475-
**update_kwargs.get("trace", {}),
476-
}
477470
update_kwargs["fit"] = {**cls._default_variational_inference_fit_kwargs, **update_kwargs.get("fit", {})}
478-
479471
optimizer_type = update_kwargs.get("optimizer_type", None)
480472

481473
if optimizer_type is not None:
@@ -507,10 +499,6 @@ def arrange_update_kwargs(self):
507499
self.update_kwargs = dict()
508500

509501
if self.update_method == "VI":
510-
self.update_kwargs["trace"] = {
511-
**self._default_variational_inference_trace_kwargs,
512-
**self.update_kwargs.get("trace", {}),
513-
}
514502
self.update_kwargs["fit"] = {
515503
**self._default_variational_inference_fit_kwargs,
516504
**self.update_kwargs.get("fit", {}),
@@ -673,14 +661,14 @@ def create_update_model(
673661
3. Apply sigmoid activation at the output
674662
4. Use Bernoulli likelihood for binary classification
675663
"""
676-
664+
y = np.array(y, dtype=np.int32)
677665
with PymcModel() as _model:
678666
# Define data variables
679667
if batch_size is None:
680668
bnn_output = Data("bnn_output", y)
681669
bnn_input = Data("bnn_input", x)
682670
else:
683-
bnn_input, bnn_output = Minibatch(x, np.array(y).astype("int32"), batch_size=batch_size)
671+
bnn_input, bnn_output = Minibatch(x, y, batch_size=batch_size)
684672

685673
next_layer_input = bnn_input
686674

@@ -750,7 +738,7 @@ def sample_proba(self, context: np.ndarray) -> List[ProbabilityWeight]:
750738
)
751739

752740
# Linear transformation
753-
linear_transform = np.sum(next_layer_input[..., None] * w, axis=1) + b
741+
linear_transform = np.einsum("...i,...ij->...j", next_layer_input, w) + b
754742

755743
# Apply activation function (tanh for hidden layers, sigmoid for output)
756744
if layer_ind < len(self.model_params.bnn_layer_params) - 1:
@@ -797,29 +785,53 @@ def _update(self, context: np.ndarray, rewards: List[BinaryReward]):
797785
else:
798786
approx = fit(**update_kwargs["fit"])
799787

800-
trace = approx.sample(**update_kwargs["trace"])
801788
self._approx_history = approx.hist
789+
approx_mean_eval = approx.mean.eval()
790+
approx_std_eval = approx.std.eval()
791+
approx_posterior_mapping = {
792+
param: (approx_mean_eval[slice_], approx_std_eval[slice_])
793+
for (param, (_, slice_, _, _)) in approx.ordering.items()
794+
}
795+
for layer_ind, layer_params in enumerate(self.model_params.bnn_layer_params):
796+
weight_layer_params_name, bias_layer_params_name = self.get_layer_params_name(layer_ind)
797+
w_shape = layer_params.weight.shape
798+
b_shape = layer_params.bias.shape
799+
w_mu = approx_posterior_mapping[weight_layer_params_name][0].reshape(w_shape)
800+
w_sigma = approx_posterior_mapping[weight_layer_params_name][1].reshape(w_shape)
801+
b_mu = approx_posterior_mapping[bias_layer_params_name][0].reshape(b_shape)
802+
b_sigma = approx_posterior_mapping[bias_layer_params_name][1].reshape(b_shape)
803+
layer_params.weight = StudentTArray(
804+
mu=w_mu, sigma=w_sigma, nu=self.model_params.bnn_layer_params[layer_ind].weight.nu
805+
)
806+
layer_params.bias = StudentTArray(
807+
mu=b_mu, sigma=b_sigma, nu=self.model_params.bnn_layer_params[layer_ind].bias.nu
808+
)
809+
self.model_params.bnn_layer_params[layer_ind] = layer_params
802810
elif self.update_method == "MCMC":
803811
# MCMC
804812
trace = sample(**self.update_kwargs["trace"])
813+
814+
for layer_ind, layer_params in enumerate(self.model_params.bnn_layer_params):
815+
weight_layer_params_name, bias_layer_params_name = self.get_layer_params_name(layer_ind)
816+
817+
w_mu = np.mean(trace[weight_layer_params_name], axis=0)
818+
w_sigma = np.std(trace[weight_layer_params_name], axis=0)
819+
layer_params.weight = StudentTArray(
820+
mu=w_mu.tolist(),
821+
sigma=w_sigma.tolist(),
822+
nu=self.model_params.bnn_layer_params[layer_ind].weight.nu,
823+
)
824+
825+
b_mu = np.mean(trace[bias_layer_params_name], axis=0)
826+
b_sigma = np.std(trace[bias_layer_params_name], axis=0)
827+
layer_params.bias = StudentTArray(
828+
mu=b_mu.tolist(),
829+
sigma=b_sigma.tolist(),
830+
nu=self.model_params.bnn_layer_params[layer_ind].bias.nu,
831+
)
805832
else:
806833
raise ValueError("Invalid update method.")
807834

808-
for layer_ind, layer_params in enumerate(self.model_params.bnn_layer_params):
809-
weight_layer_params_name, bias_layer_params_name = self.get_layer_params_name(layer_ind)
810-
811-
w_mu = np.mean(trace[weight_layer_params_name], axis=0)
812-
w_sigma = np.std(trace[weight_layer_params_name], axis=0)
813-
layer_params.weight = StudentTArray(
814-
mu=w_mu.tolist(), sigma=w_sigma.tolist(), nu=self.model_params.bnn_layer_params[layer_ind].weight.nu
815-
)
816-
817-
b_mu = np.mean(trace[bias_layer_params_name], axis=0)
818-
b_sigma = np.std(trace[bias_layer_params_name], axis=0)
819-
layer_params.bias = StudentTArray(
820-
mu=b_mu.tolist(), sigma=b_sigma.tolist(), nu=self.model_params.bnn_layer_params[layer_ind].bias.nu
821-
)
822-
823835
@classmethod
824836
def cold_start(
825837
cls,

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "pybandits"
3-
version = "4.0.14"
3+
version = "4.0.15"
44
description = "Python Multi-Armed Bandit Library"
55
authors = [
66
"Dario d'Andrea <dariod@playtika.com>",

0 commit comments

Comments
 (0)