Skip to content

Commit de300a9

Browse files
committed
Change masking kwargs as tuple and not dict in sklearn impl
1 parent d91949f commit de300a9

File tree

1 file changed

+9
-7
lines changed

1 file changed

+9
-7
lines changed

cebra/integrations/sklearn/cebra.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -491,15 +491,16 @@ class CEBRA(TransformerMixin, BaseEstimator):
491491
hybrid (bool):
492492
If ``True``, the model will be trained using both the time-contrastive and the selected
493493
behavior-constrastive loss functions. |Default:| ``False``.
494-
optimizer_kwargs (dict):
494+
optimizer_kwargs (tuple):
495495
Additional optimization parameters. These have the form ``((key, value), (key, value))`` and
496496
are passed to the PyTorch optimizer specified through the ``optimizer`` argument. Refer to the
497497
optimizer documentation in :py:mod:`torch.optim` for further information on how to format the
498498
arguments.
499499
|Default:| ``(('betas', (0.9, 0.999)), ('eps', 1e-08), ('weight_decay', 0), ('amsgrad', False))``
500-
masking_kwargs (dict):
501-
A dictionary of masking types and their corresponding required masking values. The keys are the
502-
names of the Mask instances.
500+
masking_kwargs (tuple):
501+
A Tuple of masking types and their corresponding required masking values. The keys are the
502+
names of the Mask instances and formating should be ``((key, value), (key, value))``.
503+
|Default:| ``None``.
503504
504505
Example:
505506
@@ -573,8 +574,8 @@ def __init__(
573574
("weight_decay", 0),
574575
("amsgrad", False),
575576
),
576-
masking_kwargs: Dict[str, Union[float, List[float], Tuple[float,
577-
...]]] = None,
577+
masking_kwargs: Tuple[Tuple[str, Union[float, List[float],
578+
Tuple[float, ...]]], ...] = None,
578579
):
579580
self.__dict__.update(locals())
580581

@@ -901,7 +902,8 @@ def _prepare_fit(
901902
self.offset_ = self._compute_offset()
902903
dataset, is_multisession = self._prepare_data(X, y)
903904

904-
dataset.set_masks(self.masking_kwargs)
905+
if self.masking_kwargs:
906+
dataset.set_masks(dict(self.masking_kwargs))
905907

906908
loader, solver_name = self._prepare_loader(
907909
dataset,

0 commit comments

Comments
 (0)