Skip to content

Commit 2896e47

Browse files
committed
Resolve some comments
1 parent f3809d5 commit 2896e47

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

tensorly/utils/jointdiag.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def joint_matrix_diagonalization(
2222
and Signal Processing (ICASSP 2006), vol. IV, pp. 1137-1140, Toulouse, France, May 2006.
2323
2424
Args:
25-
X (_type_): n matrices, organized in tensor of dimension (k,k,n), for joint diagonalization.
25+
X (_type_): n matrices, organized in a single tensor of dimension (k, k, n).
2626
max_n_iter (int, optional): Maximum iteration number. Defaults to 50.
2727
threshold (float, optional): Threshold for decrease in error indicating convergence. Defaults to 1e-10.
2828
verbose (bool, optional): Output progress information during diagonalization. Defaults to False.
@@ -34,10 +34,10 @@ def joint_matrix_diagonalization(
3434
Tensor: X after joint diagonalization.
3535
Tensor: The transformation matrix resulting in the diagonalization.
3636
"""
37-
X = tl.tensor(X, **tl.context(X))
38-
D = tl.shape(X)[0] # Dimension of square matrix slices
37+
X = tl.copy()
38+
matrix_dimension = tl.shape(X)[0] # Dimension of square matrix slices
3939
assert tl.ndim(X) == 3, "Input must be a 3D tensor"
40-
assert D == X.shape[1], "All slices must be square"
40+
assert matrix_dimension == X.shape[1], "All matrices must be square."
4141

4242
# Initial error calculation
4343
# Transpose is because np.tril operates on the last two dimensions
@@ -47,17 +47,17 @@ def joint_matrix_diagonalization(
4747
print(f"Sweep # 0: e = {e:.3e}")
4848

4949
# Additional output parameters
50-
Q_total = tl.eye(D)
50+
Q_total = tl.eye(matrix_dimension)
5151

5252
for k in range(max_n_iter):
5353
# loop over all pairs of slices
54-
for p, q in combinations(range(D), 2):
54+
for p, q in combinations(range(matrix_dimension), 2):
5555
# Finds matrix slice with greatest variability among diagonal elements
5656
d_ = X[p, p, :] - X[q, q, :]
5757
h = tl.argmax(tl.abs(d_))
5858

5959
# List of indices
60-
all_but_pq = list(set(range(D)) - set([p, q]))
60+
all_but_pq = list(set(range(matrix_dimension)) - set([p, q]))
6161

6262
# Compute certain quantities
6363
dh = d_[h]

0 commit comments

Comments
 (0)