11from itertools import combinations
2+
23import tensorly as tl
34
45# Authors: Aaron Meyer <[email protected] > 89
910
1011def joint_matrix_diagonalization (
11- X ,
12+ matrices_tensor ,
1213 max_n_iter : int = 50 ,
1314 threshold : float = 1e-10 ,
1415 verbose : bool = False ,
1516):
1617 """
1718 Jointly diagonalizes n matrices, organized in tensor of dimension (k,k,n).
18- Returns the diagonalized matrices, along with the transformation matrix.
19-
20- T. Fu and X. Gao, “Simultaneous diagonalization with similarity transformation for
21- non-defective matrices”, in Proc. IEEE International Conference on Acoustics, Speech
22- and Signal Processing (ICASSP 2006), vol. IV, pp. 1137-1140, Toulouse, France, May 2006.
19+ Returns the diagonalized matrices, along with the transformation matrix [1]_ .
2320
2421 Args:
2522 X (_type_): n matrices, organized in a single tensor of dimension (k, k, n).
@@ -33,15 +30,24 @@ def joint_matrix_diagonalization(
3330 Returns:
3431 Tensor: X after joint diagonalization.
3532 Tensor: The transformation matrix resulting in the diagonalization.
33+
34+ References
35+ ----------
36+ .. [1] T. Fu and X. Gao, “Simultaneous diagonalization with similarity transformation for
37+ non-defective matrices”, in Proc. IEEE International Conference on Acoustics, Speech
38+ and Signal Processing (ICASSP 2006), vol. IV, pp. 1137-1140, Toulouse, France, May 2006.
3639 """
37- X = tl .copy ()
38- matrix_dimension = tl .shape (X )[0 ] # Dimension of square matrix slices
39- assert tl .ndim (X ) == 3 , "Input must be a 3D tensor"
40- assert matrix_dimension == X .shape [1 ], "All matrices must be square."
40+ matrices_tensor = tl .copy (matrices_tensor )
41+ matrix_dimension = tl .shape (matrices_tensor )[0 ] # Dimension of square matrix slices
42+ assert tl .ndim (matrices_tensor ) == 3 , "Input must be a 3D tensor"
43+ assert matrix_dimension == matrices_tensor .shape [1 ], "All matrices must be square."
4144
4245 # Initial error calculation
4346 # Transpose is because np.tril operates on the last two dimensions
44- e = tl .norm (X ) ** 2.0 - tl .norm (tl .diagonal (X , axis1 = 1 , axis2 = 2 )) ** 2.0
47+ e = (
48+ tl .norm (matrices_tensor ) ** 2.0
49+ - tl .norm (tl .diagonal (matrices_tensor , axis1 = 1 , axis2 = 2 )) ** 2.0
50+ )
4551
4652 if verbose :
4753 print (f"Sweep # 0: e = { e :.3e} " )
@@ -52,36 +58,40 @@ def joint_matrix_diagonalization(
5258 for k in range (max_n_iter ):
5359 # loop over all pairs of slices
5460 for p , q in combinations (range (matrix_dimension ), 2 ):
55- # Finds matrix slice with greatest variability among diagonal elements
56- d_ = X [p , p , :] - X [q , q , :]
61+ # Comparing the p and q chords across matrices, identifies the
62+ # position h with the largest difference
63+ d_ = matrices_tensor [p , p , :] - matrices_tensor [q , q , :]
5764 h = tl .argmax (tl .abs (d_ ))
5865
59- # List of indices
66+ # List of non-selected indices
6067 all_but_pq = list (set (range (matrix_dimension )) - set ([p , q ]))
6168
6269 # Compute certain quantities
6370 dh = d_ [h ]
64- Xh = X [:, :, h ]
65- Kh = tl .dot (Xh [p , all_but_pq ], Xh [q , all_but_pq ]) - tl .dot (
66- Xh [all_but_pq , p ], Xh [all_but_pq , q ]
71+ matrix_h = matrices_tensor [:, :, h ]
72+ Kh = tl .dot (matrix_h [p , all_but_pq ], matrix_h [q , all_but_pq ]) - tl .dot (
73+ matrix_h [all_but_pq , p ], matrix_h [all_but_pq , q ]
6774 )
6875 Gh = (
69- tl .norm (Xh [p , all_but_pq ]) ** 2
70- + tl .norm (Xh [q , all_but_pq ]) ** 2
71- + tl .norm (Xh [all_but_pq , p ]) ** 2
72- + tl .norm (Xh [all_but_pq , q ]) ** 2
76+ tl .norm (matrix_h [p , all_but_pq ]) ** 2
77+ + tl .norm (matrix_h [q , all_but_pq ]) ** 2
78+ + tl .norm (matrix_h [all_but_pq , p ]) ** 2
79+ + tl .norm (matrix_h [all_but_pq , q ]) ** 2
7380 )
74- xih = Xh [p , q ] - Xh [q , p ]
81+ matrix_h_pq_diff = matrix_h [p , q ] - matrix_h [q , p ]
7582
7683 # Build shearing matrix out of these quantities
77- yk = tl .arctanh ((Kh - xih * dh ) / (2 * (dh ** 2 + xih ** 2 ) + Gh ))
84+ yk = tl .arctanh (
85+ (Kh - matrix_h_pq_diff * dh ) / (2 * (dh ** 2 + matrix_h_pq_diff ** 2 ) + Gh )
86+ )
7887
7988 # Inverse of Sk on left side
80- pvec = tl .copy (X [p , :, :])
89+ pvec = tl .copy (matrices_tensor [p , :, :])
8190 X = tl .index_update (
82- X ,
91+ matrices_tensor ,
8392 tl .index [p , :, :],
84- X [p , :, :] * tl .cosh (yk ) - X [q , :, :] * tl .sinh (yk ),
93+ matrices_tensor [p , :, :] * tl .cosh (yk )
94+ - matrices_tensor [q , :, :] * tl .sinh (yk ),
8595 )
8696 X = tl .index_update (
8797 X , tl .index [q , :, :], - pvec * tl .sinh (yk ) + X [q , :, :] * tl .cosh (yk )
@@ -170,7 +180,10 @@ def joint_matrix_diagonalization(
170180
171181 # Error computation, check if loop needed...
172182 old_e = e
173- e = tl .norm (X ) ** 2.0 - tl .norm (tl .diagonal (X , axis1 = 1 , axis2 = 2 )) ** 2.0
183+ e = (
184+ tl .norm (matrices_tensor ) ** 2.0
185+ - tl .norm (tl .diagonal (matrices_tensor , axis1 = 1 , axis2 = 2 )) ** 2.0
186+ )
174187
175188 if verbose :
176189 print (f"Sweep # { k + 1 } : e = { e :.3e} " )
@@ -179,4 +192,4 @@ def joint_matrix_diagonalization(
179192 if old_e - e < threshold and k > 2 :
180193 break
181194
182- return X , Q_total
195+ return matrices_tensor , Q_total
0 commit comments