diff --git a/botorch_community/models/prior_fitted_network.py b/botorch_community/models/prior_fitted_network.py index e7e705eacd..bd2a11e8e8 100644 --- a/botorch_community/models/prior_fitted_network.py +++ b/botorch_community/models/prior_fitted_network.py @@ -27,6 +27,7 @@ ModelPaths, ) from botorch_community.posteriors.riemann import BoundedRiemannPosterior +from pfns.train import MainConfig # @manual=//pytorch/PFNs:PFNs from torch import Tensor from torch.nn import Module @@ -44,6 +45,7 @@ def __init__( batch_first: bool = False, constant_model_kwargs: dict[str, Any] | None = None, input_transform: InputTransform | None = None, + load_training_checkpoint: bool = False, ) -> None: """Initialize a PFNModel. @@ -71,6 +73,8 @@ def __init__( constant_model_kwargs: A dictionary of model kwargs that will be passed to the model in each forward pass. input_transform: A Botorch input transform. + load_training_checkpoint: Whether to load a training checkpoint as + produced by the PFNs training code, see github.com/automl/PFNs. """ super().__init__() @@ -79,6 +83,15 @@ def __init__( model_path=checkpoint_url, ) + if load_training_checkpoint: + # the model is not an actual model, but a training checkpoint + # make a model out of it + checkpoint = model + config = MainConfig.from_dict(checkpoint["config"]) + model = config.model.create_model() + model.load_state_dict(checkpoint["model_state_dict"]) + model.eval() + if train_Yvar is not None: logger.debug("train_Yvar provided but ignored for PFNModel.") @@ -113,7 +126,7 @@ def __init__( self.train_X = train_X # shape: `b x n x d` self.train_Y = train_Y # shape: `b x n` - self.pfn = model + self.pfn = model.to(train_X.device) self.batch_first = batch_first self.constant_model_kwargs = constant_model_kwargs if input_transform is not None: diff --git a/pyproject.toml b/pyproject.toml index fc636e4658..80db3d85bb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,6 +42,7 @@ test = [ "pytest-cov", "requests", "pymoo", + "pfns" ] dev = [ diff --git a/test_community/models/test_prior_fitted_network.py b/test_community/models/test_prior_fitted_network.py index 74df2ed8af..74c6229af5 100644 --- a/test_community/models/test_prior_fitted_network.py +++ b/test_community/models/test_prior_fitted_network.py @@ -19,6 +19,8 @@ download_model, ModelPaths, ) +from pfns.model.transformer_config import CrossEntropyConfig, TransformerConfig +from pfns.train import MainConfig, OptimizerConfig from torch import nn, Tensor @@ -162,6 +164,43 @@ def test_input_transform(self): self.assertIsInstance(model.input_transform, Normalize) self.assertEqual(model.input_transform.bounds.shape, torch.Size([2, 3])) + def test_unpack_checkpoint(self): + config = MainConfig( + priors=[], + optimizer=OptimizerConfig( + optimizer="adam", + lr=0.001, + ), + model=TransformerConfig( + criterion=CrossEntropyConfig(num_classes=3), + ), + batch_shape_sampler=None, + ) + + model = config.model.create_model() + + state_dict = model.state_dict() + checkpoint = { + "config": config.to_dict(), + "model_state_dict": state_dict, + } + + loaded_model = PFNModel( + train_X=torch.rand(10, 3), + train_Y=torch.rand(10, 1), + input_transform=Normalize(d=3), + model=checkpoint, + load_training_checkpoint=True, + ) + + loaded_state_dict = loaded_model.pfn.state_dict() + self.assertEqual( + sorted(loaded_state_dict.keys()), + sorted(state_dict.keys()), + ) + for k in loaded_state_dict.keys(): + self.assertTrue(torch.equal(loaded_state_dict[k], state_dict[k])) + class TestPriorFittedNetworkUtils(BotorchTestCase): @patch("botorch_community.models.utils.prior_fitted_network.requests.get") @@ -215,7 +254,7 @@ def test_download_model_cache_miss( train_X=torch.rand(10, 3), train_Y=torch.rand(10, 1), ) - self.assertEqual(model.pfn, fake_model) + self.assertEqual(model.pfn, fake_model.to("cpu")) @patch("botorch_community.models.utils.prior_fitted_network.torch.load") @patch("botorch_community.models.utils.prior_fitted_network.os.path.exists")