Skip to content

Commit 8c8be85

Browse files
gonlairostes
authored andcommitted
add torch padding to _transform
1 parent f0303e0 commit 8c8be85

File tree

1 file changed

+3
-6
lines changed

1 file changed

+3
-6
lines changed

cebra/solver/base.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ def _inference_transform(model, inputs):
5656
#TODO: I am not sure what is the best way with dealing with the types and
5757
# device when using batched inference. This works for now.
5858
inputs = inputs.type(torch.FloatTensor).to(next(model.parameters()).device)
59+
5960
if isinstance(model, cebra.models.ConvolutionalModelMixin):
6061
# Fully convolutional evaluation, switch (T, C) -> (1, C, T)
6162
inputs = inputs.transpose(1, 0).unsqueeze(0)
@@ -110,8 +111,6 @@ def _check_indices(indices, inputs):
110111

111112
def _check_batch_size_length(indices_batch, offset):
112113
batch_size_lenght = indices_batch[1] - indices_batch[0]
113-
print("batch_size ll", add_padding, indices, batch_size_lenght,
114-
len(offset))
115114
if batch_size_lenght <= len(offset):
116115
raise ValueError(
117116
f"The batch has length {batch_size_lenght} which "
@@ -489,10 +488,8 @@ def _transform(self, model, inputs, offset,
489488
pad_before_transform) -> torch.Tensor:
490489

491490
if pad_before_transform:
492-
inputs = np.pad(inputs, ((offset.left, offset.right - 1), (0, 0)),
493-
mode="edge")
494-
inputs = torch.from_numpy(inputs)
495-
491+
inputs = F.pad(inputs.T, (offset.left, offset.right - 1),
492+
'replicate').T
496493
output = _inference_transform(model, inputs)
497494
return output
498495

0 commit comments

Comments
 (0)