Skip to content

Commit 4e7e903

Browse files
generatedunixname537391475639613meta-codesync[bot]
authored andcommitted
Fix for T244375028 ("An automatically generated diff you reviewed, D86492815, broke one test") (meta-pytorch#3076)
Summary: Pull Request resolved: meta-pytorch#3076 Reviewed By: Balandat Differential Revision: D86578948 fbshipit-source-id: 7289c3a40d4706cde69ed60f0eee0d33ae3fba8d
1 parent 7635818 commit 4e7e903

File tree

1 file changed

+14
-4
lines changed

1 file changed

+14
-4
lines changed

test/models/test_gp_regression_mixed.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -201,10 +201,6 @@ def test_condition_on_observations__(self):
201201

202202
# check that fantasies of batched model are correct
203203
if len(batch_shape) > 0 and test_X.dim() == 2:
204-
state_dict_non_batch = {
205-
key: (val[0] if val.ndim > 1 else val)
206-
for key, val in model.state_dict().items()
207-
}
208204
model_kwargs_non_batch = {
209205
"train_X": train_X[0],
210206
"train_Y": train_Y[0],
@@ -213,6 +209,20 @@ def test_condition_on_observations__(self):
213209
if observed_noise:
214210
model_kwargs_non_batch["train_Yvar"] = train_Yvar[0]
215211
model_non_batch = type(model)(**model_kwargs_non_batch)
212+
non_batch_shapes = {
213+
key: val.shape
214+
for key, val in model_non_batch.state_dict().items()
215+
}
216+
state_dict_non_batch = {}
217+
for key, val in model.state_dict().items():
218+
if key in non_batch_shapes:
219+
expected_shape = non_batch_shapes[key]
220+
if val.ndim > len(expected_shape):
221+
state_dict_non_batch[key] = val[0]
222+
else:
223+
state_dict_non_batch[key] = val
224+
else:
225+
state_dict_non_batch[key] = val
216226
model_non_batch.load_state_dict(state_dict_non_batch)
217227
model_non_batch.eval()
218228
model_non_batch.likelihood.eval()

0 commit comments

Comments
 (0)