Skip to content

Commit bc8ee25

Browse files
gonlairostes
authored andcommitted
differentiate between data padding and zero padding
1 parent 1aadc8b commit bc8ee25

File tree

2 files changed

+229
-253
lines changed

2 files changed

+229
-253
lines changed

cebra/solver/base.py

Lines changed: 42 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -66,56 +66,32 @@ def _inference_transform(model, inputs):
6666
return output
6767

6868

69-
def _pad_with_data(inputs: torch.Tensor, offset: cebra.data.Offset,
70-
start_batch_idx: int, end_batch_idx: int) -> torch.Tensor:
71-
"""
72-
Pads a batch of input data with its own data (maybe this is not called padding)
73-
74-
Args:
75-
inputs: The input data to be processed.
76-
add_padding: Indicates whether padding should be applied before inference.
77-
offset: Offset configuration for padding. If add_padding is True,
78-
offset must be set. If add_padding is False, offset is not used and can be None.
79-
start_batch_idx: The starting index of the current batch.
80-
end_batch_idx: The last index of the current batch.
81-
82-
Returns:
83-
torch.Tensor: The (potentially) padded data.
84-
85-
Raises:
86-
ValueError: If add_padding is True and offset is not provided.
87-
"""
88-
89-
def _check_indices(indices, inputs):
90-
if (indices[0] < 0) or (indices[1] > inputs.shape[0]):
91-
raise ValueError(
92-
f"offset {offset} is too big for the length of the inputs ({len(inputs)}) "
93-
f"The indices {indices} do not match the inputs length {len(inputs)}."
94-
)
69+
def _check_indices(start_batch_idx, end_batch_idx, offset, num_samples):
9570

9671
if start_batch_idx < 0 or end_batch_idx < 0:
9772
raise ValueError(
9873
f"start_batch_idx ({start_batch_idx}) and end_batch_idx ({end_batch_idx}) must be non-negative."
9974
)
100-
10175
if start_batch_idx > end_batch_idx:
10276
raise ValueError(
10377
f"start_batch_idx ({start_batch_idx}) cannot be greater than end_batch_idx ({end_batch_idx})."
10478
)
79+
if end_batch_idx > num_samples:
80+
raise ValueError(
81+
f"end_batch_idx ({end_batch_idx}) cannot exceed the length of inputs ({num_samples})."
82+
)
10583

106-
if end_batch_idx > len(inputs):
84+
batch_size_lenght = end_batch_idx - start_batch_idx
85+
if batch_size_lenght <= len(offset):
10786
raise ValueError(
108-
f"end_batch_idx ({end_batch_idx}) cannot exceed the length of inputs ({len(inputs)})."
87+
f"The batch has length {batch_size_lenght} which "
88+
f"is smaller or equal than the required offset length {len(offset)}."
89+
f"Either choose a model with smaller offset or the batch shoud contain more samples."
10990
)
11091

111-
def _check_batch_size_length(indices_batch, offset):
112-
batch_size_lenght = indices_batch[1] - indices_batch[0]
113-
if batch_size_lenght <= len(offset):
114-
raise ValueError(
115-
f"The batch has length {batch_size_lenght} which "
116-
f"is smaller or equal than the required offset length {len(offset)}."
117-
f"Either choose a model with smaller offset or the batch shoud contain more samples."
118-
)
92+
93+
def _get_batch(inputs: torch.Tensor, offset: cebra.data.Offset,
94+
start_batch_idx: int, end_batch_idx: int) -> torch.Tensor:
11995

12096
if start_batch_idx == 0: # First batch
12197
indices = start_batch_idx, (end_batch_idx + offset.right - 1)
@@ -126,12 +102,25 @@ def _check_batch_size_length(indices_batch, offset):
126102
else: # Middle batches
127103
indices = start_batch_idx - offset.left, end_batch_idx + offset.right - 1
128104

129-
#_check_batch_size_length(indices, offset)
130-
#TODO: modify this check_batch_size to pass test.
105+
_check_indices(indices[0], indices[1], offset, len(inputs))
131106
batched_data = inputs[slice(*indices)]
132107
return batched_data
133108

134109

110+
def _add_zero_padding(batched_data: torch.Tensor, offset: cebra.data.Offset,
111+
start_batch_idx: int, end_batch_idx: int,
112+
number_of_samples: int):
113+
114+
if start_batch_idx == 0: # First batch
115+
batched_data = F.pad(batched_data.T, (offset.left, 0), 'replicate').T
116+
117+
elif end_batch_idx == number_of_samples: # Last batch
118+
batched_data = F.pad(batched_data.T, (0, offset.right - 1),
119+
'replicate').T
120+
121+
return batched_data
122+
123+
135124
def _batched_transform(model, inputs: torch.Tensor, batch_size: int,
136125
pad_before_transform: bool,
137126
offset: cebra.data.Offset) -> torch.Tensor:
@@ -153,21 +142,17 @@ def __getitem__(self, idx):
153142
output = []
154143
for batch_id, index_batch in enumerate(index_dataloader):
155144
start_batch_idx, end_batch_idx = index_batch[0], index_batch[-1] + 1
156-
157-
# This applies to all batches.
158-
batched_data = _pad_with_data(inputs=inputs,
159-
offset=offset,
160-
start_batch_idx=start_batch_idx,
161-
end_batch_idx=end_batch_idx)
145+
batched_data = _get_batch(inputs=inputs,
146+
offset=offset,
147+
start_batch_idx=start_batch_idx,
148+
end_batch_idx=end_batch_idx)
162149

163150
if pad_before_transform:
164-
if start_batch_idx == 0: # First batch
165-
batched_data = F.pad(batched_data.T, (offset.left, 0),
166-
'replicate').T
167-
168-
elif end_batch_idx == len(inputs): # Last batch
169-
batched_data = F.pad(batched_data.T, (0, offset.right - 1),
170-
'replicate').T
151+
batched_data = _add_zero_padding(batched_data=batched_data,
152+
offset=offset,
153+
start_batch_idx=start_batch_idx,
154+
end_batch_idx=end_batch_idx,
155+
number_of_samples=len(inputs))
171156

172157
output_batch = _inference_transform(model, batched_data)
173158
output.append(output_batch)
@@ -503,10 +488,11 @@ def transform(self,
503488
model, offset = self._select_model(inputs, session_id)
504489
model.eval()
505490

506-
if len(offset) < 2 and pad_before_transform:
507-
raise ValueError(
508-
"Padding does not make sense when the offset of the model is < 2"
509-
)
491+
#TODO: should we add this error?
492+
#if len(offset) < 2 and pad_before_transform:
493+
# raise ValueError(
494+
# "Padding does not make sense when the offset of the model is < 2"
495+
# )
510496

511497
if batch_size is not None:
512498
output = _batched_transform(

0 commit comments

Comments
 (0)