Skip to content

Commit 08bfa91

Browse files
SamuelGabrielfacebook-github-bot
authored andcommitted
Enable evaluating PFNs trained with pytorch/PFNs to be evaluated in Ax
Summary: This PR enables PFNs generally to work in our MAST evaluation, as the registry didn't quite work before, and it additionally allows to use training checkpoints from pytorch/PFNs to be used to do evaluations. Reviewed By: Balandat Differential Revision: D80944578
1 parent 4170f58 commit 08bfa91

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

test_community/models/test_prior_fitted_network.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -179,9 +179,10 @@ def test_unpack_checkpoint(self):
179179

180180
model = config.model.create_model()
181181

182+
state_dict = model.state_dict()
182183
checkpoint = {
183184
"config": config.to_dict(),
184-
"model_state_dict": model.state_dict(),
185+
"model_state_dict": state_dict,
185186
}
186187

187188
loaded_model = PFNModel(
@@ -195,10 +196,10 @@ def test_unpack_checkpoint(self):
195196
loaded_state_dict = loaded_model.pfn.state_dict()
196197
self.assertEqual(
197198
sorted(loaded_state_dict.keys()),
198-
sorted(model.state_dict().keys()),
199+
sorted(state_dict.keys()),
199200
)
200201
for k in loaded_state_dict.keys():
201-
self.assertTrue(torch.equal(loaded_state_dict[k], model.state_dict()[k]))
202+
self.assertTrue(torch.equal(loaded_state_dict[k], state_dict[k]))
202203

203204

204205
class TestPriorFittedNetworkUtils(BotorchTestCase):

0 commit comments

Comments
 (0)