Skip to content

Commit b83c835

Browse files
saitcakmakfacebook-github-bot
authored andcommitted
Add storage support for SobolQMCNormalSampler (facebook#2622)
Summary: Pull Request resolved: facebook#2622 Reviewed By: mgarrard Differential Revision: D60432564 fbshipit-source-id: 56026fe2b327f44df2dda5c8518c4943286ff9ad
1 parent 7e6a7db commit b83c835

File tree

4 files changed

+30
-0
lines changed

4 files changed

+30
-0
lines changed

ax/storage/botorch_modular_registry.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@
7575
OutcomeTransform,
7676
Standardize,
7777
)
78+
from botorch.sampling.normal import SobolQMCNormalSampler
7879

7980
# Miscellaneous BoTorch imports
8081
from gpytorch.constraints import Interval
@@ -168,6 +169,7 @@
168169
Interval: "Interval",
169170
GammaPrior: "GammaPrior",
170171
LogNormalPrior: "LogNormalPrior",
172+
SobolQMCNormalSampler: "SobolQMCNormalSampler",
171173
}
172174

173175
"""
@@ -205,6 +207,7 @@
205207
LogNormalPrior: GPYTORCH_COMPONENT_REGISTRY,
206208
InputTransform: INPUT_TRANSFORM_REGISTRY,
207209
OutcomeTransform: OUTCOME_TRANSFORM_REGISTRY,
210+
SobolQMCNormalSampler: GPYTORCH_COMPONENT_REGISTRY,
208211
}
209212

210213

ax/storage/json_store/encoders.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
from ax.utils.common.typeutils import not_none
6464
from ax.utils.common.typeutils_torch import torch_type_to_str
6565
from botorch.models.transforms.input import ChainedInputTransform, InputTransform
66+
from botorch.sampling.base import MCSampler
6667
from botorch.utils.types import _DefaultType
6768
from torch import Tensor
6869

@@ -593,6 +594,11 @@ def botorch_component_to_dict(input_obj: Any) -> Dict[str, Any]:
593594
state_dict = botorch_input_transform_to_init_args(input_transform=input_obj)
594595
else:
595596
state_dict = dict(input_obj.state_dict())
597+
if isinstance(input_obj, MCSampler):
598+
# The sampler args are not part of the state dict. Manually add them.
599+
# Sample shape cannot be added to torch state dict since it is not a tensor.
600+
state_dict["sample_shape"] = input_obj.sample_shape
601+
state_dict["seed"] = input_obj.seed
596602
return {
597603
"__type": f"{class_type.__name__}",
598604
"index": class_type,

ax/storage/json_store/registry.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,7 @@
167167
from botorch.acquisition.acquisition import AcquisitionFunction
168168
from botorch.models.model import Model
169169
from botorch.models.transforms.input import ChainedInputTransform, Normalize, Round
170+
from botorch.sampling.normal import SobolQMCNormalSampler
170171
from botorch.utils.types import DEFAULT
171172
from gpytorch.constraints import Interval
172173
from gpytorch.likelihoods.likelihood import Likelihood
@@ -252,6 +253,7 @@
252253
SearchSpace: search_space_to_dict,
253254
SingleDiagnosticBestModelSelector: best_model_selector_to_dict,
254255
HierarchicalSearchSpace: search_space_to_dict,
256+
SobolQMCNormalSampler: botorch_component_to_dict,
255257
SumConstraint: sum_parameter_constraint_to_dict,
256258
Surrogate: surrogate_to_dict,
257259
BenchmarkMetric: metric_to_dict,
@@ -385,6 +387,7 @@
385387
"SurrogateMetric": BenchmarkMetric, # backward-compatiblity
386388
# NOTE: SurrogateRunners -> SyntheticRunner on load due to complications
387389
"SurrogateRunner": SyntheticRunner,
390+
"SobolQMCNormalSampler": SobolQMCNormalSampler,
388391
"SyntheticRunner": SyntheticRunner,
389392
"SurrogateSpec": SurrogateSpec,
390393
"Trial": Trial,

ax/storage/json_store/tests/test_json_store.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@
130130
sobol_gpei_generation_node_gs,
131131
)
132132
from ax.utils.testing.utils import generic_equals
133+
from botorch.sampling.normal import SobolQMCNormalSampler
133134

134135

135136
# pyre-fixme[5]: Global expression must be annotated.
@@ -731,3 +732,20 @@ def test_generation_step_backwards_compatibility(self) -> None:
731732
generation_step = object_from_json(json)
732733
self.assertIsInstance(generation_step, GenerationStep)
733734
self.assertEqual(generation_step.model_kwargs, {"other_kwarg": 5})
735+
736+
def test_SobolQMCNormalSampler(self) -> None:
737+
# This fails default equality checks, so testing it separately.
738+
sampler = SobolQMCNormalSampler(sample_shape=torch.Size([2]))
739+
sampler_json = object_to_json(
740+
sampler,
741+
encoder_registry=CORE_ENCODER_REGISTRY,
742+
class_encoder_registry=CORE_CLASS_ENCODER_REGISTRY,
743+
)
744+
sampler_loaded = object_from_json(
745+
sampler_json,
746+
decoder_registry=CORE_DECODER_REGISTRY,
747+
class_decoder_registry=CORE_CLASS_DECODER_REGISTRY,
748+
)
749+
self.assertIsInstance(sampler_loaded, SobolQMCNormalSampler)
750+
self.assertEqual(sampler.sample_shape, sampler_loaded.sample_shape)
751+
self.assertEqual(sampler.seed, sampler_loaded.seed)

0 commit comments

Comments
 (0)