44import pandas as pd
55import torch
66from botorch import fit_fully_bayesian_model_nuts
7- from botorch .models .fully_bayesian import SaasFullyBayesianSingleTaskGP
7+ from botorch .models .fully_bayesian import (
8+ FullyBayesianLinearSingleTaskGP ,
9+ FullyBayesianSingleTaskGP ,
10+ SaasFullyBayesianSingleTaskGP ,
11+ )
812from botorch .models .transforms .outcome import Standardize
913
1014from bofire .data_models .enum import OutputFilteringEnum
11- from bofire .data_models .surrogates .api import SaasSingleTaskGPSurrogate as DataModel
15+ from bofire .data_models .surrogates .api import (
16+ FullyBayesianSingleTaskGPSurrogate as DataModel ,
17+ )
1218from bofire .data_models .surrogates .scaler import ScalerEnum
1319from bofire .surrogates .botorch import BotorchSurrogate
1420from bofire .surrogates .trainable import TrainableSurrogate
1521from bofire .surrogates .utils import get_scaler
1622from bofire .utils .torch_tools import tkwargs
1723
1824
19- class SaasSingleTaskGPSurrogate (BotorchSurrogate , TrainableSurrogate ):
25+ _model_mapper = {
26+ "saas" : SaasFullyBayesianSingleTaskGP ,
27+ "linear" : FullyBayesianLinearSingleTaskGP ,
28+ "hvarfner" : FullyBayesianSingleTaskGP ,
29+ }
30+
31+
32+ class FullyBayesianSingleTaskGPSurrogate (BotorchSurrogate , TrainableSurrogate ):
2033 def __init__ (
2134 self ,
2235 data_model : DataModel ,
@@ -27,6 +40,8 @@ def __init__(
2740 self .thinning = data_model .thinning
2841 self .scaler = data_model .scaler
2942 self .output_scaler = data_model .output_scaler
43+ self .features_to_warp = data_model .features_to_warp
44+ self .model_type = data_model .model_type
3045 super ().__init__ (data_model = data_model , ** kwargs )
3146
3247 model : Optional [SaasFullyBayesianSingleTaskGP ] = None
@@ -41,19 +56,40 @@ def _fit(self, X: pd.DataFrame, Y: pd.DataFrame, disable_progbar: bool = True):
4156 torch .from_numpy (transformed_X .values ).to (** tkwargs ),
4257 torch .from_numpy (Y .values ).to (** tkwargs ),
4358 )
59+ try :
60+ self .model = _model_mapper [self .model_type ](
61+ train_X = tX ,
62+ train_Y = tY ,
63+ outcome_transform = (
64+ Standardize (m = 1 )
65+ if self .output_scaler == ScalerEnum .STANDARDIZE
66+ else None
67+ ),
68+ input_transform = scaler ,
69+ use_input_warping = True if len (self .features_to_warp ) > 0 else False ,
70+ indices_to_warp = self .inputs .get_feature_indices (
71+ self .input_preprocessing_specs , self .features_to_warp
72+ )
73+ if len (self .features_to_warp ) > 0
74+ else None , # type: ignore
75+ )
76+ except TypeError :
77+ # For the current release versions of BoTorch,
78+ # the `use_input_warping` argument is not available
79+ # we have to wait for the next release
80+ self .model = _model_mapper [self .model_type ](
81+ train_X = tX ,
82+ train_Y = tY ,
83+ outcome_transform = (
84+ Standardize (m = 1 )
85+ if self .output_scaler == ScalerEnum .STANDARDIZE
86+ else None
87+ ),
88+ input_transform = scaler ,
89+ )
4490
45- self .model = SaasFullyBayesianSingleTaskGP (
46- train_X = tX ,
47- train_Y = tY ,
48- outcome_transform = (
49- Standardize (m = 1 )
50- if self .output_scaler == ScalerEnum .STANDARDIZE
51- else None
52- ),
53- input_transform = scaler ,
54- )
5591 fit_fully_bayesian_model_nuts (
56- self .model ,
92+ self .model , # type: ignore
5793 warmup_steps = self .warmup_steps ,
5894 num_samples = self .num_samples ,
5995 thinning = self .thinning ,
0 commit comments