77import tensorly as tl
88from tensorly .tenalg import khatri_rao
99from copy import deepcopy
10- from tensorly .decomposition ._cp import initialize_cp , parafac
10+ from tensorly .decomposition ._cp import initialize_cp
1111from tqdm import tqdm
1212from .SVD_impute import IterativeSVD
1313
@@ -116,7 +116,7 @@ def delete_component(tFac, compNum):
116116 return tensor
117117
118118
119- def censored_lstsq (A : np .ndarray , B : np .ndarray , uniqueInfo ) -> np .ndarray :
119+ def censored_lstsq (A : np .ndarray , B : np .ndarray , uniqueInfo = None ) -> np .ndarray :
120120 """Solves least squares problem subject to missing data.
121121 Note: uses a for loop over the missing patterns of B, leading to a
122122 slower but more numerically stable algorithm
@@ -130,7 +130,10 @@ def censored_lstsq(A: np.ndarray, B: np.ndarray, uniqueInfo) -> np.ndarray:
130130 """
131131 X = np .empty ((A .shape [1 ], B .shape [1 ]))
132132 # Missingness patterns
133- unique , uIDX = uniqueInfo
133+ if uniqueInfo is None :
134+ unique , uIDX = np .unique (np .isfinite (B ), axis = 1 , return_inverse = True )
135+ else :
136+ unique , uIDX = uniqueInfo
134137
135138 for i in range (unique .shape [1 ]):
136139 uI = uIDX == i
@@ -268,8 +271,9 @@ def perform_CMTF(tOrig, mOrig, r=9, tol=1e-6, maxiter=50, progress=True):
268271
269272 tq = tqdm (range (maxiter ), disable = (not progress ))
270273 for _ in tq :
271- tensor = np .nan_to_num (tOrig ) + tl .cp_to_tensor (tFac ) * np .isnan (tOrig )
272- tFac = parafac (tensor , r , 2000 , init = tFac , verbose = False , fixed_modes = [0 ], mask = np .isfinite (tOrig ), linesearch = True , tol = 1e-9 )
274+ for m in [1 , 2 ]:
275+ kr = khatri_rao (tFac .factors , skip_matrix = m )
276+ tFac .factors [m ] = censored_lstsq (kr , tl .unfold (tOrig , m ).T )
273277
274278 # Solve for the glycan matrix fit
275279 tFac .mFactor = np .linalg .lstsq (tFac .factors [0 ][missingM , :], mOrig [missingM , :], rcond = - 1 )[0 ].T
0 commit comments