Skip to content

Commit f0856d0

Browse files
authored
Additions for Fully Bayesian Models (#606)
* add map_saas * implement new fully bayesian models * make tests compatible to botorch release * fix test * fix exception * fix tests * fix tests
1 parent aea2f61 commit f0856d0

File tree

13 files changed

+347
-52
lines changed

13 files changed

+347
-52
lines changed

bofire/data_models/surrogates/api.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,11 @@
1111
LinearDeterministicSurrogate,
1212
)
1313
from bofire.data_models.surrogates.empirical import EmpiricalSurrogate
14-
from bofire.data_models.surrogates.fully_bayesian import SaasSingleTaskGPSurrogate
14+
from bofire.data_models.surrogates.fully_bayesian import (
15+
FullyBayesianSingleTaskGPSurrogate,
16+
)
1517
from bofire.data_models.surrogates.linear import LinearSurrogate
18+
from bofire.data_models.surrogates.map_saas import AdditiveMapSaasSingleTaskGPSurrogate
1619
from bofire.data_models.surrogates.mixed_single_task_gp import (
1720
MixedSingleTaskGPHyperconfig,
1821
MixedSingleTaskGPSurrogate,
@@ -56,7 +59,7 @@
5659
MixedTanimotoGPSurrogate,
5760
ClassificationMLPEnsemble,
5861
RegressionMLPEnsemble,
59-
SaasSingleTaskGPSurrogate,
62+
FullyBayesianSingleTaskGPSurrogate,
6063
XGBoostSurrogate,
6164
LinearSurrogate,
6265
PolynomialSurrogate,
@@ -66,6 +69,7 @@
6669
MultiTaskGPSurrogate,
6770
SingleTaskIBNNSurrogate,
6871
PiecewiseLinearGPSurrogate,
72+
AdditiveMapSaasSingleTaskGPSurrogate,
6973
]
7074

7175
AnyTrainableSurrogate = Union[
@@ -76,13 +80,14 @@
7680
MixedTanimotoGPSurrogate,
7781
ClassificationMLPEnsemble,
7882
RegressionMLPEnsemble,
79-
SaasSingleTaskGPSurrogate,
83+
FullyBayesianSingleTaskGPSurrogate,
8084
XGBoostSurrogate,
8185
LinearSurrogate,
8286
PolynomialSurrogate,
8387
SingleTaskIBNNSurrogate,
8488
TanimotoGPSurrogate,
8589
PiecewiseLinearGPSurrogate,
90+
AdditiveMapSaasSingleTaskGPSurrogate,
8691
]
8792

8893
AnyRegressionSurrogate = Union[
@@ -93,7 +98,7 @@
9398
MixedSingleTaskGPSurrogate,
9499
MixedTanimotoGPSurrogate,
95100
RegressionMLPEnsemble,
96-
SaasSingleTaskGPSurrogate,
101+
FullyBayesianSingleTaskGPSurrogate,
97102
XGBoostSurrogate,
98103
LinearSurrogate,
99104
PolynomialSurrogate,
@@ -102,6 +107,7 @@
102107
MultiTaskGPSurrogate,
103108
SingleTaskIBNNSurrogate,
104109
PiecewiseLinearGPSurrogate,
110+
AdditiveMapSaasSingleTaskGPSurrogate,
105111
]
106112

107113
AnyClassificationSurrogate = ClassificationMLPEnsemble

