Skip to content

Commit 36de3f1

Browse files
author
Mithrillion
committed
updated formatting of changed sections
1 parent e04cf28 commit 36de3f1

File tree

2 files changed

+23
-13
lines changed

2 files changed

+23
-13
lines changed

tensorly/tests/test_tucker_tensor.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -108,17 +108,23 @@ def test_tucker_to_tensor_with_partial_modes():
108108
for (R, s) in zip(ranks, tl.shape(X)[1:])
109109
]
110110
true_res = np.array(
111-
[[[ 120., 456., 792., 1128.],
112-
[ 400., 1472., 2544., 3616.],
113-
[ 680., 2488., 4296., 6104.]],
114-
115-
[[ 126., 486., 846., 1206.],
116-
[ 422., 1582., 2742., 3902.],
117-
[ 718., 2678., 4638., 6598.]],
118-
119-
[[ 132., 516., 900., 1284.],
120-
[ 444., 1692., 2940., 4188.],
121-
[ 756., 2868., 4980., 7092.]]]
111+
[
112+
[
113+
[120.0, 456.0, 792.0, 1128.0],
114+
[400.0, 1472.0, 2544.0, 3616.0],
115+
[680.0, 2488.0, 4296.0, 6104.0],
116+
],
117+
[
118+
[126.0, 486.0, 846.0, 1206.0],
119+
[422.0, 1582.0, 2742.0, 3902.0],
120+
[718.0, 2678.0, 4638.0, 6598.0],
121+
],
122+
[
123+
[132.0, 516.0, 900.0, 1284.0],
124+
[444.0, 1692.0, 2940.0, 4188.0],
125+
[756.0, 2868.0, 4980.0, 7092.0],
126+
],
127+
]
122128
)
123129
res = tucker_to_tensor((X, U), modes=[1, 2])
124130
assert_array_equal(true_res, res)

tensorly/tucker_tensor.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,9 @@ def _validate_tucker_tensor(tucker_tensor):
4747
return tuple(shape), tuple(rank)
4848

4949

50-
def tucker_to_tensor(tucker_tensor, skip_factor=None, transpose_factors=False, modes=None):
50+
def tucker_to_tensor(
51+
tucker_tensor, skip_factor=None, transpose_factors=False, modes=None
52+
):
5153
"""Converts the Tucker tensor into a full tensor
5254
5355
Parameters
@@ -68,7 +70,9 @@ def tucker_to_tensor(tucker_tensor, skip_factor=None, transpose_factors=False, m
6870
full tensor of shape ``(factors[0].shape[0], ..., factors[-1].shape[0])``
6971
"""
7072
core, factors = tucker_tensor
71-
return multi_mode_dot(core, factors, skip=skip_factor, transpose=transpose_factors, modes=modes)
73+
return multi_mode_dot(
74+
core, factors, skip=skip_factor, transpose=transpose_factors, modes=modes
75+
)
7276

7377

7478
def tucker_normalize(tucker_tensor):

0 commit comments

Comments
 (0)