@@ -1201,16 +1201,18 @@ def fit(
12011201 def transform (self ,
12021202 X : Union [npt .NDArray , torch .Tensor ],
12031203 pad_before_transform : bool = True ,
1204+ batch_size : Optional [int ] = None ,
12041205 session_id : Optional [int ] = None ) -> npt .NDArray :
12051206 """Transform an input sequence and return the embedding.
12061207
12071208 Args:
12081209 X: A numpy array or torch tensor of size ``time x dimension``.
12091210 pad_before_transform: If ``False``, no padding is applied to the input sequence.
1210- and the output sequence will be smaller than the input sequence due to the
1211- receptive field of the model. If the input sequence is ``n`` steps long,
1212- and a model with receptive field ``m`` is used, the output sequence would
1213- only be ``n-m+1`` steps long.
1211+ and the output sequence will be smaller than the input sequence due to the
1212+ receptive field of the model. If the input sequence is ``n`` steps long,
1213+ and a model with receptive field ``m`` is used, the output sequence would
1214+ only be ``n-m+1`` steps long.
1215+ batch_size:
12141216 session_id: The session ID, an :py:class:`int` between 0 and :py:attr:`num_sessions` for
12151217 multisession, set to ``None`` for single session.
12161218
@@ -1233,10 +1235,15 @@ def transform(self,
12331235 # Input validation
12341236 X = sklearn_utils .check_input_array (X , min_samples = len (self .offset_ ))
12351237 input_dtype = X .dtype
1238+ #print(type(X))
1239+ #print(X.dtype)
12361240
12371241 with torch .no_grad ():
12381242 output = self .solver_ .transform (
1239- X , pad_before_transform = pad_before_transform )
1243+ inputs = X ,
1244+ pad_before_transform = pad_before_transform ,
1245+ session_id = session_id ,
1246+ batch_size = batch_size )
12401247
12411248 if input_dtype == "float64" :
12421249 return output .astype (input_dtype )
0 commit comments