Skip to content

Commit 583e8cb

Browse files
TobyBoynemeta-codesync[bot]
authored andcommitted
Check shape of state dict when comparing input transforms (#3051)
Summary: ## Motivation See #3050 ### Have you read the [Contributing Guidelines on pull requests](https://github.com/meta-pytorch/botorch/blob/main/CONTRIBUTING.md#pull-requests)? Y Pull Request resolved: #3051 Test Plan: I have added an additional test case to `test/models/transforms/test_input.py`. Without the changes in this PR, the new test case would fail with a runtime error, as described in #3050. The test now correctly passes. Reviewed By: saitcakmak Differential Revision: D85149051 Pulled By: Balandat fbshipit-source-id: 4f6a972601994154d19edd44fef43ddfa092170f
1 parent b0d492d commit 583e8cb

File tree

2 files changed

+10
-1
lines changed

2 files changed

+10
-1
lines changed

botorch/models/transforms/input.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,13 @@
4242
from torch.nn.functional import one_hot
4343

4444

45+
def _allclose(input: Tensor, other: Tensor) -> bool:
46+
"""Check if `input` and `other` are the same shape, and satisfy `torch.allclose`."""
47+
if input.shape != other.shape:
48+
return False
49+
return torch.allclose(input, other)
50+
51+
4552
class InputTransform(Module, ABC):
4653
r"""Abstract base class for input transforms.
4754
@@ -124,7 +131,7 @@ def equals(self, other: InputTransform) -> bool:
124131
and (self.transform_on_eval == other.transform_on_eval)
125132
and (self.transform_on_fantasize == other.transform_on_fantasize)
126133
and all(
127-
torch.allclose(v, other_state_dict[k].to(v))
134+
_allclose(v, other_state_dict[k].to(v))
128135
for k, v in self.state_dict().items()
129136
)
130137
)

test/models/transforms/test_input.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -400,7 +400,9 @@ def test_normalize(self) -> None:
400400
self.assertTrue(nlz6.equals(nlz8))
401401
nlz9 = Normalize(d=3, batch_shape=batch_shape, indices=[0, 1])
402402
nlz10 = Normalize(d=3, batch_shape=batch_shape, indices=[0, 2])
403+
nlz11 = Normalize(d=3, batch_shape=batch_shape, indices=[0, 1, 2])
403404
self.assertFalse(nlz9.equals(nlz10))
405+
self.assertFalse(nlz9.equals(nlz11))
404406

405407
# test with grad
406408
nlz = Normalize(d=1)

0 commit comments

Comments
 (0)