2121
2222from functools import partial
2323
24- from typing import Any , Callable , List , Optional , Tuple , TypeVar , Union
24+ from typing import Callable , List , Optional , Tuple , TypeVar , Union
2525
2626import torch
2727from botorch .acquisition .cached_cholesky import CachedCholeskyMCSamplerMixin
@@ -275,7 +275,7 @@ def __init__(
275275 cache_root : bool = True ,
276276 tau_max : float = TAU_MAX ,
277277 tau_relu : float = TAU_RELU ,
278- ** kwargs : Any ,
278+ marginalize_dim : Optional [ int ] = None ,
279279 ) -> None :
280280 r"""q-Noisy Expected Improvement.
281281
@@ -314,7 +314,7 @@ def __init__(
314314 approximations to max.
315315 tau_relu: Temperature parameter controlling the sharpness of the smooth
316316 approximations to ReLU.
317- kwargs: Here for qNEI for compatibility .
317+ marginalize_dim: The dimension to marginalize over .
318318
319319 TODO: similar to qNEHVI, when we are using sequential greedy candidate
320320 selection, we could incorporate pending points X_baseline and compute
@@ -343,7 +343,7 @@ def __init__(
343343 posterior_transform = posterior_transform ,
344344 prune_baseline = prune_baseline ,
345345 cache_root = cache_root ,
346- ** kwargs ,
346+ marginalize_dim = marginalize_dim ,
347347 )
348348
349349 def _sample_forward (self , obj : Tensor ) -> Tensor :
@@ -372,7 +372,7 @@ def _init_baseline(
372372 posterior_transform : Optional [PosteriorTransform ] = None ,
373373 prune_baseline : bool = False ,
374374 cache_root : bool = True ,
375- ** kwargs : Any ,
375+ marginalize_dim : Optional [ int ] = None ,
376376 ) -> None :
377377 CachedCholeskyMCSamplerMixin .__init__ (
378378 self , model = model , cache_root = cache_root , sampler = sampler
@@ -383,7 +383,7 @@ def _init_baseline(
383383 X = X_baseline ,
384384 objective = objective ,
385385 posterior_transform = posterior_transform ,
386- marginalize_dim = kwargs . get ( " marginalize_dim" ) ,
386+ marginalize_dim = marginalize_dim ,
387387 constraints = self ._constraints ,
388388 )
389389 self .register_buffer ("X_baseline" , X_baseline )
0 commit comments