Skip to content

Commit c0fa168

Browse files
saitcakmakfacebook-github-bot
authored andcommitted
Storage support for RBFKernel & LogNormalPrior (facebook#2616)
Summary: Pull Request resolved: facebook#2616 Adds storage support for these components that'll soon be used by default. The storage support is only relevant in cases where they are manually specified. I also removed `REVERSE_GPYTORCH_COMPONENT_REGISTRY` since I did not find it to affect anything. We need `GPYTORCH_COMPONENT_REGISTRY` but the reverse does not seem to do much. Reviewed By: mpolson64 Differential Revision: D60458209 fbshipit-source-id: a8f8cd826656d497a35d60fa54e3fc901b88481d
1 parent 8587b30 commit c0fa168

File tree

5 files changed

+57
-36
lines changed

5 files changed

+57
-36
lines changed

ax/storage/botorch_modular_registry.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@
7979
# Miscellaneous BoTorch imports
8080
from gpytorch.constraints import Interval
8181
from gpytorch.kernels.kernel import Kernel
82+
from gpytorch.kernels.rbf_kernel import RBFKernel
8283
from gpytorch.likelihoods.gaussian_likelihood import GaussianLikelihood
8384
from gpytorch.likelihoods.likelihood import Likelihood
8485

@@ -156,6 +157,7 @@
156157

157158
KERNEL_REGISTRY: Dict[Type[Kernel], str] = {
158159
ScaleMaternKernel: "ScaleMaternKernel",
160+
RBFKernel: "RBFKernel",
159161
}
160162

161163
LIKELIHOOD_REGISTRY: Dict[Type[GaussianLikelihood], str] = {
@@ -200,6 +202,7 @@
200202
Model: MODEL_REGISTRY,
201203
Interval: GPYTORCH_COMPONENT_REGISTRY,
202204
GammaPrior: GPYTORCH_COMPONENT_REGISTRY,
205+
LogNormalPrior: GPYTORCH_COMPONENT_REGISTRY,
203206
InputTransform: INPUT_TRANSFORM_REGISTRY,
204207
OutcomeTransform: OUTCOME_TRANSFORM_REGISTRY,
205208
}
@@ -239,12 +242,6 @@
239242
v: k for k, v in LIKELIHOOD_REGISTRY.items()
240243
}
241244

242-
243-
REVERSE_GPYTORCH_COMPONENT_REGISTRY: Dict[str, Type[torch.nn.Module]] = {
244-
v: k for k, v in GPYTORCH_COMPONENT_REGISTRY.items()
245-
}
246-
247-
248245
REVERSE_INPUT_TRANSFORM_REGISTRY: Dict[str, Type[InputTransform]] = {
249246
v: k for k, v in INPUT_TRANSFORM_REGISTRY.items()
250247
}
@@ -264,8 +261,6 @@
264261
Likelihood: REVERSE_LIKELIHOOD_REGISTRY,
265262
MarginalLogLikelihood: REVERSE_MLL_REGISTRY,
266263
Model: REVERSE_MODEL_REGISTRY,
267-
Interval: REVERSE_GPYTORCH_COMPONENT_REGISTRY,
268-
GammaPrior: REVERSE_GPYTORCH_COMPONENT_REGISTRY,
269264
InputTransform: REVERSE_INPUT_TRANSFORM_REGISTRY,
270265
OutcomeTransform: REVERSE_OUTCOME_TRANSFORM_REGISTRY,
271266
}

ax/storage/json_store/decoders.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
from botorch.models.transforms.input import ChainedInputTransform, InputTransform
4242
from botorch.models.transforms.outcome import ChainedOutcomeTransform, OutcomeTransform
4343
from botorch.utils.types import _DefaultType, DEFAULT
44+
from torch.distributions.transformed_distribution import TransformedDistribution
4445

4546
logger: logging.Logger = get_logger(__name__)
4647

