Skip to content

Commit 9fe3af3

Browse files
CeliaBenquetstes
authored andcommitted
Fix warning
1 parent 0823b54 commit 9fe3af3

File tree

1 file changed

+9
-3
lines changed

1 file changed

+9
-3
lines changed

cebra/solver/base.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -111,12 +111,18 @@ def _add_zero_padding(batched_data: torch.Tensor, offset: cebra.data.Offset,
111111
start_batch_idx: int, end_batch_idx: int,
112112
number_of_samples: int):
113113

114+
reversed_dims = torch.arange(batched_data.ndim - 1, -1, -1)
115+
114116
if start_batch_idx == 0: # First batch
115-
batched_data = F.pad(batched_data.T, (offset.left, 0), 'replicate').T
117+
batched_data = F.pad(batched_data.permute(*reversed_dims),
118+
(offset.left, 0), 'replicate').permute(*reversed_dims)
119+
#batched_data = F.pad(batched_data.T, (offset.left, 0), 'replicate').T
116120

117121
elif end_batch_idx == number_of_samples: # Last batch
118-
batched_data = F.pad(batched_data.T, (0, offset.right - 1),
119-
'replicate').T
122+
batched_data = F.pad(batched_data.permute(*reversed_dims),
123+
(0, offset.right - 1), 'replicate').permute(*reversed_dims)
124+
#batched_data = F.pad(batched_data.T, (0, offset.right - 1), 'replicate').T
125+
120126

121127
return batched_data
122128

0 commit comments

Comments
 (0)