Skip to content

Commit bb4aed6

Browse files
authored
Merge branch 'testing' into mcminn-test
2 parents 14f989d + 47fa89c commit bb4aed6

File tree

5 files changed

+21
-7
lines changed

5 files changed

+21
-7
lines changed

demo/Cost-penalized custom objective.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@
232232
"metadata": {},
233233
"outputs": [],
234234
"source": [
235-
"for iter in range(3):\n",
235+
"for i in range(3):\n",
236236
" campaign.fit()\n",
237237
" X_suggest, eval_suggest = campaign.suggest(m_batch=3)\n",
238238
" y_iter = pd.DataFrame(simulator.simulate(X_suggest))\n",

demo/Simple single objective.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@
190190
"metadata": {},
191191
"outputs": [],
192192
"source": [
193-
"for iter in range(3):\n",
193+
"for i in range(3):\n",
194194
" campaign.fit()\n",
195195
" X_suggest, eval_suggest = campaign.suggest(m_batch=3)\n",
196196
" y_iter = pd.DataFrame(simulator.simulate(X_suggest))\n",

obsidian/parameters/transforms.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Transformation functions to normalize output responses"""
22

3-
from torch import logit, sigmoid
3+
from torch import logit, sigmoid, zeros_like
44
from abc import ABC, abstractmethod
55
import warnings
66
from torch import Tensor
@@ -84,7 +84,12 @@ def forward(self,
8484
self.params = {'mu': X_v.mean(), 'sd': X_v.std()}
8585
else:
8686
self._validate_fit()
87-
return (X-self.params['mu'])/self.params['sd']
87+
if self.params["sd"] == 0:
88+
# In the edge case where `X` is degenerate, avoid 0 divided by 0
89+
warnings.warn('Transform constant target values by mean subtraction', UserWarning)
90+
return zeros_like(X)
91+
else:
92+
return (X-self.params['mu'])/self.params['sd']
8893

8994
def inverse(self, X):
9095
"""Inverse transform the transformed data X_t"""

obsidian/plotting/plotly.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def MDS_plot(campaign: Campaign) -> Figure:
147147
except ImportError:
148148
raise ImportError(
149149
"The `sklearn` package (>1.0) is required for the MDS plot. \
150-
Please install it using `pip install scikit-learn`"
150+
Please install it using `pip install scikit-learn`"
151151
)
152152

153153
mds = MDS(n_components=2)

obsidian/tests/test_parameters.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,16 @@ def test_target_validation():
255255
with pytest.warns(UserWarning):
256256
transform_func = Logit_Scaler(range_response=100)
257257
transform_func.forward(test_neg_response, fit=False)
258-
259-
258+
259+
# Transform constant target values
260+
test_constant_response = torch.zeros(10) + 9.0
261+
with pytest.warns(UserWarning):
262+
Target('Response1', f_transform='Standard').transform_f(test_constant_response, fit=True)
263+
264+
# Corner case for Logit_Scaler
265+
transform_func = Logit_Scaler(standardize=False)
266+
transform_func.forward(test_response, fit=True)
267+
268+
260269
if __name__ == '__main__':
261270
pytest.main([__file__, '-m', 'fast'])

0 commit comments

Comments
 (0)