bofire/data_models/surrogates/botorch_surrogates.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@
1010
LinearDeterministicSurrogate,
1111
)
1212
from bofire.data_models.surrogates.empirical import EmpiricalSurrogate
13-
from bofire.data_models.surrogates.fully_bayesian import SaasSingleTaskGPSurrogate
13+
from bofire.data_models.surrogates.fully_bayesian import (
14+
FullyBayesianSingleTaskGPSurrogate,
15+
)
1416
from bofire.data_models.surrogates.linear import LinearSurrogate
1517
from bofire.data_models.surrogates.mixed_single_task_gp import (
1618
MixedSingleTaskGPSurrogate,
@@ -37,7 +39,7 @@
3739
MixedTanimotoGPSurrogate,
3840
RegressionMLPEnsemble,
3941
ClassificationMLPEnsemble,
40-
SaasSingleTaskGPSurrogate,
42+
FullyBayesianSingleTaskGPSurrogate,
4143
TanimotoGPSurrogate,
4244
LinearSurrogate,
4345
PolynomialSurrogate,

bofire/data_models/surrogates/fully_bayesian.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,33 @@
1-
from typing import Annotated, Literal, Type
1+
from typing import Annotated, List, Literal, Type
22

3-
from pydantic import Field, field_validator
3+
from pydantic import AfterValidator, Field, field_validator, model_validator
44

55
from bofire.data_models.features.api import AnyOutput, ContinuousOutput
66
from bofire.data_models.surrogates.trainable_botorch import TrainableBotorchSurrogate
7+
from bofire.data_models.types import make_unique_validator
78

89

9-
class SaasSingleTaskGPSurrogate(TrainableBotorchSurrogate):
10-
type: Literal["SaasSingleTaskGPSurrogate"] = "SaasSingleTaskGPSurrogate"
10+
class FullyBayesianSingleTaskGPSurrogate(TrainableBotorchSurrogate):
11+
type: Literal["FullyBayesianSingleTaskGPSurrogate"] = (
12+
"FullyBayesianSingleTaskGPSurrogate"
13+
)
14+
model_type: Literal["linear", "saas", "hvarfner"] = "saas"
1115
warmup_steps: Annotated[int, Field(ge=1)] = 256
1216
num_samples: Annotated[int, Field(ge=1)] = 128
1317
thinning: Annotated[int, Field(ge=1)] = 16
18+
features_to_warp: Annotated[
19+
List[str], AfterValidator(make_unique_validator("Features"))
20+
] = []
21+
22+
@model_validator(mode="after")
23+
def validate_features_to_warp(self):
24+
input_keys = self.inputs.get_keys()
25+
for feature in self.features_to_warp:
26+
if feature not in input_keys:
27+
raise ValueError(
28+
f"Feature '{feature}' in features_to_warp is not a valid input key."
29+
)
30+
return self
1431

1532
@field_validator("thinning")
1633
@classmethod
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
from typing import Literal, Type
2+
3+
from pydantic import PositiveInt
4+
5+
from bofire.data_models.features.api import AnyOutput, ContinuousOutput
6+
from bofire.data_models.surrogates.trainable_botorch import TrainableBotorchSurrogate
7+
8+
9+
class TestSurrogate:
10+
pass
11+
12+
13+
class AdditiveMapSaasSingleTaskGPSurrogate(TrainableBotorchSurrogate):
14+
"""Additive MAP SAAS single-task GP
15+
16+
Maximum-a-posteriori (MAP) version of the sparse axis-aligned subspace
17+
`FullyBayesianSingleTaskGPSurrogate` with `model_type` equals to "saas".
18+
19+
Attributes:
20+
n_taus (PositiveInt): Number of sub-kernels to use in the SAAS model.
21+
"""
22+
23+
type: Literal["AdditiveMapSaasSingleTaskGPSurrogate"] = (
24+
"AdditiveMapSaasSingleTaskGPSurrogate"
25+
)
26+
n_taus: PositiveInt = 4
27+
28+
@classmethod
29+
def is_output_implemented(cls, my_type: Type[AnyOutput]) -> bool:
30+
"""Abstract method to check output type for surrogate models
31+
Args:
32+
my_type: continuous or categorical output
33+
Returns:
34+
bool: True if the output type is valid for the surrogate chosen, False otherwise
35+
"""
36+
return isinstance(my_type, type(ContinuousOutput))

bofire/surrogates/api.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from bofire.surrogates.botorch_surrogates import BotorchSurrogates
22
from bofire.surrogates.deterministic import LinearDeterministicSurrogate
33
from bofire.surrogates.empirical import EmpiricalSurrogate
4+
from bofire.surrogates.map_saas import AdditiveMapSaasSingleTaskGPSurrogate
45
from bofire.surrogates.mapper import map
56
from bofire.surrogates.mixed_single_task_gp import MixedSingleTaskGPSurrogate
67
from bofire.surrogates.mixed_tanimoto_gp import MixedTanimotoGPSurrogate

bofire/surrogates/fully_bayesian.py

Lines changed: 50 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,32 @@
44
import pandas as pd
55
import torch
66
from botorch import fit_fully_bayesian_model_nuts
7-
from botorch.models.fully_bayesian import SaasFullyBayesianSingleTaskGP
7+
from botorch.models.fully_bayesian import (
8+
FullyBayesianLinearSingleTaskGP,
9+
FullyBayesianSingleTaskGP,
10+
SaasFullyBayesianSingleTaskGP,
11+
)
812
from botorch.models.transforms.outcome import Standardize
913

1014
from bofire.data_models.enum import OutputFilteringEnum
11-
from bofire.data_models.surrogates.api import SaasSingleTaskGPSurrogate as DataModel
15+
from bofire.data_models.surrogates.api import (
16+
FullyBayesianSingleTaskGPSurrogate as DataModel,
17+
)
1218
from bofire.data_models.surrogates.scaler import ScalerEnum
1319
from bofire.surrogates.botorch import BotorchSurrogate
1420
from bofire.surrogates.trainable import TrainableSurrogate
1521
from bofire.surrogates.utils import get_scaler
1622
from bofire.utils.torch_tools import tkwargs
1723

1824

19-
class SaasSingleTaskGPSurrogate(BotorchSurrogate, TrainableSurrogate):
25+
_model_mapper = {
26+
"saas": SaasFullyBayesianSingleTaskGP,
27+
"linear": FullyBayesianLinearSingleTaskGP,
28+
"hvarfner": FullyBayesianSingleTaskGP,
29+
}
30+
31+
32+
class FullyBayesianSingleTaskGPSurrogate(BotorchSurrogate, TrainableSurrogate):
2033
def __init__(
2134
self,
2235
data_model: DataModel,
@@ -27,6 +40,8 @@ def __init__(
2740
self.thinning = data_model.thinning
2841
self.scaler = data_model.scaler
2942
self.output_scaler = data_model.output_scaler
43+
self.features_to_warp = data_model.features_to_warp
44+
self.model_type = data_model.model_type
3045
super().__init__(data_model=data_model, **kwargs)
3146

3247
model: Optional[SaasFullyBayesianSingleTaskGP] = None
@@ -41,19 +56,40 @@ def _fit(self, X: pd.DataFrame, Y: pd.DataFrame, disable_progbar: bool = True):
4156
torch.from_numpy(transformed_X.values).to(**tkwargs),
4257
torch.from_numpy(Y.values).to(**tkwargs),
4358
)
59+
try:
60+
self.model = _model_mapper[self.model_type](
61+
train_X=tX,
62+
train_Y=tY,
63+
outcome_transform=(
64+
Standardize(m=1)
65+
if self.output_scaler == ScalerEnum.STANDARDIZE
66+
else None
67+
),
68+
input_transform=scaler,
69+
use_input_warping=True if len(self.features_to_warp) > 0 else False,
70+
indices_to_warp=self.inputs.get_feature_indices(
71+
self.input_preprocessing_specs, self.features_to_warp
72+
)
73+
if len(self.features_to_warp) > 0
74+
else None, # type: ignore
75+
)
76+
except TypeError:
77+
# For the current release versions of BoTorch,
78+
# the `use_input_warping` argument is not available
79+
# we have to wait for the next release
80+
self.model = _model_mapper[self.model_type](
81+
train_X=tX,
82+
train_Y=tY,
83+
outcome_transform=(
84+
Standardize(m=1)
85+
if self.output_scaler == ScalerEnum.STANDARDIZE
86+
else None
87+
),
88+
input_transform=scaler,
89+
)
4490

45-
self.model = SaasFullyBayesianSingleTaskGP(
46-
train_X=tX,
47-
train_Y=tY,
48-
outcome_transform=(
49-
Standardize(m=1)
50-
if self.output_scaler == ScalerEnum.STANDARDIZE
51-
else None
52-
),
53-
input_transform=scaler,
54-
)
5591
fit_fully_bayesian_model_nuts(
56-
self.model,
92+
self.model, # type: ignore
5793
warmup_steps=self.warmup_steps,
5894
num_samples=self.num_samples,
5995
thinning=self.thinning,

bofire/surrogates/map_saas.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
from typing import Dict, Optional
2+
3+
import pandas as pd
4+
import torch
5+
from botorch.fit import fit_gpytorch_mll
6+
from botorch.models.map_saas import AdditiveMapSaasSingleTaskGP
7+
from botorch.models.transforms.outcome import Standardize
8+
from gpytorch.mlls import ExactMarginalLogLikelihood
9+
10+
from bofire.data_models.enum import OutputFilteringEnum
11+
from bofire.data_models.surrogates.api import (
12+
AdditiveMapSaasSingleTaskGPSurrogate as DataModel,
13+
)
14+
from bofire.data_models.surrogates.scaler import ScalerEnum
15+
from bofire.surrogates.botorch import BotorchSurrogate
16+
from bofire.surrogates.trainable import TrainableSurrogate
17+
from bofire.surrogates.utils import get_scaler
18+
from bofire.utils.torch_tools import tkwargs
19+
20+
21+
class AdditiveMapSaasSingleTaskGPSurrogate(BotorchSurrogate, TrainableSurrogate):
22+
def __init__(
23+
self,
24+
data_model: DataModel,
25+
**kwargs,
26+
):
27+
self.n_taus = data_model.n_taus
28+
self.scaler = data_model.scaler
29+
self.output_scaler = data_model.output_scaler
30+
super().__init__(data_model=data_model, **kwargs)
31+
32+
model: Optional[AdditiveMapSaasSingleTaskGP] = None
33+
_output_filtering: OutputFilteringEnum = OutputFilteringEnum.ALL
34+
training_specs: Dict = {}
35+
36+
def _fit(self, X: pd.DataFrame, Y: pd.DataFrame, disable_progbar: bool = True): # type: ignore
37+
scaler = get_scaler(self.inputs, self.input_preprocessing_specs, self.scaler, X)
38+
transformed_X = self.inputs.transform(X, self.input_preprocessing_specs)
39+
40+
tX, tY = (
41+
torch.from_numpy(transformed_X.values).to(**tkwargs),
42+
torch.from_numpy(Y.values).to(**tkwargs),
43+
)
44+
45+
self.model = AdditiveMapSaasSingleTaskGP(
46+
train_X=tX,
47+
train_Y=tY,
48+
outcome_transform=(
49+
Standardize(m=1)
50+
if self.output_scaler == ScalerEnum.STANDARDIZE
51+
else None
52+
),
53+
input_transform=scaler,
54+
num_taus=self.n_taus,
55+
)
56+
mll = ExactMarginalLogLikelihood(self.model.likelihood, self.model)
57+
fit_gpytorch_mll(mll, options=self.training_specs, max_attempts=10)

bofire/surrogates/mapper.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
LinearDeterministicSurrogate,
77
)
88
from bofire.surrogates.empirical import EmpiricalSurrogate
9-
from bofire.surrogates.fully_bayesian import SaasSingleTaskGPSurrogate
9+
from bofire.surrogates.fully_bayesian import FullyBayesianSingleTaskGPSurrogate
10+
from bofire.surrogates.map_saas import AdditiveMapSaasSingleTaskGPSurrogate
1011
from bofire.surrogates.mixed_single_task_gp import MixedSingleTaskGPSurrogate
1112
from bofire.surrogates.mixed_tanimoto_gp import MixedTanimotoGPSurrogate
1213
from bofire.surrogates.mlp import ClassificationMLPEnsemble, RegressionMLPEnsemble
@@ -28,7 +29,7 @@
2829
data_models.MixedTanimotoGPSurrogate: MixedTanimotoGPSurrogate,
2930
data_models.RegressionMLPEnsemble: RegressionMLPEnsemble,
3031
data_models.ClassificationMLPEnsemble: ClassificationMLPEnsemble,
31-
data_models.SaasSingleTaskGPSurrogate: SaasSingleTaskGPSurrogate,
32+
data_models.FullyBayesianSingleTaskGPSurrogate: FullyBayesianSingleTaskGPSurrogate,
3233
data_models.XGBoostSurrogate: XGBoostSurrogate,
3334
data_models.LinearSurrogate: SingleTaskGPSurrogate,
3435
data_models.PolynomialSurrogate: SingleTaskGPSurrogate,
@@ -38,6 +39,7 @@
3839
data_models.SingleTaskIBNNSurrogate: SingleTaskGPSurrogate,
3940
data_models.PiecewiseLinearGPSurrogate: PiecewiseLinearGPSurrogate,
4041
data_models.CategoricalDeterministicSurrogate: CategoricalDeterministicSurrogate,
42+
data_models.AdditiveMapSaasSingleTaskGPSurrogate: AdditiveMapSaasSingleTaskGPSurrogate,
4143
}
4244

4345

0 commit comments

Comments
 (0)