Skip to content

Commit a42c5f8

Browse files
committed
Improve OPE functionality, logging, and multiprocessing
### Changes: * Add optional dependencies for xgboost and ipython * Rename `_stable_sigmoid` to `_numpy_sigmoid` for clarity * Improve dtype checks in `BaseOfflinePolicyEstimator` * Refactor multiprocessing in `OfflinePolicyEvaluator` for better resource management * Add serialization support for MAB predictions to avoid pickling issues * Enhance tests for Jupyter notebook detection
1 parent a60aa5c commit a42c5f8

File tree

10 files changed

+267
-47
lines changed

10 files changed

+267
-47
lines changed

.github/workflows/continuous_delivery.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ jobs:
4343
${{ runner.os }}-poetry-${{ matrix.python-version }}-${{ matrix.pydantic-version }}-
4444
- name: Install project dependencies with Poetry
4545
run: |
46-
poetry install
46+
poetry install --all-extras
4747
- name: Restore pyproject.toml
4848
run: |
4949
mv pyproject.toml.bak pyproject.toml

.github/workflows/continuous_documentation.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ jobs:
2929
export PATH="$HOME/.poetry/bin:$PATH"
3030
- name: Install project dependencies with Poetry
3131
run: |
32-
poetry install
32+
poetry install --all-extras
3333
3434
- name: Install Pandoc
3535
run: |

.github/workflows/continuous_integration.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ jobs:
3838
- name: Install project dependencies with Poetry
3939
run: |
4040
poetry add pydantic@${{ matrix.pydantic-version }}
41-
poetry install
41+
poetry install --all-extras
4242
- name: Style check
4343
run: |
4444
# run pre-commit hooks

docs/src/authors.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
- Dario d'Andrea, <[email protected]>
44
- Shahar Bar, <[email protected]>
55
- Jerome Carayol, <[email protected]>
6+
- Anastasiia Kabeshova, <[email protected]>
67
- Stefano Piazza, <[email protected]>
78
- Ron Shiff, <[email protected]>
89
- Raphael Steinmann, <[email protected]>

pybandits/model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def _numpy_gelu(x: np.ndarray) -> np.ndarray:
7979
return 0.5 * x * (1 + erf(x / np.sqrt(2.0)))
8080

8181

82-
def _stable_sigmoid(x):
82+
def _numpy_sigmoid(x):
8383
"""Stable sigmoid activation function for NumPy."""
8484
return np.where(x >= 0, 1 / (1 + np.exp(-x)), np.exp(x) / (1 + np.exp(x)))
8585

@@ -850,7 +850,7 @@ class BaseBayesianNeuralNetwork(Model, ABC):
850850
_numpy_activations: ClassVar[dict] = {
851851
"tanh": np.tanh,
852852
"relu": _numpy_relu,
853-
"sigmoid": _stable_sigmoid,
853+
"sigmoid": _numpy_sigmoid,
854854
"gelu": _numpy_gelu,
855855
}
856856

@@ -1208,7 +1208,7 @@ def sample_proba(self, context: np.ndarray) -> List[ProbabilityWeight]:
12081208
else:
12091209
# Output layer - apply sigmoid
12101210
weighted_sum = linear_transform.squeeze(-1)
1211-
prob = _stable_sigmoid(weighted_sum)
1211+
prob = _numpy_sigmoid(weighted_sum)
12121212

12131213
return list(zip(prob, weighted_sum))
12141214

pybandits/offline_policy_estimator.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,14 @@ def _check_array(
7878
raise ValueError(f"{name} must be a {ndim}D array.")
7979
if array.shape[0] != n_samples:
8080
raise ValueError(f"action and {name} must have the same length.")
81-
if array.dtype != dtype:
81+
# Check dtype compatibility: use issubdtype for numpy dtypes
82+
if dtype is float:
83+
if not np.issubdtype(array.dtype, np.floating):
84+
raise ValueError(f"{name} must be a {dtype} array")
85+
elif dtype is int:
86+
if not np.issubdtype(array.dtype, np.integer):
87+
raise ValueError(f"{name} must be a {dtype} array")
88+
elif array.dtype is not dtype:
8289
raise ValueError(f"{name} must be a {dtype} array")
8390
if ndim > 1:
8491
if array.shape[1] != n_actions:

0 commit comments

Comments
 (0)