Skip to content

Commit ca83ff5

Browse files
sdaultonfacebook-github-bot
authored andcommitted
use no_grad when using input transform in model.__init__ (#610)
Summary: Pull Request resolved: #610 See title. Reviewed By: Balandat Differential Revision: D25158547 fbshipit-source-id: 393c5b6ba45909c13cf876893d28ca9bbcd7e83e
1 parent 4cb3587 commit ca83ff5

File tree

3 files changed

+24
-18
lines changed

3 files changed

+24
-18
lines changed

botorch/models/gp_regression.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,10 @@ def __init__(
9292
"""
9393
if input_transform is not None:
9494
input_transform.to(train_X)
95-
transformed_X = self.transform_inputs(
96-
X=train_X, input_transform=input_transform
97-
)
95+
with torch.no_grad():
96+
transformed_X = self.transform_inputs(
97+
X=train_X, input_transform=input_transform
98+
)
9899
if outcome_transform is not None:
99100
train_Y, _ = outcome_transform(train_Y)
100101
self._validate_tensor_args(X=transformed_X, Y=train_Y)
@@ -206,9 +207,10 @@ def __init__(
206207
"""
207208
if input_transform is not None:
208209
input_transform.to(train_X)
209-
transformed_X = self.transform_inputs(
210-
X=train_X, input_transform=input_transform
211-
)
210+
with torch.no_grad():
211+
transformed_X = self.transform_inputs(
212+
X=train_X, input_transform=input_transform
213+
)
212214
if outcome_transform is not None:
213215
train_Y, train_Yvar = outcome_transform(train_Y, train_Yvar)
214216
self._validate_tensor_args(X=transformed_X, Y=train_Y, Yvar=train_Yvar)

botorch/models/gp_regression_fidelity.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -97,9 +97,10 @@ def __init__(
9797
)
9898
if input_transform is not None:
9999
input_transform.to(train_X)
100-
transformed_X = self.transform_inputs(
101-
X=train_X, input_transform=input_transform
102-
)
100+
with torch.no_grad():
101+
transformed_X = self.transform_inputs(
102+
X=train_X, input_transform=input_transform
103+
)
103104

104105
self._set_dimensions(train_X=transformed_X, train_Y=train_Y)
105106
covar_module, subset_batch_dict = _setup_multifidelity_covar_module(
@@ -208,9 +209,10 @@ def __init__(
208209
)
209210
if input_transform is not None:
210211
input_transform.to(train_X)
211-
transformed_X = self.transform_inputs(
212-
X=train_X, input_transform=input_transform
213-
)
212+
with torch.no_grad():
213+
transformed_X = self.transform_inputs(
214+
X=train_X, input_transform=input_transform
215+
)
214216
self._set_dimensions(train_X=transformed_X, train_Y=train_Y)
215217
covar_module, subset_batch_dict = _setup_multifidelity_covar_module(
216218
dim=transformed_X.size(-1),

botorch/models/multitask.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -83,9 +83,10 @@ def __init__(
8383
"""
8484
if input_transform is not None:
8585
input_transform.to(train_X)
86-
transformed_X = self.transform_inputs(
87-
X=train_X, input_transform=input_transform
88-
)
86+
with torch.no_grad():
87+
transformed_X = self.transform_inputs(
88+
X=train_X, input_transform=input_transform
89+
)
8990
self._validate_tensor_args(X=transformed_X, Y=train_Y)
9091
all_tasks, task_feature, d = self.get_all_tasks(
9192
transformed_X, task_feature, output_tasks
@@ -292,9 +293,10 @@ def __init__(
292293
"""
293294
if input_transform is not None:
294295
input_transform.to(train_X)
295-
transformed_X = self.transform_inputs(
296-
X=train_X, input_transform=input_transform
297-
)
296+
with torch.no_grad():
297+
transformed_X = self.transform_inputs(
298+
X=train_X, input_transform=input_transform
299+
)
298300
self._validate_tensor_args(X=transformed_X, Y=train_Y, Yvar=train_Yvar)
299301
# We'll instatiate a MultiTaskGP and simply override the likelihood
300302
super().__init__(

0 commit comments

Comments
 (0)