@@ -279,6 +280,11 @@ def botorch_component_from_json(botorch_class: Any, json: Dict[str, Any]) -> Typ
279280
for k, v in state_dict.items()
280281
}
281282
)
283+
if issubclass(botorch_class, TransformedDistribution):
284+
# Extract the transformed attributes for transformed priors.
285+
for k in list(state_dict.keys()):
286+
if k.startswith("_transformed_"):
287+
state_dict[k[13:]] = state_dict.pop(k)
282288
class_path = json.pop("class")
283289
init_args = inspect.signature(botorch_class).parameters
284290
required_args = {

ax/storage/json_store/registry.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@
171171
from gpytorch.constraints import Interval
172172
from gpytorch.likelihoods.likelihood import Likelihood
173173
from gpytorch.mlls.marginal_log_likelihood import MarginalLogLikelihood
174-
from gpytorch.priors.torch_priors import GammaPrior
174+
from gpytorch.priors.torch_priors import GammaPrior, LogNormalPrior
175175

176176

177177
# pyre-fixme[5]: Global annotation cannot contain `Any`.
@@ -206,6 +206,7 @@
206206
Interval: botorch_component_to_dict,
207207
JenattonMetric: metric_to_dict,
208208
L2NormMetric: metric_to_dict,
209+
LogNormalPrior: botorch_component_to_dict,
209210
MapData: map_data_to_dict,
210211
MapKeyInfo: map_key_info_to_dict,
211212
MapMetric: metric_to_dict,
@@ -327,6 +328,7 @@
327328
"LifecycleStage": LifecycleStage,
328329
"ListSurrogate": Surrogate, # For backwards compatibility
329330
"L2NormMetric": L2NormMetric,
331+
"LogNormalPrior": LogNormalPrior,
330332
"MapData": MapData,
331333
"MapMetric": MapMetric,
332334
"MapKeyInfo": MapKeyInfo,

ax/storage/json_store/tests/test_json_store.py

Lines changed: 31 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@
114114
get_sum_constraint2,
115115
get_surrogate,
116116
get_surrogate_spec_with_default,
117+
get_surrogate_spec_with_lognormal,
117118
get_synthetic_runner,
118119
get_threshold_early_stopping_strategy,
119120
get_trial,
@@ -238,6 +239,7 @@
238239
class JSONStoreTest(TestCase):
239240
def setUp(self) -> None:
240241
super().setUp()
242+
self.maxDiff = None
241243
self.experiment = get_experiment_with_batch_and_single_trial()
242244

243245
def test_JSONEncodeFailure(self) -> None:
@@ -526,32 +528,35 @@ def test_EncodeDecodeSet(self) -> None:
526528
def test_encode_decode_surrogate_spec(self) -> None:
527529
# Test SurrogateSpec separately since the GPyTorch components
528530
# fail simple equality checks.
529-
org_object = get_surrogate_spec_with_default()
530-
converted_object = object_from_json(
531-
object_to_json(
532-
org_object,
533-
encoder_registry=CORE_ENCODER_REGISTRY,
534-
class_encoder_registry=CORE_CLASS_ENCODER_REGISTRY,
535-
),
536-
decoder_registry=CORE_DECODER_REGISTRY,
537-
class_decoder_registry=CORE_CLASS_DECODER_REGISTRY,
538-
)
539-
org_as_dict = dataclasses.asdict(org_object)
540-
converted_as_dict = dataclasses.asdict(converted_object)
541-
# Covar module kwargs will fail comparison. Manually compare.
542-
org_covar_kwargs = org_as_dict.pop("covar_module_kwargs")
543-
converted_covar_kwargs = converted_as_dict.pop("covar_module_kwargs")
544-
self.assertEqual(org_covar_kwargs.keys(), converted_covar_kwargs.keys())
545-
for k in org_covar_kwargs:
546-
org_ = org_covar_kwargs[k]
547-
converted_ = converted_covar_kwargs[k]
548-
if isinstance(org_, torch.nn.Module):
549-
self.assertEqual(org_.__class__, converted_.__class__)
550-
self.assertEqual(org_.__dict__, converted_.__dict__)
551-
else:
552-
self.assertEqual(org_, converted_)
553-
# Compare the rest.
554-
self.assertEqual(org_as_dict, converted_as_dict)
531+
for org_object in (
532+
get_surrogate_spec_with_default(),
533+
get_surrogate_spec_with_lognormal(),
534+
):
535+
converted_object = object_from_json(
536+
object_to_json(
537+
org_object,
538+
encoder_registry=CORE_ENCODER_REGISTRY,
539+
class_encoder_registry=CORE_CLASS_ENCODER_REGISTRY,
540+
),
541+
decoder_registry=CORE_DECODER_REGISTRY,
542+
class_decoder_registry=CORE_CLASS_DECODER_REGISTRY,
543+
)
544+
org_as_dict = dataclasses.asdict(org_object)
545+
converted_as_dict = dataclasses.asdict(converted_object)
546+
# Covar module kwargs will fail comparison. Manually compare.
547+
org_covar_kwargs = org_as_dict.pop("covar_module_kwargs")
548+
converted_covar_kwargs = converted_as_dict.pop("covar_module_kwargs")
549+
self.assertEqual(org_covar_kwargs.keys(), converted_covar_kwargs.keys())
550+
for k in org_covar_kwargs:
551+
org_ = org_covar_kwargs[k]
552+
converted_ = converted_covar_kwargs[k]
553+
if isinstance(org_, torch.nn.Module):
554+
self.assertEqual(org_.__class__, converted_.__class__)
555+
self.assertEqual(org_.state_dict(), converted_.state_dict())
556+
else:
557+
self.assertEqual(org_, converted_)
558+
# Compare the rest.
559+
self.assertEqual(org_as_dict, converted_as_dict)
555560

556561
def test_RegistryAdditions(self) -> None:
557562
class MyRunner(Runner):

ax/utils/testing/core_stubs.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,9 +122,10 @@
122122
from botorch.utils.datasets import SupervisedDataset
123123
from botorch.utils.types import DEFAULT
124124
from gpytorch.constraints import Interval
125+
from gpytorch.kernels.rbf_kernel import RBFKernel
125126
from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood
126127
from gpytorch.mlls.marginal_log_likelihood import MarginalLogLikelihood
127-
from gpytorch.priors.torch_priors import GammaPrior
128+
from gpytorch.priors.torch_priors import GammaPrior, LogNormalPrior
128129

129130
logger: Logger = get_logger(__name__)
130131

@@ -2251,6 +2252,18 @@ def get_surrogate_spec_with_default() -> SurrogateSpec:
22512252
)
22522253

22532254

2255+
def get_surrogate_spec_with_lognormal() -> SurrogateSpec:
2256+
return SurrogateSpec(
2257+
botorch_model_class=SingleTaskGP,
2258+
covar_module_class=RBFKernel,
2259+
covar_module_kwargs={
2260+
"ard_num_dims": DEFAULT,
2261+
"lengthscale_prior": LogNormalPrior(-4.0, 1.0),
2262+
"batch_shape": DEFAULT,
2263+
},
2264+
)
2265+
2266+
22542267
def get_acquisition_type() -> Type[Acquisition]:
22552268
return Acquisition
22562269

0 commit comments

Comments
 (0)