Skip to content

Commit f5646cb

Browse files
saitcakmakfacebook-github-bot
authored andcommitted
Remove model.train call from get_X_baseline (meta-pytorch#1289)
Summary: Pull Request resolved: meta-pytorch#1289 Removes the `model.train()` call from `get_X_baseline` and uses `_has_transformed_inputs` and `_original_train_inputs` attributes instead to get the original train inputs. Why: `model.train()` destroys caches on the model. For `SaasFullyBayesianSingleTaskGP` it completely destroys the model. Reviewed By: dme65 Differential Revision: D37692137 fbshipit-source-id: d6360a6325b78e72a0e5e729ef958beb1754952d
1 parent a675968 commit f5646cb

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

botorch/optim/utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -276,12 +276,12 @@ def get_X_baseline(acq_function: AcquisitionFunction) -> Optional[Tensor]:
276276
warnings.warn("Failed to extract X_baseline.", BotorchWarning)
277277
return
278278
try:
279-
# make sure input transforms are not applied
280-
model.train()
281-
if isinstance(model, ModelListGPyTorchModel):
282-
X = model.models[0].train_inputs[0]
279+
# Make sure we get the original train inputs.
280+
m = model.models[0] if isinstance(model, ModelListGPyTorchModel) else model
281+
if m._has_transformed_inputs:
282+
X = m._original_train_inputs
283283
else:
284-
X = model.train_inputs[0]
284+
X = m.train_inputs[0]
285285
except (BotorchError, AttributeError):
286286
warnings.warn("Failed to extract X_baseline.", BotorchWarning)
287287
return

0 commit comments

Comments
 (0)