@@ -44,6 +44,9 @@ class LinearTruncatedFidelityKernel(Kernel):
4444 We assume the last two dimensions of input `x` are the fidelity parameters.
4545
4646 Args:
47+ :attr:`dimension` (int):
48+ The dimension of `x`. This is not needed if active_dims is specified.
49+ Default: `3`
4750 :attr:`nu` (float):
4851 The smoothness parameter fo Matern kernel: either 1/2, 3/2, or 5/2.
4952 Default: '2.5'
@@ -73,6 +76,19 @@ class LinearTruncatedFidelityKernel(Kernel):
7376 :attr:`power_constraint` (Constraint, optional):
7477 Set this if you want to apply a constraint to the power parameter
7578 polynomial kernel. Default: `Positive`
79+ :attr:`train_iteration_fidelity` (bool):
80+ Set this to True if your data contains iteration fidelity parameter.
81+ Default: 'True'
82+ :attr:`train_data_fidelity` (bool):
83+ Set this to True if your data contains training data fidelity parameter.
84+ Default: 'True'
85+ :attr: `covar_module_1` (Kernel):
86+ Set this if you want a different kernel for the unbiased part.
87+ Default: 'MaternKernel'
88+ :attr: `covar_module_2` (Kernel):
89+ Set this if you want a different kernel for the biased part.
90+ Default: 'MaternKernel'
91+
7692
7793 Attributes:
7894 :attr:`lengthscale` (Tensor):
@@ -92,6 +108,7 @@ class LinearTruncatedFidelityKernel(Kernel):
92108
93109 def __init__ (
94110 self ,
111+ dimension : int = 3 ,
95112 nu : float = 2.5 ,
96113 train_iteration_fidelity : bool = True ,
97114 train_data_fidelity : bool = True ,
@@ -100,32 +117,29 @@ def __init__(
100117 power_constraint : Optional [Interval ] = None ,
101118 lengthscale_2_prior : Optional [Prior ] = None ,
102119 lengthscale_2_constraint : Optional [Interval ] = None ,
120+ lengthscale_constraint : Optional [Interval ] = None ,
121+ covar_module_1 : Optional [Kernel ] = None ,
122+ covar_module_2 : Optional [Kernel ] = None ,
103123 ** kwargs : Any ,
104124 ):
105125 if not train_iteration_fidelity and not train_data_fidelity :
106126 raise UnsupportedError ("You should have at least one fidelity parameter." )
107127 if nu not in {0.5 , 1.5 , 2.5 }:
108128 raise ValueError ("nu expected to be 0.5, 1.5, or 2.5" )
109- super ().__init__ (has_lengthscale = True , ** kwargs )
129+ super ().__init__ (** kwargs )
110130 self .train_iteration_fidelity = train_iteration_fidelity
111131 self .train_data_fidelity = train_data_fidelity
112132 if power_constraint is None :
113133 power_constraint = Positive ()
114134
115135 if lengthscale_prior is None :
116- self .lengthscale_prior = GammaPrior (1.1 , 1 / 20 )
117- else :
118- self .lengthscale_prior = lengthscale_prior
136+ lengthscale_prior = GammaPrior (3 , 6 )
119137
120138 if lengthscale_2_prior is None :
121- self .lengthscale_2_prior = GammaPrior (5 , 1 / 20 )
122- else :
123- self .register_prior (
124- "lengthscale_2_prior" ,
125- lengthscale_2_prior ,
126- lambda : self .lengthscale_2 ,
127- lambda v : self ._set_lengthscale_2 (v ),
128- )
139+ lengthscale_2_prior = GammaPrior (6 , 2 )
140+
141+ if lengthscale_constraint is None :
142+ lengthscale_constraint = Positive ()
129143
130144 if lengthscale_2_constraint is None :
131145 lengthscale_2_constraint = Positive ()
@@ -134,10 +148,6 @@ def __init__(
134148 name = "raw_power" ,
135149 parameter = torch .nn .Parameter (torch .zeros (* self .batch_shape , 1 )),
136150 )
137- self .register_parameter (
138- name = "raw_lengthscale_2" ,
139- parameter = torch .nn .Parameter (torch .zeros (* self .batch_shape , 1 )),
140- )
141151
142152 if power_prior is not None :
143153 self .register_prior (
@@ -146,10 +156,35 @@ def __init__(
146156 lambda : self .power ,
147157 lambda v : self ._set_power (v ),
148158 )
149- self .nu = nu
150- self .register_constraint ("raw_lengthscale_2" , lengthscale_2_constraint )
151159 self .register_constraint ("raw_power" , power_constraint )
152160
161+ m = self .train_iteration_fidelity + self .train_data_fidelity
162+
163+ if self .active_dims is not None :
164+ dimension = len (self .active_dims )
165+
166+ if covar_module_1 is None :
167+ self .covar_module_1 = MaternKernel (
168+ nu = nu ,
169+ batch_shape = self .batch_shape ,
170+ lengthscale_prior = lengthscale_prior ,
171+ ard_num_dims = dimension - m ,
172+ lengthscale_constraint = lengthscale_constraint ,
173+ )
174+ else :
175+ self .covar_module_1 = covar_module_1
176+
177+ if covar_module_2 is None :
178+ self .covar_module_2 = MaternKernel (
179+ nu = nu ,
180+ batch_shape = self .batch_shape ,
181+ lengthscale_prior = lengthscale_2_prior ,
182+ ard_num_dims = dimension - m ,
183+ lengthscale_constraint = lengthscale_2_constraint ,
184+ )
185+ else :
186+ self .covar_module_2 = covar_module_2
187+
153188 @property
154189 def power (self ) -> torch .Tensor :
155190 return self .raw_power_constraint .transform (self .raw_power )
@@ -163,46 +198,20 @@ def _set_power(self, value: torch.Tensor) -> None:
163198 value = torch .as_tensor (value ).to (self .raw_power )
164199 self .initialize (raw_power = self .raw_power_constraint .inverse_transform (value ))
165200
166- @property
167- def lengthscale_2 (self ) -> torch .Tensor :
168- return self .raw_lengthscale_2_constraint .transform (self .raw_lengthscale_2 )
169-
170- @lengthscale_2 .setter
171- def lengthscale_2 (self , value : torch .Tensor ) -> None :
172- self ._set_lengthscale_2 (value )
173-
174- def _set_lengthscale_2 (self , value : torch .Tensor ) -> None :
175- if not torch .is_tensor (value ):
176- value = torch .as_tensor (value ).to (self .raw_lengthscale_2 )
177- self .initialize (
178- raw_lengthscale_2 = self .raw_lengthscale_2_constraint .inverse_transform (value )
179- )
180-
181201 def forward (self , x1 : torch .Tensor , x2 : torch .Tensor , ** params ) -> torch .Tensor :
182- m = self .train_iteration_fidelity + self .train_data_fidelity
183202 power = self .power .view (* self .batch_shape , 1 , 1 )
184- active_dimsM = list (range (x1 .size ()[- 1 ] - m ))
185- covar_module_1 = MaternKernel (
186- nu = self .nu ,
187- batch_shape = self .batch_shape ,
188- lengthscale_prior = self .lengthscale_prior ,
189- active_dims = active_dimsM ,
190- ard_num_dims = x1 .shape [- 1 ] - m ,
191- )
192- covar_module_2 = MaternKernel (
193- nu = self .nu ,
194- batch_shape = self .batch_shape ,
195- lengthscale_prior = self .lengthscale_2_prior ,
196- active_dims = active_dimsM ,
197- ard_num_dims = x1 .shape [- 1 ] - m ,
198- )
199- covar_0 = covar_module_1 (x1 , x2 )
203+
204+ m = self .train_iteration_fidelity + self .train_data_fidelity
205+ active_dimsM = list (range (x1 .shape [- 1 ] - m ))
206+ x1_ = x1 [..., active_dimsM ]
207+ x2_ = x2 [..., active_dimsM ]
208+ covar_0 = self .covar_module_1 (x1_ , x2_ )
209+ covar_1 = self .covar_module_2 (x1_ , x2_ )
200210 x11_ = x1 [..., - 1 ].unsqueeze (- 1 )
201211 x21t_ = x2 [..., - 1 ].unsqueeze (- 1 ).transpose (- 1 , - 2 )
202- covar_1 = covar_module_2 (x1 , x2 )
203212 if self .train_iteration_fidelity and self .train_data_fidelity :
204- covar_2 = covar_module_2 (x1 , x2 )
205- covar_3 = covar_module_2 (x1 , x2 )
213+ covar_2 = self . covar_module_2 (x1_ , x2_ )
214+ covar_3 = self . covar_module_2 (x1_ , x2_ )
206215 x12_ = x1 [..., - 2 ].unsqueeze (- 1 )
207216 x22t_ = x2 [..., - 2 ].unsqueeze (- 1 ).transpose (- 1 , - 2 )
208217 res = (
0 commit comments