Skip to content

Commit 0be800e

Browse files
sdaultonfacebook-github-bot
authored andcommitted
fix StratifiedStandardize dtype/nan issue (#2757)
Summary: Pull Request resolved: #2757 If dtype passed to `get_task_value_remapping` is not float or double, an exception is raised because NaN cannot be used in an int/long tensor. The docstring states this, but we did it anyway in StratifiedStandardize, which meant that remapping didn't work. This fixes the issue. Reviewed By: esantorella Differential Revision: D70111086 fbshipit-source-id: c01044dadd2de33c84a346a3ee500dcf504cfe23
1 parent 9a7c517 commit 0be800e

File tree

4 files changed

+30
-9
lines changed

4 files changed

+30
-9
lines changed

botorch/models/transforms/outcome.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -532,7 +532,7 @@ def __init__(
532532
OutcomeTransform.__init__(self)
533533
self._stratification_idx = stratification_idx
534534
task_values = task_values.unique(sorted=True)
535-
self.strata_mapping = get_task_value_remapping(task_values, dtype=torch.long)
535+
self.strata_mapping = get_task_value_remapping(task_values, dtype=torch.double)
536536
if self.strata_mapping is None:
537537
self.strata_mapping = task_values
538538
n_strata = self.strata_mapping.shape[0]
@@ -576,7 +576,7 @@ def forward(
576576
strata = X[..., self._stratification_idx].long()
577577
unique_strata = strata.unique()
578578
for s in unique_strata:
579-
mapped_strata = self.strata_mapping[s]
579+
mapped_strata = self.strata_mapping[s].long()
580580
mask = strata != s
581581
Y_strata = Y.clone()
582582
Y_strata[..., mask, :] = float("nan")
@@ -616,7 +616,7 @@ def _get_per_input_means_stdvs(
616616
- The per-input stdvs squared.
617617
"""
618618
strata = X[..., self._stratification_idx].long()
619-
mapped_strata = self.strata_mapping[strata].unsqueeze(-1)
619+
mapped_strata = self.strata_mapping[strata].unsqueeze(-1).long()
620620
# get means and stdvs for each strata
621621
n_extra_batch_dims = mapped_strata.ndim - 2 - len(self._batch_shape)
622622
expand_shape = mapped_strata.shape[:n_extra_batch_dims] + self.means.shape

botorch/models/utils/assorted.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -422,6 +422,8 @@ def get_task_value_remapping(task_values: Tensor, dtype: torch.dtype) -> Tensor
422422
return value will be `None`, when the task values are contiguous
423423
integers starting from zero.
424424
"""
425+
if dtype not in (torch.float, torch.double):
426+
raise ValueError(f"dtype must be torch.float or torch.double, but got {dtype}.")
425427
task_range = torch.arange(
426428
len(task_values), dtype=task_values.dtype, device=task_values.device
427429
)

test/models/test_multitask.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -700,3 +700,12 @@ def test_get_task_value_remapping(self) -> None:
700700
mapping = get_task_value_remapping(task_values, dtype)
701701
self.assertTrue(torch.equal(mapping[[1, 3]], expected_mapping_no_nan))
702702
self.assertTrue(torch.isnan(mapping[[0, 2]]).all())
703+
704+
def test_get_task_value_remapping_invalid_dtype(self) -> None:
705+
task_values = torch.tensor([1, 3])
706+
for dtype in (torch.int32, torch.long, torch.bool):
707+
with self.assertRaisesRegex(
708+
ValueError,
709+
f"dtype must be torch.float or torch.double, but got {dtype}.",
710+
):
711+
get_task_value_remapping(task_values, dtype)

test/models/transforms/test_outcome.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -372,16 +372,24 @@ def test_stratified_standardize(self):
372372
n = 5
373373
seed = randint(0, 100)
374374
torch.manual_seed(seed)
375-
for dtype, batch_shape in itertools.product(
376-
(torch.float, torch.double), (torch.Size([]), torch.Size([3]))
375+
for dtype, batch_shape, task_values in itertools.product(
376+
(torch.float, torch.double),
377+
(torch.Size([]), torch.Size([3])),
378+
(
379+
torch.tensor([0, 1], dtype=torch.long, device=self.device),
380+
torch.tensor([0, 3], dtype=torch.long, device=self.device),
381+
),
377382
):
378383
torch.manual_seed(seed)
384+
tval = task_values[1].item()
379385
X = torch.rand(*batch_shape, n, 2, dtype=dtype, device=self.device)
380-
X[..., -1] = torch.tensor([0, 1, 0, 1, 0], dtype=dtype, device=self.device)
386+
X[..., -1] = torch.tensor(
387+
[0, tval, 0, tval, 0], dtype=dtype, device=self.device
388+
)
381389
Y = torch.randn(*batch_shape, n, 1, dtype=dtype, device=self.device)
382390
Yvar = torch.rand(*batch_shape, n, 1, dtype=dtype, device=self.device)
383391
strata_tf = StratifiedStandardize(
384-
task_values=torch.tensor([0, 1], dtype=torch.long, device=self.device),
392+
task_values=task_values,
385393
stratification_idx=-1,
386394
batch_shape=batch_shape,
387395
)
@@ -400,9 +408,11 @@ def test_stratified_standardize(self):
400408
tf_Y1, tf_Yvar1 = tf1(Y=Y1, Yvar=Yvar1, X=X1)
401409
# check that stratified means are expected
402410
self.assertAllClose(strata_tf.means[..., :1, :], tf0.means)
403-
self.assertAllClose(strata_tf.means[..., 1:, :], tf1.means)
411+
# use remapped task values to index
412+
self.assertAllClose(strata_tf.means[..., 1:2, :], tf1.means)
404413
self.assertAllClose(strata_tf.stdvs[..., :1, :], tf0.stdvs)
405-
self.assertAllClose(strata_tf.stdvs[..., 1:, :], tf1.stdvs)
414+
# use remapped task values to index
415+
self.assertAllClose(strata_tf.stdvs[..., 1:2, :], tf1.stdvs)
406416
# check the transformed values
407417
self.assertAllClose(tf_Y0, tf_Y[mask0].view(*batch_shape, -1, 1))
408418
self.assertAllClose(tf_Y1, tf_Y[mask1].view(*batch_shape, -1, 1))

0 commit comments

Comments
 (0)