@@ -22,7 +22,7 @@ def joint_matrix_diagonalization(
2222 and Signal Processing (ICASSP 2006), vol. IV, pp. 1137-1140, Toulouse, France, May 2006.
2323
2424 Args:
25- X (_type_): n matrices, organized in tensor of dimension (k,k,n), for joint diagonalization .
25+ X (_type_): n matrices, organized in a single tensor of dimension (k, k, n) .
2626 max_n_iter (int, optional): Maximum iteration number. Defaults to 50.
2727 threshold (float, optional): Threshold for decrease in error indicating convergence. Defaults to 1e-10.
2828 verbose (bool, optional): Output progress information during diagonalization. Defaults to False.
@@ -34,10 +34,10 @@ def joint_matrix_diagonalization(
3434 Tensor: X after joint diagonalization.
3535 Tensor: The transformation matrix resulting in the diagonalization.
3636 """
37- X = tl .tensor ( X , ** tl . context ( X ) )
38- D = tl .shape (X )[0 ] # Dimension of square matrix slices
37+ X = tl .copy ( )
38+ matrix_dimension = tl .shape (X )[0 ] # Dimension of square matrix slices
3939 assert tl .ndim (X ) == 3 , "Input must be a 3D tensor"
40- assert D == X .shape [1 ], "All slices must be square"
40+ assert matrix_dimension == X .shape [1 ], "All matrices must be square. "
4141
4242 # Initial error calculation
4343 # Transpose is because np.tril operates on the last two dimensions
@@ -47,17 +47,17 @@ def joint_matrix_diagonalization(
4747 print (f"Sweep # 0: e = { e :.3e} " )
4848
4949 # Additional output parameters
50- Q_total = tl .eye (D )
50+ Q_total = tl .eye (matrix_dimension )
5151
5252 for k in range (max_n_iter ):
5353 # loop over all pairs of slices
54- for p , q in combinations (range (D ), 2 ):
54+ for p , q in combinations (range (matrix_dimension ), 2 ):
5555 # Finds matrix slice with greatest variability among diagonal elements
5656 d_ = X [p , p , :] - X [q , q , :]
5757 h = tl .argmax (tl .abs (d_ ))
5858
5959 # List of indices
60- all_but_pq = list (set (range (D )) - set ([p , q ]))
60+ all_but_pq = list (set (range (matrix_dimension )) - set ([p , q ]))
6161
6262 # Compute certain quantities
6363 dh = d_ [h ]
0 commit comments