@@ -491,15 +491,16 @@ class CEBRA(TransformerMixin, BaseEstimator):
491491 hybrid (bool):
492492 If ``True``, the model will be trained using both the time-contrastive and the selected
493493 behavior-constrastive loss functions. |Default:| ``False``.
494- optimizer_kwargs (dict ):
494+ optimizer_kwargs (tuple ):
495495 Additional optimization parameters. These have the form ``((key, value), (key, value))`` and
496496 are passed to the PyTorch optimizer specified through the ``optimizer`` argument. Refer to the
497497 optimizer documentation in :py:mod:`torch.optim` for further information on how to format the
498498 arguments.
499499 |Default:| ``(('betas', (0.9, 0.999)), ('eps', 1e-08), ('weight_decay', 0), ('amsgrad', False))``
500- masking_kwargs (dict):
501- A dictionary of masking types and their corresponding required masking values. The keys are the
502- names of the Mask instances.
500+ masking_kwargs (tuple):
501+ A Tuple of masking types and their corresponding required masking values. The keys are the
502+ names of the Mask instances and formating should be ``((key, value), (key, value))``.
503+ |Default:| ``None``.
503504
504505 Example:
505506
@@ -573,8 +574,8 @@ def __init__(
573574 ("weight_decay" , 0 ),
574575 ("amsgrad" , False ),
575576 ),
576- masking_kwargs : Dict [ str , Union [float , List [float ], Tuple [ float ,
577- ...]]] = None ,
577+ masking_kwargs : Tuple [ Tuple [ str , Union [float , List [float ],
578+ Tuple [ float , ...]]], ... ] = None ,
578579 ):
579580 self .__dict__ .update (locals ())
580581
@@ -901,7 +902,8 @@ def _prepare_fit(
901902 self .offset_ = self ._compute_offset ()
902903 dataset , is_multisession = self ._prepare_data (X , y )
903904
904- dataset .set_masks (self .masking_kwargs )
905+ if self .masking_kwargs :
906+ dataset .set_masks (dict (self .masking_kwargs ))
905907
906908 loader , solver_name = self ._prepare_loader (
907909 dataset ,
0 commit comments