File tree Expand file tree Collapse file tree 1 file changed +1
-7
lines changed
cebra/integrations/sklearn Expand file tree Collapse file tree 1 file changed +1
-7
lines changed Original file line number Diff line number Diff line change @@ -1200,18 +1200,12 @@ def fit(
12001200
12011201 def transform (self ,
12021202 X : Union [npt .NDArray , torch .Tensor ],
1203- pad_before_transform : bool = True ,
12041203 batch_size : Optional [int ] = None ,
12051204 session_id : Optional [int ] = None ) -> npt .NDArray :
12061205 """Transform an input sequence and return the embedding.
12071206
12081207 Args:
12091208 X: A numpy array or torch tensor of size ``time x dimension``.
1210- pad_before_transform: If ``False``, no padding is applied to the input sequence.
1211- and the output sequence will be smaller than the input sequence due to the
1212- receptive field of the model. If the input sequence is ``n`` steps long,
1213- and a model with receptive field ``m`` is used, the output sequence would
1214- only be ``n-m+1`` steps long.
12151209 batch_size:
12161210 session_id: The session ID, an :py:class:`int` between 0 and :py:attr:`num_sessions` for
12171211 multisession, set to ``None`` for single session.
@@ -1244,7 +1238,7 @@ def transform(self,
12441238 with torch .no_grad ():
12451239 output = self .solver_ .transform (
12461240 inputs = X ,
1247- pad_before_transform = pad_before_transform ,
1241+ pad_before_transform = self . pad_before_transform ,
12481242 session_id = session_id ,
12491243 batch_size = batch_size )
12501244
You can’t perform that action at this time.
0 commit comments