@@ -150,10 +150,18 @@ def model_list_to_batched(model_list: ModelListGP) -> BatchedMultiOutputGPyTorch
150150 raise UnsupportedError ("All models must have the same fidelity parameters." )
151151 kwargs .update (init_args )
152152
153+ # add batched kernel, except if the model type is SingleTaskMultiFidelityGP,
154+ # which does not have a `covar_module`
155+ if not isinstance (models [0 ], SingleTaskMultiFidelityGP ):
156+ batch_length = len (models )
157+ covar_module = _batched_kernel (models [0 ].covar_module , batch_length )
158+ kwargs ["covar_module" ] = covar_module
159+
153160 # construct the batched GP model
154161 input_transform = getattr (models [0 ], "input_transform" , None )
155162 if input_transform is not None :
156163 input_transform .train ()
164+
157165 batch_gp = models [0 ].__class__ (input_transform = input_transform , ** kwargs )
158166 adjusted_batch_keys , non_adjusted_batch_keys = _get_adjusted_batch_keys (
159167 batch_state_dict = batch_gp .state_dict (), input_transform = input_transform
@@ -196,6 +204,46 @@ def model_list_to_batched(model_list: ModelListGP) -> BatchedMultiOutputGPyTorch
196204 return batch_gp
197205
198206
207+ def _batched_kernel (kernel , batch_length : int ):
208+ """Adds a batch dimension of size `batch_length` to all non-scalar
209+ Tensor parameters that govern the kernel function `kernel`.
210+ NOTE: prior or constraint parameters are excluded from batching.
211+ """
212+ # copy just in case there are non-tensor parameters that are passed by reference
213+ kernel = deepcopy (kernel )
214+ search_str = "raw_outputscale"
215+ for key , attr in kernel .state_dict ().items ():
216+ if isinstance (attr , Tensor ) and (
217+ attr .ndim > 0 or (search_str == key .rpartition ("." )[- 1 ])
218+ ):
219+ attr = attr .unsqueeze (0 ).expand (batch_length , * attr .shape ).clone ()
220+ set_attribute (kernel , key , torch .nn .Parameter (attr ))
221+ return kernel
222+
223+
224+ # two helper functions for `batched_kernel`
225+ # like `setattr` and `getattr` for object hierarchies
226+ def set_attribute (obj , attr : str , val ):
227+ """Like `setattr` but works with hierarchical attribute specification.
228+ E.g. if obj=Zoo(), and attr="tiger.age", set_attribute(obj, attr, 3),
229+ would set the Zoo's tiger's age to three.
230+ """
231+ path_to_leaf , _ , attr_name = attr .rpartition ("." )
232+ leaf = get_attribute (obj , path_to_leaf ) if path_to_leaf else obj
233+ setattr (leaf , attr_name , val )
234+
235+
236+ def get_attribute (obj , attr : str ):
237+ """Like `getattr` but works with hierarchical attribute specification.
238+ E.g. if obj=Zoo(), and attr="tiger.age", get_attribute(obj, attr),
239+ would return the Zoo's tiger's age.
240+ """
241+ attr_names = attr .split ("." )
242+ while attr_names :
243+ obj = getattr (obj , attr_names .pop (0 ))
244+ return obj
245+
246+
199247def batched_to_model_list (batch_model : BatchedMultiOutputGPyTorchModel ) -> ModelListGP :
200248 """Convert a BatchedMultiOutputGPyTorchModel to a ModelListGP.
201249
0 commit comments