Skip to content

Commit b7a9763

Browse files
SamuelGabrielfacebook-github-bot
authored andcommitted
PFN changes to enable PFNs in new MAST FBFlow (#2998)
Summary: To get PFNs to work in the new MAST PFN flow, I have changed a few dependencies, added dependencies, fixed device handling. I needed to overwrite `construct_inputs` and get errors in changing the signature. I don't feel like this should be an issue, so I added an ignore statement. To really enable them, one also needs the next two commits in this stack. Differential Revision: D80944578
1 parent 55b8911 commit b7a9763

File tree

3 files changed

+54
-2
lines changed

3 files changed

+54
-2
lines changed

botorch_community/models/prior_fitted_network.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
ModelPaths,
2828
)
2929
from botorch_community.posteriors.riemann import BoundedRiemannPosterior
30+
from pfns.train import MainConfig # @manual=//pytorch/PFNs:PFNs
3031
from torch import Tensor
3132
from torch.nn import Module
3233

@@ -44,6 +45,7 @@ def __init__(
4445
batch_first: bool = False,
4546
constant_model_kwargs: dict[str, Any] | None = None,
4647
input_transform: InputTransform | None = None,
48+
load_training_checkpoint: bool = False,
4749
) -> None:
4850
"""Initialize a PFNModel.
4951
@@ -71,6 +73,8 @@ def __init__(
7173
constant_model_kwargs: A dictionary of model kwargs that
7274
will be passed to the model in each forward pass.
7375
input_transform: A Botorch input transform.
76+
load_training_checkpoint: Whether to load a training checkpoint as
77+
produced by the PFNs training code, see github.com/automl/PFNs.
7478
7579
"""
7680
super().__init__()
@@ -79,6 +83,15 @@ def __init__(
7983
model_path=checkpoint_url,
8084
)
8185

86+
if load_training_checkpoint:
87+
# the model is not an actual model, but a training checkpoint
88+
# make a model out of it
89+
checkpoint = model
90+
config = MainConfig.from_dict(checkpoint["config"])
91+
model = config.model.create_model()
92+
model.load_state_dict(checkpoint["model_state_dict"])
93+
model.eval()
94+
8295
if train_Yvar is not None:
8396
logger.debug("train_Yvar provided but ignored for PFNModel.")
8497

@@ -113,7 +126,7 @@ def __init__(
113126

114127
self.train_X = train_X # shape: `b x n x d`
115128
self.train_Y = train_Y # shape: `b x n`
116-
self.pfn = model
129+
self.pfn = model.to(train_X.device)
117130
self.batch_first = batch_first
118131
self.constant_model_kwargs = constant_model_kwargs
119132
if input_transform is not None:

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ test = [
4242
"pytest-cov",
4343
"requests",
4444
"pymoo",
45+
"pfns"
4546
]
4647

4748
dev = [

test_community/models/test_prior_fitted_network.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
download_model,
2020
ModelPaths,
2121
)
22+
from pfns.model.transformer_config import CrossEntropyConfig, TransformerConfig
23+
from pfns.train import MainConfig, OptimizerConfig
2224
from torch import nn, Tensor
2325

2426

@@ -162,6 +164,42 @@ def test_input_transform(self):
162164
self.assertIsInstance(model.input_transform, Normalize)
163165
self.assertEqual(model.input_transform.bounds.shape, torch.Size([2, 3]))
164166

167+
def test_unpack_checkpoint(self):
168+
config = MainConfig(
169+
priors=[],
170+
optimizer=OptimizerConfig(
171+
optimizer="adam",
172+
lr=0.001,
173+
),
174+
model=TransformerConfig(
175+
criterion=CrossEntropyConfig(num_classes=3),
176+
),
177+
batch_shape_sampler=None,
178+
)
179+
180+
model = config.model.create_model()
181+
182+
checkpoint = {
183+
"config": config.to_dict(),
184+
"model_state_dict": model.state_dict(),
185+
}
186+
187+
loaded_model = PFNModel(
188+
train_X=torch.rand(10, 3),
189+
train_Y=torch.rand(10, 1),
190+
input_transform=Normalize(d=3),
191+
model=checkpoint,
192+
load_training_checkpoint=True,
193+
)
194+
195+
loaded_state_dict = loaded_model.pfn.state_dict()
196+
self.assertEqual(
197+
sorted(loaded_state_dict.keys()),
198+
sorted(model.state_dict().keys()),
199+
)
200+
for k in loaded_state_dict.keys():
201+
self.assertTrue(torch.equal(loaded_state_dict[k], model.state_dict()[k]))
202+
165203

166204
class TestPriorFittedNetworkUtils(BotorchTestCase):
167205
@patch("botorch_community.models.utils.prior_fitted_network.requests.get")
@@ -215,7 +253,7 @@ def test_download_model_cache_miss(
215253
train_X=torch.rand(10, 3),
216254
train_Y=torch.rand(10, 1),
217255
)
218-
self.assertEqual(model.pfn, fake_model)
256+
self.assertEqual(model.pfn, fake_model.to("cpu"))
219257

220258
@patch("botorch_community.models.utils.prior_fitted_network.torch.load")
221259
@patch("botorch_community.models.utils.prior_fitted_network.os.path.exists")

0 commit comments

Comments
 (0)