Skip to content

Commit 57b8faa

Browse files
author
Aaron Meyer
committed
Refactor jointdiag; fix tl.copy and tests
- Rename X to matrices_tensor and update internals - Correct tl.copy to use the input tensor - Improve docstring and add References section - Update tests to import from jointdiag module
1 parent 731a373 commit 57b8faa

File tree

2 files changed

+44
-29
lines changed

2 files changed

+44
-29
lines changed

tensorly/utils/jointdiag.py

Lines changed: 41 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from itertools import combinations
2+
23
import tensorly as tl
34

45
# Authors: Aaron Meyer <[email protected]>
@@ -8,18 +9,14 @@
89

910

1011
def joint_matrix_diagonalization(
11-
X,
12+
matrices_tensor,
1213
max_n_iter: int = 50,
1314
threshold: float = 1e-10,
1415
verbose: bool = False,
1516
):
1617
"""
1718
Jointly diagonalizes n matrices, organized in tensor of dimension (k,k,n).
18-
Returns the diagonalized matrices, along with the transformation matrix.
19-
20-
T. Fu and X. Gao, “Simultaneous diagonalization with similarity transformation for
21-
non-defective matrices”, in Proc. IEEE International Conference on Acoustics, Speech
22-
and Signal Processing (ICASSP 2006), vol. IV, pp. 1137-1140, Toulouse, France, May 2006.
19+
Returns the diagonalized matrices, along with the transformation matrix [1]_ .
2320
2421
Args:
2522
X (_type_): n matrices, organized in a single tensor of dimension (k, k, n).
@@ -33,15 +30,24 @@ def joint_matrix_diagonalization(
3330
Returns:
3431
Tensor: X after joint diagonalization.
3532
Tensor: The transformation matrix resulting in the diagonalization.
33+
34+
References
35+
----------
36+
.. [1] T. Fu and X. Gao, “Simultaneous diagonalization with similarity transformation for
37+
non-defective matrices”, in Proc. IEEE International Conference on Acoustics, Speech
38+
and Signal Processing (ICASSP 2006), vol. IV, pp. 1137-1140, Toulouse, France, May 2006.
3639
"""
37-
X = tl.copy()
38-
matrix_dimension = tl.shape(X)[0] # Dimension of square matrix slices
39-
assert tl.ndim(X) == 3, "Input must be a 3D tensor"
40-
assert matrix_dimension == X.shape[1], "All matrices must be square."
40+
matrices_tensor = tl.copy(matrices_tensor)
41+
matrix_dimension = tl.shape(matrices_tensor)[0] # Dimension of square matrix slices
42+
assert tl.ndim(matrices_tensor) == 3, "Input must be a 3D tensor"
43+
assert matrix_dimension == matrices_tensor.shape[1], "All matrices must be square."
4144

4245
# Initial error calculation
4346
# Transpose is because np.tril operates on the last two dimensions
44-
e = tl.norm(X) ** 2.0 - tl.norm(tl.diagonal(X, axis1=1, axis2=2)) ** 2.0
47+
e = (
48+
tl.norm(matrices_tensor) ** 2.0
49+
- tl.norm(tl.diagonal(matrices_tensor, axis1=1, axis2=2)) ** 2.0
50+
)
4551

4652
if verbose:
4753
print(f"Sweep # 0: e = {e:.3e}")
@@ -52,36 +58,40 @@ def joint_matrix_diagonalization(
5258
for k in range(max_n_iter):
5359
# loop over all pairs of slices
5460
for p, q in combinations(range(matrix_dimension), 2):
55-
# Finds matrix slice with greatest variability among diagonal elements
56-
d_ = X[p, p, :] - X[q, q, :]
61+
# Comparing the p and q chords across matrices, identifies the
62+
# position h with the largest difference
63+
d_ = matrices_tensor[p, p, :] - matrices_tensor[q, q, :]
5764
h = tl.argmax(tl.abs(d_))
5865

59-
# List of indices
66+
# List of non-selected indices
6067
all_but_pq = list(set(range(matrix_dimension)) - set([p, q]))
6168

6269
# Compute certain quantities
6370
dh = d_[h]
64-
Xh = X[:, :, h]
65-
Kh = tl.dot(Xh[p, all_but_pq], Xh[q, all_but_pq]) - tl.dot(
66-
Xh[all_but_pq, p], Xh[all_but_pq, q]
71+
matrix_h = matrices_tensor[:, :, h]
72+
Kh = tl.dot(matrix_h[p, all_but_pq], matrix_h[q, all_but_pq]) - tl.dot(
73+
matrix_h[all_but_pq, p], matrix_h[all_but_pq, q]
6774
)
6875
Gh = (
69-
tl.norm(Xh[p, all_but_pq]) ** 2
70-
+ tl.norm(Xh[q, all_but_pq]) ** 2
71-
+ tl.norm(Xh[all_but_pq, p]) ** 2
72-
+ tl.norm(Xh[all_but_pq, q]) ** 2
76+
tl.norm(matrix_h[p, all_but_pq]) ** 2
77+
+ tl.norm(matrix_h[q, all_but_pq]) ** 2
78+
+ tl.norm(matrix_h[all_but_pq, p]) ** 2
79+
+ tl.norm(matrix_h[all_but_pq, q]) ** 2
7380
)
74-
xih = Xh[p, q] - Xh[q, p]
81+
matrix_h_pq_diff = matrix_h[p, q] - matrix_h[q, p]
7582

7683
# Build shearing matrix out of these quantities
77-
yk = tl.arctanh((Kh - xih * dh) / (2 * (dh**2 + xih**2) + Gh))
84+
yk = tl.arctanh(
85+
(Kh - matrix_h_pq_diff * dh) / (2 * (dh**2 + matrix_h_pq_diff**2) + Gh)
86+
)
7887

7988
# Inverse of Sk on left side
80-
pvec = tl.copy(X[p, :, :])
89+
pvec = tl.copy(matrices_tensor[p, :, :])
8190
X = tl.index_update(
82-
X,
91+
matrices_tensor,
8392
tl.index[p, :, :],
84-
X[p, :, :] * tl.cosh(yk) - X[q, :, :] * tl.sinh(yk),
93+
matrices_tensor[p, :, :] * tl.cosh(yk)
94+
- matrices_tensor[q, :, :] * tl.sinh(yk),
8595
)
8696
X = tl.index_update(
8797
X, tl.index[q, :, :], -pvec * tl.sinh(yk) + X[q, :, :] * tl.cosh(yk)
@@ -170,7 +180,10 @@ def joint_matrix_diagonalization(
170180

171181
# Error computation, check if loop needed...
172182
old_e = e
173-
e = tl.norm(X) ** 2.0 - tl.norm(tl.diagonal(X, axis1=1, axis2=2)) ** 2.0
183+
e = (
184+
tl.norm(matrices_tensor) ** 2.0
185+
- tl.norm(tl.diagonal(matrices_tensor, axis1=1, axis2=2)) ** 2.0
186+
)
174187

175188
if verbose:
176189
print(f"Sweep # {k + 1}: e = {e:.3e}")
@@ -179,4 +192,4 @@ def joint_matrix_diagonalization(
179192
if old_e - e < threshold and k > 2:
180193
break
181194

182-
return X, Q_total
195+
return matrices_tensor, Q_total

tensorly/utils/tests/test_jointdiag.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import numpy as np
2+
23
import tensorly as tl
3-
from ..joint_matrix_diagonalization import joint_matrix_diagonalization
4+
5+
from ..jointdiag import joint_matrix_diagonalization
46

57

68
def test_joint_matrix_diagonalization():

0 commit comments

Comments
 (0)