3737
3838import literate_dataclasses as dataclasses
3939import numpy as np
40+ import numpy .typing as npt
4041import torch
4142import torch .nn .functional as F
4243import tqdm
@@ -89,32 +90,6 @@ def _check_indices(batch_start_idx: int, batch_end_idx: int,
8990 )
9091
9192
92- def _get_batch (inputs : torch .Tensor , offset : cebra .data .Offset ,
93- batch_start_idx : int , batch_end_idx : int ) -> torch .Tensor :
94- """Get a batch of samples between the `batch_start_idx` and `batch_end_idx`.
95-
96- Args:
97- inputs: Input data.
98- offset: Model offset.
99- batch_start_idx: Index of the first sample in the batch.
100- batch_end_idx: Index of the first sample in the batch.
101-
102- Returns:
103- The batch.
104- """
105-
106- if batch_start_idx == 0 : # First batch
107- indices = batch_start_idx , (batch_end_idx + offset .right - 1 )
108- elif batch_end_idx == len (inputs ): # Last batch
109- indices = (batch_start_idx - offset .left ), batch_end_idx
110- else :
111- indices = batch_start_idx - offset .left , batch_end_idx + offset .right - 1
112-
113- _check_indices (indices [0 ], indices [1 ], offset , len (inputs ))
114- batched_data = inputs [slice (* indices )]
115- return batched_data
116-
117-
11893def _add_batched_zero_padding (batched_data : torch .Tensor ,
11994 offset : cebra .data .Offset , batch_start_idx : int ,
12095 batch_end_idx : int ,
@@ -145,6 +120,45 @@ def _add_batched_zero_padding(batched_data: torch.Tensor,
145120 return batched_data
146121
147122
123+ def _get_batch (inputs : torch .Tensor , offset : Optional [cebra .data .Offset ],
124+ batch_start_idx : int , batch_end_idx : int ,
125+ pad_before_transform : bool ) -> torch .Tensor :
126+ """Get a batch of samples between the `batch_start_idx` and `batch_end_idx`.
127+
128+ Args:
129+ inputs: Input data.
130+ offset: Model offset.
131+ batch_start_idx: Index of the first sample in the batch.
132+ batch_end_idx: Index of the first sample in the batch.
133+ pad_before_transform: If True zero-pad the batched data.
134+
135+ Returns:
136+ The batch.
137+ """
138+ if offset is None :
139+ raise ValueError (f"offset cannot be null." )
140+
141+ if batch_start_idx == 0 : # First batch
142+ indices = batch_start_idx , (batch_end_idx + offset .right - 1 )
143+ elif batch_end_idx == len (inputs ): # Last batch
144+ indices = (batch_start_idx - offset .left ), batch_end_idx
145+ else :
146+ indices = batch_start_idx - offset .left , batch_end_idx + offset .right - 1
147+
148+ _check_indices (indices [0 ], indices [1 ], offset , len (inputs ))
149+ batched_data = inputs [slice (* indices )]
150+
151+ if pad_before_transform :
152+ batched_data = _add_batched_zero_padding (
153+ batched_data = batched_data ,
154+ offset = offset ,
155+ batch_start_idx = batch_start_idx ,
156+ batch_end_idx = batch_end_idx ,
157+ num_samples = len (inputs ))
158+
159+ return batched_data
160+
161+
148162def _inference_transform (model : cebra .models .Model ,
149163 inputs : torch .Tensor ) -> torch .Tensor :
150164 """Compute the embedding on the inputs using the model provided.
@@ -156,9 +170,7 @@ def _inference_transform(model: cebra.models.Model,
156170 Returns:
157171 The embedding.
158172 """
159- #TODO(rodrigo): I am not sure what is the best way with dealing with the types and
160- # device when using batched inference. This works for now.
161- inputs = inputs .type (torch .FloatTensor ).to (next (model .parameters ()).device )
173+ inputs = inputs .float ().to (next (model .parameters ()).device )
162174
163175 if isinstance (model , cebra .models .ConvolutionalModelMixin ):
164176 # Fully convolutional evaluation, switch (T, C) -> (1, C, T)
@@ -228,15 +240,8 @@ def __getitem__(self, idx):
228240 batched_data = _get_batch (inputs = inputs ,
229241 offset = offset ,
230242 batch_start_idx = batch_start_idx ,
231- batch_end_idx = batch_end_idx )
232-
233- if pad_before_transform :
234- batched_data = _add_batched_zero_padding (
235- batched_data = batched_data ,
236- offset = offset ,
237- batch_start_idx = batch_start_idx ,
238- batch_end_idx = batch_end_idx ,
239- num_samples = len (inputs ))
243+ batch_end_idx = batch_end_idx ,
244+ pad_before_transform = pad_before_transform )
240245
241246 output_batch = _inference_transform (model , batched_data )
242247 output .append (output_batch )
@@ -549,6 +554,15 @@ def transform(self,
549554 Returns:
550555 The output embedding.
551556 """
557+ if isinstance (inputs , list ):
558+ raise NotImplementedError (
559+ "Inputs to transform() should be the data for a single session."
560+ )
561+
562+ elif not isinstance (inputs , torch .Tensor ):
563+ raise ValueError (
564+ f"Inputs should be a torch.Tensor, not { type (inputs )} ." )
565+
552566 if not hasattr (self , "n_features" ):
553567 raise ValueError (
554568 f"This { type (self ).__name__ } instance is not fitted yet. Call 'fit' with "
0 commit comments