Skip to content

Commit 7009e92

Browse files
authored
Merge pull request tensorly#589 from JeanKossaifi/main
FIX partial_unfold + add test
2 parents 2f1aabb + 97ae008 commit 7009e92

File tree

2 files changed

+8
-2
lines changed

2 files changed

+8
-2
lines changed

tensorly/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def partial_unfold(tensor, mode=0, skip_begin=1, skip_end=0, ravel_tensors=False
113113
new_shape = [tensor.shape[i] for i in range(skip_begin)] + new_shape
114114

115115
if skip_end:
116-
new_shape += [tensor.shape[-i] for i in range(1, 1 + skip_end)]
116+
new_shape += [tensor.shape[-i] for i in range(skip_end, 0, -1)]
117117

118118
return tl.reshape(tl.moveaxis(tensor, mode + skip_begin, skip_begin), new_shape)
119119

tensorly/tests/test_core.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import numpy as np
22

3-
from .. import backend as T
3+
import tensorly as T
44
from ..base import fold, unfold
55
from ..base import partial_fold, partial_unfold
66
from ..base import tensor_to_vec, vec_to_tensor
@@ -320,6 +320,12 @@ def test_partial_tensor_to_vec():
320320
for j in range(n_samples): # test for each sample
321321
assert_array_equal(T.transpose(vectorised)[j], vec_X + j)
322322

323+
tensor = T.randn((2, 3, 4, 5))
324+
TT = partial_tensor_to_vec(tensor, skip_begin=0, skip_end=2)
325+
assert T.shape(TT) == (6, 4, 5)
326+
rec = partial_vec_to_tensor(TT, T.shape(tensor), skip_begin=0, skip_end=2)
327+
assert T.shape(rec) == T.shape(tensor)
328+
323329

324330
def test_partial_vec_to_tensor():
325331
"""Test for partial_vec_to_tensor"""

0 commit comments

Comments
 (0)