Skip to content

Commit 60c6b46

Browse files
authored
Add private numpy array attributes and initialize them in StudentTArray (#112)
### Changes: * Introduced private attributes `_mu_array`, `_sigma_array`, and `_nu_array` to store numpy arrays. * Added `model_post_init` method to convert lists to numpy arrays during initialization. * Updated `shape` and `params` properties to utilize the new private attributes for improved performance and clarity.
1 parent a97b0d7 commit 60c6b46

File tree

2 files changed

+48
-4
lines changed

2 files changed

+48
-4
lines changed

pybandits/model.py

Lines changed: 47 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,20 @@ class StudentTArray(PyBanditsBaseModel):
273273
sigma: Union[List[NonNegativeFloat], List[List[NonNegativeFloat]]]
274274
nu: Union[List[PositiveFloat], List[List[PositiveFloat]]]
275275

276+
_mu_array: np.ndarray = PrivateAttr()
277+
_sigma_array: np.ndarray = PrivateAttr()
278+
_nu_array: np.ndarray = PrivateAttr()
279+
_params: Dict[str, np.ndarray] = PrivateAttr()
280+
281+
def __eq__(self, other: Any) -> bool:
282+
if not isinstance(other, StudentTArray):
283+
return False
284+
return (
285+
np.all(self._mu_array == other._mu_array)
286+
and np.all(self._sigma_array == other._sigma_array)
287+
and np.all(self._nu_array == other._nu_array)
288+
)
289+
276290
@staticmethod
277291
def maybe_convert_list_to_array(input_list: Union[List[float], List[List[float]]]) -> bool:
278292
if len(input_list) == 0:
@@ -336,13 +350,43 @@ def cold_start(
336350
nu = np.full(shape, nu)
337351
return cls(mu=mu, sigma=sigma, nu=nu)
338352

353+
def model_post_init(self, __context: Any) -> None:
354+
"""
355+
Initialize private numpy array attributes by converting lists to arrays once at initialization.
356+
357+
Parameters
358+
----------
359+
__context : Any
360+
Pydantic context (unused).
361+
"""
362+
self._mu_array = np.array(self.mu)
363+
self._sigma_array = np.array(self.sigma)
364+
self._nu_array = np.array(self.nu)
365+
self._params = dict(mu=self._mu_array, sigma=self._sigma_array, nu=self._nu_array)
366+
339367
@property
340368
def shape(self) -> Tuple[PositiveInt, ...]:
341-
return np.array(self.mu).shape
369+
"""
370+
Get the shape of the mu array.
371+
372+
Returns
373+
-------
374+
Tuple[PositiveInt, ...]
375+
The shape of the mu array.
376+
"""
377+
return self._mu_array.shape
342378

343379
@property
344-
def params(self):
345-
return dict(mu=np.array(self.mu), sigma=np.array(self.sigma), nu=np.array(self.nu))
380+
def params(self) -> Dict[str, np.ndarray]:
381+
"""
382+
Get the parameters as a dictionary of numpy arrays.
383+
384+
Returns
385+
-------
386+
Dict[str, np.ndarray]
387+
Dictionary containing 'mu', 'sigma', and 'nu' as numpy arrays.
388+
"""
389+
return self._params
346390

347391

348392
class BnnLayerParams(PyBanditsBaseModel):

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "pybandits"
3-
version = "4.0.15"
3+
version = "4.0.16"
44
description = "Python Multi-Armed Bandit Library"
55
authors = [
66
"Dario d'Andrea <dariod@playtika.com>",

0 commit comments

Comments
 (0)