Skip to content

Commit f0303e0

Browse files
gonlairostes
authored andcommitted
add argument to sklearn api
1 parent 87bebac commit f0303e0

File tree

1 file changed

+12
-5
lines changed

1 file changed

+12
-5
lines changed

cebra/integrations/sklearn/cebra.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1201,16 +1201,18 @@ def fit(
12011201
def transform(self,
12021202
X: Union[npt.NDArray, torch.Tensor],
12031203
pad_before_transform: bool = True,
1204+
batch_size: Optional[int] = None,
12041205
session_id: Optional[int] = None) -> npt.NDArray:
12051206
"""Transform an input sequence and return the embedding.
12061207
12071208
Args:
12081209
X: A numpy array or torch tensor of size ``time x dimension``.
12091210
pad_before_transform: If ``False``, no padding is applied to the input sequence.
1210-
and the output sequence will be smaller than the input sequence due to the
1211-
receptive field of the model. If the input sequence is ``n`` steps long,
1212-
and a model with receptive field ``m`` is used, the output sequence would
1213-
only be ``n-m+1`` steps long.
1211+
and the output sequence will be smaller than the input sequence due to the
1212+
receptive field of the model. If the input sequence is ``n`` steps long,
1213+
and a model with receptive field ``m`` is used, the output sequence would
1214+
only be ``n-m+1`` steps long.
1215+
batch_size:
12141216
session_id: The session ID, an :py:class:`int` between 0 and :py:attr:`num_sessions` for
12151217
multisession, set to ``None`` for single session.
12161218
@@ -1233,10 +1235,15 @@ def transform(self,
12331235
# Input validation
12341236
X = sklearn_utils.check_input_array(X, min_samples=len(self.offset_))
12351237
input_dtype = X.dtype
1238+
#print(type(X))
1239+
#print(X.dtype)
12361240

12371241
with torch.no_grad():
12381242
output = self.solver_.transform(
1239-
X, pad_before_transform=pad_before_transform)
1243+
inputs=X,
1244+
pad_before_transform=pad_before_transform,
1245+
session_id=session_id,
1246+
batch_size=batch_size)
12401247

12411248
if input_dtype == "float64":
12421249
return output.astype(input_dtype)

0 commit comments

Comments
 (0)