Skip to content

Commit ccee894

Browse files
committed
Improve OPE functionality, logging, and multiprocessing
### Changes: * Add optional dependencies for xgboost and ipython * Rename `_stable_sigmoid` to `_numpy_sigmoid` for clarity * Improve dtype checks in `BaseOfflinePolicyEstimator` * Refactor multiprocessing in `OfflinePolicyEvaluator` for better resource management * Add serialization support for MAB predictions to avoid pickling issues * Enhance tests for Jupyter notebook detection
1 parent 4624372 commit ccee894

File tree

10 files changed

+210
-46
lines changed

10 files changed

+210
-46
lines changed

.github/workflows/continuous_delivery.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ jobs:
4343
${{ runner.os }}-poetry-${{ matrix.python-version }}-${{ matrix.pydantic-version }}-
4444
- name: Install project dependencies with Poetry
4545
run: |
46-
poetry install
46+
poetry install --all-extras
4747
- name: Restore pyproject.toml
4848
run: |
4949
mv pyproject.toml.bak pyproject.toml

.github/workflows/continuous_documentation.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ jobs:
2929
export PATH="$HOME/.poetry/bin:$PATH"
3030
- name: Install project dependencies with Poetry
3131
run: |
32-
poetry install
32+
poetry install --all-extras
3333
3434
- name: Install Pandoc
3535
run: |

.github/workflows/continuous_integration.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ jobs:
3838
- name: Install project dependencies with Poetry
3939
run: |
4040
poetry add pydantic@${{ matrix.pydantic-version }}
41-
poetry install
41+
poetry install --all-extras
4242
- name: Style check
4343
run: |
4444
# run pre-commit hooks

docs/src/authors.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
- Dario d'Andrea, <dariod@playtika.com>
44
- Shahar Bar, <shaharbar1@gmail.com>
55
- Jerome Carayol, <jeromec@playtika.com>
6+
- Anastasiia Kabeshove, <anastasiiak@playtika.com>
67
- Stefano Piazza, <stefanop@playtika.com>
78
- Ron Shiff, <ron.shiff1@gmail.com>
89
- Raphael Steinmann, <raphaels@playtika.com>

pybandits/model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def _numpy_gelu(x: np.ndarray) -> np.ndarray:
7878
return 0.5 * x * (1 + erf(x / np.sqrt(2.0)))
7979

8080

81-
def _stable_sigmoid(x):
81+
def _numpy_sigmoid(x):
8282
"""Stable sigmoid activation function for NumPy."""
8383
return np.where(x >= 0, 1 / (1 + np.exp(-x)), np.exp(x) / (1 + np.exp(x)))
8484

@@ -521,7 +521,7 @@ class BaseBayesianNeuralNetwork(Model, ABC):
521521
_numpy_activations: ClassVar[dict] = {
522522
"tanh": np.tanh,
523523
"relu": _numpy_relu,
524-
"sigmoid": _stable_sigmoid,
524+
"sigmoid": _numpy_sigmoid,
525525
"gelu": _numpy_gelu,
526526
}
527527

@@ -883,7 +883,7 @@ def sample_proba(self, context: np.ndarray) -> List[ProbabilityWeight]:
883883
else:
884884
# Output layer - apply sigmoid
885885
weighted_sum = linear_transform.squeeze(-1)
886-
prob = _stable_sigmoid(weighted_sum)
886+
prob = _numpy_sigmoid(weighted_sum)
887887

888888
return list(zip(prob, weighted_sum))
889889

pybandits/offline_policy_estimator.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,14 @@ def _check_array(
7878
raise ValueError(f"{name} must be a {ndim}D array.")
7979
if array.shape[0] != n_samples:
8080
raise ValueError(f"action and {name} must have the same length.")
81-
if array.dtype != dtype:
81+
# Check dtype compatibility: use issubdtype for numpy dtypes
82+
if dtype is float:
83+
if not np.issubdtype(array.dtype, np.floating):
84+
raise ValueError(f"{name} must be a {dtype} array")
85+
elif dtype is int:
86+
if not np.issubdtype(array.dtype, np.integer):
87+
raise ValueError(f"{name} must be a {dtype} array")
88+
elif array.dtype is not dtype:
8289
raise ValueError(f"{name} must be a {dtype} array")
8390
if ndim > 1:
8491
if array.shape[1] != n_actions:

0 commit comments

Comments
 (0)