@@ -66,56 +66,32 @@ def _inference_transform(model, inputs):
6666 return output
6767
6868
69- def _pad_with_data (inputs : torch .Tensor , offset : cebra .data .Offset ,
70- start_batch_idx : int , end_batch_idx : int ) -> torch .Tensor :
71- """
72- Pads a batch of input data with its own data (maybe this is not called padding)
73-
74- Args:
75- inputs: The input data to be processed.
76- add_padding: Indicates whether padding should be applied before inference.
77- offset: Offset configuration for padding. If add_padding is True,
78- offset must be set. If add_padding is False, offset is not used and can be None.
79- start_batch_idx: The starting index of the current batch.
80- end_batch_idx: The last index of the current batch.
81-
82- Returns:
83- torch.Tensor: The (potentially) padded data.
84-
85- Raises:
86- ValueError: If add_padding is True and offset is not provided.
87- """
88-
89- def _check_indices (indices , inputs ):
90- if (indices [0 ] < 0 ) or (indices [1 ] > inputs .shape [0 ]):
91- raise ValueError (
92- f"offset { offset } is too big for the length of the inputs ({ len (inputs )} ) "
93- f"The indices { indices } do not match the inputs length { len (inputs )} ."
94- )
69+ def _check_indices (start_batch_idx , end_batch_idx , offset , num_samples ):
9570
9671 if start_batch_idx < 0 or end_batch_idx < 0 :
9772 raise ValueError (
9873 f"start_batch_idx ({ start_batch_idx } ) and end_batch_idx ({ end_batch_idx } ) must be non-negative."
9974 )
100-
10175 if start_batch_idx > end_batch_idx :
10276 raise ValueError (
10377 f"start_batch_idx ({ start_batch_idx } ) cannot be greater than end_batch_idx ({ end_batch_idx } )."
10478 )
79+ if end_batch_idx > num_samples :
80+ raise ValueError (
81+ f"end_batch_idx ({ end_batch_idx } ) cannot exceed the length of inputs ({ num_samples } )."
82+ )
10583
106- if end_batch_idx > len (inputs ):
84+ batch_size_lenght = end_batch_idx - start_batch_idx
85+ if batch_size_lenght <= len (offset ):
10786 raise ValueError (
108- f"end_batch_idx ({ end_batch_idx } ) cannot exceed the length of inputs ({ len (inputs )} )."
87+ f"The batch has length { batch_size_lenght } which "
88+ f"is smaller or equal than the required offset length { len (offset )} ."
89+ f"Either choose a model with smaller offset or the batch shoud contain more samples."
10990 )
11091
111- def _check_batch_size_length (indices_batch , offset ):
112- batch_size_lenght = indices_batch [1 ] - indices_batch [0 ]
113- if batch_size_lenght <= len (offset ):
114- raise ValueError (
115- f"The batch has length { batch_size_lenght } which "
116- f"is smaller or equal than the required offset length { len (offset )} ."
117- f"Either choose a model with smaller offset or the batch shoud contain more samples."
118- )
92+
93+ def _get_batch (inputs : torch .Tensor , offset : cebra .data .Offset ,
94+ start_batch_idx : int , end_batch_idx : int ) -> torch .Tensor :
11995
12096 if start_batch_idx == 0 : # First batch
12197 indices = start_batch_idx , (end_batch_idx + offset .right - 1 )
@@ -126,12 +102,25 @@ def _check_batch_size_length(indices_batch, offset):
126102 else : # Middle batches
127103 indices = start_batch_idx - offset .left , end_batch_idx + offset .right - 1
128104
129- #_check_batch_size_length(indices, offset)
130- #TODO: modify this check_batch_size to pass test.
105+ _check_indices (indices [0 ], indices [1 ], offset , len (inputs ))
131106 batched_data = inputs [slice (* indices )]
132107 return batched_data
133108
134109
110+ def _add_zero_padding (batched_data : torch .Tensor , offset : cebra .data .Offset ,
111+ start_batch_idx : int , end_batch_idx : int ,
112+ number_of_samples : int ):
113+
114+ if start_batch_idx == 0 : # First batch
115+ batched_data = F .pad (batched_data .T , (offset .left , 0 ), 'replicate' ).T
116+
117+ elif end_batch_idx == number_of_samples : # Last batch
118+ batched_data = F .pad (batched_data .T , (0 , offset .right - 1 ),
119+ 'replicate' ).T
120+
121+ return batched_data
122+
123+
135124def _batched_transform (model , inputs : torch .Tensor , batch_size : int ,
136125 pad_before_transform : bool ,
137126 offset : cebra .data .Offset ) -> torch .Tensor :
@@ -153,21 +142,17 @@ def __getitem__(self, idx):
153142 output = []
154143 for batch_id , index_batch in enumerate (index_dataloader ):
155144 start_batch_idx , end_batch_idx = index_batch [0 ], index_batch [- 1 ] + 1
156-
157- # This applies to all batches.
158- batched_data = _pad_with_data (inputs = inputs ,
159- offset = offset ,
160- start_batch_idx = start_batch_idx ,
161- end_batch_idx = end_batch_idx )
145+ batched_data = _get_batch (inputs = inputs ,
146+ offset = offset ,
147+ start_batch_idx = start_batch_idx ,
148+ end_batch_idx = end_batch_idx )
162149
163150 if pad_before_transform :
164- if start_batch_idx == 0 : # First batch
165- batched_data = F .pad (batched_data .T , (offset .left , 0 ),
166- 'replicate' ).T
167-
168- elif end_batch_idx == len (inputs ): # Last batch
169- batched_data = F .pad (batched_data .T , (0 , offset .right - 1 ),
170- 'replicate' ).T
151+ batched_data = _add_zero_padding (batched_data = batched_data ,
152+ offset = offset ,
153+ start_batch_idx = start_batch_idx ,
154+ end_batch_idx = end_batch_idx ,
155+ number_of_samples = len (inputs ))
171156
172157 output_batch = _inference_transform (model , batched_data )
173158 output .append (output_batch )
@@ -503,10 +488,11 @@ def transform(self,
503488 model , offset = self ._select_model (inputs , session_id )
504489 model .eval ()
505490
506- if len (offset ) < 2 and pad_before_transform :
507- raise ValueError (
508- "Padding does not make sense when the offset of the model is < 2"
509- )
491+ #TODO: should we add this error?
492+ #if len(offset) < 2 and pad_before_transform:
493+ # raise ValueError(
494+ # "Padding does not make sense when the offset of the model is < 2"
495+ # )
510496
511497 if batch_size is not None :
512498 output = _batched_transform (
0 commit comments