Skip to content

Commit bc39aa1

Browse files
esantorellafacebook-github-bot
authored andcommitted
Let Pyre know that AcquisitionFunction.model is a Model (#1216)
Summary: X-link: facebook/Ax#1216 ## Motivation Pyre is not smart enough to understand that calling `self.add_module('model', model)` makes `self.model` have the type of `model`, which is true due to some fairly complex underlying logic inherited from `torch.nn.Module`. However, PyTorch is smart enough to properly `add_module` if we just do `self.model = model`. This also works for tensors, but only if the tensor is explicitly registered as a buffer (by name, not necessarily by value) before assignment. ### Have you read the [Contributing Guidelines on pull requests] Yes Pull Request resolved: #1452 Test Plan: - Unit tests should be unaffected - Pyre error count drops from 1379 to 1309 (down 5%). - Added explicit tests that `_modules` and `_buffers` are properly initialized Reviewed By: Balandat Differential Revision: D40469725 Pulled By: esantorella fbshipit-source-id: 531cec5b77fc74faf478c4c96f1ceaa596ca8162
1 parent 8441d62 commit bc39aa1

File tree

9 files changed

+47
-14
lines changed

9 files changed

+47
-14
lines changed

botorch/acquisition/acquisition.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def __init__(self, model: Model) -> None:
3636
model: A fitted model.
3737
"""
3838
super().__init__()
39-
self.add_module("model", model)
39+
self.model: Model = model
4040

4141
@classmethod
4242
def _deprecate_acqf_objective(

botorch/acquisition/analytic.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -462,13 +462,22 @@ def _preprocess_constraint_bounds(
462462
con_upper_inds.append(k)
463463
con_upper.append(constraints[k][1])
464464
# tensor-based indexing is much faster than list-based advanced indexing
465-
self.register_buffer("con_lower_inds", torch.tensor(con_lower_inds))
466-
self.register_buffer("con_upper_inds", torch.tensor(con_upper_inds))
467-
self.register_buffer("con_both_inds", torch.tensor(con_both_inds))
468-
# tensor indexing
469-
self.register_buffer("con_both", torch.tensor(con_both, dtype=torch.float))
470-
self.register_buffer("con_lower", torch.tensor(con_lower, dtype=torch.float))
471-
self.register_buffer("con_upper", torch.tensor(con_upper, dtype=torch.float))
465+
for k in [
466+
"con_lower_inds",
467+
"con_upper_inds",
468+
"con_both_inds",
469+
"con_both",
470+
"con_lower",
471+
"con_upper",
472+
]:
473+
self.register_buffer(k, tensor=None)
474+
475+
self.con_lower_inds = torch.tensor(con_lower_inds)
476+
self.con_upper_inds = torch.tensor(con_upper_inds)
477+
self.con_both_inds = torch.tensor(con_both_inds)
478+
self.con_both = torch.tensor(con_both)
479+
self.con_lower = torch.tensor(con_lower)
480+
self.con_upper = torch.tensor(con_upper)
472481

473482
def _compute_prob_feas(self, X: Tensor, means: Tensor, sigmas: Tensor) -> Tensor:
474483
r"""Compute feasibility probability for each batch of X.

botorch/acquisition/knowledge_gradient.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -150,12 +150,13 @@ def __init__(
150150
"If using a multi-output model without an objective, "
151151
"posterior_transform must scalarize the output."
152152
)
153-
self.sampler = sampler
153+
self.sampler: MCSampler = sampler
154154
self.objective = objective
155155
self.posterior_transform = posterior_transform
156156
self.set_X_pending(X_pending)
157+
self.X_pending: Tensor = self.X_pending
157158
self.inner_sampler = inner_sampler
158-
self.num_fantasies = num_fantasies
159+
self.num_fantasies: int = num_fantasies
159160
self.current_value = current_value
160161

161162
@t_batch_mode_transform()
@@ -338,7 +339,7 @@ def __init__(
338339
project: Callable[[Tensor], Tensor] = lambda X: X,
339340
expand: Callable[[Tensor], Tensor] = lambda X: X,
340341
valfunc_cls: Optional[Type[AcquisitionFunction]] = None,
341-
valfunc_argfac: Optional[Callable[[Model, Dict[str, Any]]]] = None,
342+
valfunc_argfac: Optional[Callable[[Model], Dict[str, Any]]] = None,
342343
**kwargs: Any,
343344
) -> None:
344345
r"""Multi-Fidelity q-Knowledge Gradient (one-shot optimization).
@@ -529,7 +530,7 @@ def _get_value_function(
529530
sampler: Optional[MCSampler] = None,
530531
project: Optional[Callable[[Tensor], Tensor]] = None,
531532
valfunc_cls: Optional[Type[AcquisitionFunction]] = None,
532-
valfunc_argfac: Optional[Callable[[Model, Dict[str, Any]]]] = None,
533+
valfunc_argfac: Optional[Callable[[Model], Dict[str, Any]]] = None,
533534
) -> AcquisitionFunction:
534535
r"""Construct value function (i.e. inner acquisition function)."""
535536
if valfunc_cls is not None:

botorch/acquisition/monte_carlo.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def __init__(
7676
super().__init__(model=model)
7777
if sampler is None:
7878
sampler = SobolQMCNormalSampler(num_samples=512, collapse_batch_dims=True)
79-
self.add_module("sampler", sampler)
79+
self.sampler: MCSampler = sampler
8080
if objective is None and model.num_outputs != 1:
8181
if posterior_transform is None:
8282
raise UnsupportedError(
@@ -91,7 +91,7 @@ def __init__(
9191
if objective is None:
9292
objective = IdentityMCObjective()
9393
self.posterior_transform = posterior_transform
94-
self.add_module("objective", objective)
94+
self.objective: MCAcquisitionObjective = objective
9595
self.set_X_pending(X_pending)
9696

9797
@abstractmethod

test/acquisition/multi_objective/test_monte_carlo.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,9 @@ def test_q_expected_hypervolume_improvement(self):
159159
samples2 = torch.zeros(1, 2, 2, **tkwargs)
160160
mm2 = MockModel(MockPosterior(samples=samples2))
161161
acqf.model = mm2
162+
self.assertEqual(acqf.model, mm2)
163+
self.assertIn("model", acqf._modules)
164+
self.assertEqual(acqf._modules["model"], mm2)
162165
res = acqf(X2)
163166
self.assertEqual(res.item(), 0.0)
164167
# check cached indices

test/acquisition/multi_objective/test_multi_fidelity.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ def test_momf(self):
7373
samples2 = torch.zeros(1, 2, 2, **tkwargs)
7474
mm2 = MockModel(MockPosterior(samples=samples2))
7575
acqf.model = mm2
76+
self.assertEqual(acqf.model, mm2)
7677
res = acqf(X2)
7778
self.assertEqual(res.item(), 0.0)
7879
# check cached indices

test/acquisition/test_analytic.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,17 @@ def test_constrained_expected_improvement(self):
344344
module = ConstrainedExpectedImprovement(
345345
model=mm, best_f=0.0, objective_index=0, constraints={1: [None, 0]}
346346
)
347+
# test initialization
348+
for k in [
349+
"con_lower_inds",
350+
"con_upper_inds",
351+
"con_both_inds",
352+
"con_both",
353+
"con_lower",
354+
"con_upper",
355+
]:
356+
self.assertIn(k, module._buffers)
357+
347358
X = torch.empty(1, 1, device=self.device, dtype=dtype) # dummy
348359
ei = module(X)
349360
ei_expected_unconstrained = torch.tensor(

test/acquisition/test_knowledge_gradient.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -610,6 +610,10 @@ def test_get_value_function(self):
610610
mm = MockModel(None)
611611
# test PosteriorMean
612612
vf = _get_value_function(mm)
613+
# test initialization
614+
self.assertIn("model", vf._modules)
615+
self.assertEqual(vf._modules["model"], mm)
616+
613617
self.assertIsInstance(vf, PosteriorMean)
614618
self.assertIsNone(vf.posterior_transform)
615619
# test SimpleRegret

test/acquisition/test_monte_carlo.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,10 @@ def test_q_expected_improvement(self):
8383
# basic test
8484
sampler = IIDNormalSampler(num_samples=2)
8585
acqf = qExpectedImprovement(model=mm, best_f=0, sampler=sampler)
86+
# test initialization
87+
for k in ["objective", "sampler"]:
88+
self.assertIn(k, acqf._modules)
89+
8690
res = acqf(X)
8791
self.assertEqual(res.item(), 0.0)
8892

0 commit comments

Comments
 (0)