Skip to content

Commit b730a5f

Browse files
docusaurus-botfacebook-github-bot
authored andcommitted
Make "device" attribute of contextual models a property instead (#1611)
Summary: ## Motivation As of cornellius-gp/gpytorch#2234, the parent class of BoTorch kernels now has a property "device." This means that if a subclass tries to set `self.device`, it will error. This is why the BoTorch CI is currently breaking: https://github.com/pytorch/botorch/actions/runs/3841992968/jobs/6542850176 Pull Request resolved: #1611 Test Plan: Tests should pass Reviewed By: saitcakmak, Balandat Differential Revision: D42354199 Pulled By: esantorella fbshipit-source-id: c53e5b508dd75f4116870cd30ab90d11cd3eb573
1 parent 2998cfe commit b730a5f

File tree

2 files changed

+10
-2
lines changed

2 files changed

+10
-2
lines changed

botorch/models/kernels/contextual_lcea.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def __init__(
6161
self.decomposition = decomposition
6262
self.batch_shape = batch_shape
6363
self.train_embedding = train_embedding
64-
self.device = device
64+
self._device = device
6565

6666
num_param = len(next(iter(decomposition.values())))
6767
self.context_list = list(decomposition.keys())
@@ -128,6 +128,10 @@ def __init__(
128128
)
129129
self.register_constraint("raw_outputscale_list", Positive())
130130

131+
@property
132+
def device(self) -> Optional[torch.device]:
133+
return self._device
134+
131135
@property
132136
def outputscale_list(self) -> Tensor:
133137
return self.raw_outputscale_list_constraint.transform(self.raw_outputscale_list)

botorch/models/kernels/contextual_sac.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def __init__(
5757

5858
super().__init__(batch_shape=batch_shape)
5959
self.decomposition = decomposition
60-
self.device = device
60+
self._device = device
6161

6262
num_param = len(next(iter(decomposition.values())))
6363
for active_parameters in decomposition.values():
@@ -86,6 +86,10 @@ def __init__(
8686
)
8787
self.kernel_dict = ModuleDict(self.kernel_dict)
8888

89+
@property
90+
def device(self) -> Optional[torch.device]:
91+
return self._device
92+
8993
def forward(
9094
self,
9195
x1: Tensor,

0 commit comments

Comments
 (0)