Skip to content

Commit 928d882

Browse files
gonlairostes
authored andcommitted
change argument position
1 parent 5e7a14c commit 928d882

File tree

1 file changed

+1
-7
lines changed

1 file changed

+1
-7
lines changed

cebra/integrations/sklearn/cebra.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1200,18 +1200,12 @@ def fit(
12001200

12011201
def transform(self,
12021202
X: Union[npt.NDArray, torch.Tensor],
1203-
pad_before_transform: bool = True,
12041203
batch_size: Optional[int] = None,
12051204
session_id: Optional[int] = None) -> npt.NDArray:
12061205
"""Transform an input sequence and return the embedding.
12071206
12081207
Args:
12091208
X: A numpy array or torch tensor of size ``time x dimension``.
1210-
pad_before_transform: If ``False``, no padding is applied to the input sequence.
1211-
and the output sequence will be smaller than the input sequence due to the
1212-
receptive field of the model. If the input sequence is ``n`` steps long,
1213-
and a model with receptive field ``m`` is used, the output sequence would
1214-
only be ``n-m+1`` steps long.
12151209
batch_size:
12161210
session_id: The session ID, an :py:class:`int` between 0 and :py:attr:`num_sessions` for
12171211
multisession, set to ``None`` for single session.
@@ -1244,7 +1238,7 @@ def transform(self,
12441238
with torch.no_grad():
12451239
output = self.solver_.transform(
12461240
inputs=X,
1247-
pad_before_transform=pad_before_transform,
1241+
pad_before_transform=self.pad_before_transform,
12481242
session_id=session_id,
12491243
batch_size=batch_size)
12501244

0 commit comments

Comments
 (0)