1616
1717def buildMat (tFac ):
1818 """ Build the matrix in CMTF from the factors. """
19+ if hasattr (tFac , 'mWeights' ):
20+ return tFac .factors [0 ] @ (tFac .mFactor * tFac .mWeights ).T
1921 return tFac .factors [0 ] @ tFac .mFactor .T
2022
2123
@@ -78,7 +80,7 @@ def sort_factors(tFac):
7880
7981 # Add the variance of the matrix
8082 if hasattr (tFac , 'mFactor' ):
81- norm += np .sum (np .square (tFac .factors [0 ]), axis = 0 ) * np .sum (np .square (tFac .mFactor ), axis = 0 )
83+ norm += np .sum (np .square (tFac .factors [0 ]), axis = 0 ) * np .sum (np .square (tFac .mFactor ), axis = 0 ) * tFac . mWeights
8284
8385 order = np .flip (np .argsort (norm ))
8486 tensor .weights = tensor .weights [order ]
@@ -87,6 +89,7 @@ def sort_factors(tFac):
8789
8890 if hasattr (tFac , 'mFactor' ):
8991 tensor .mFactor = tensor .mFactor [:, order ]
92+ tensor .mWeights = tensor .mWeights [order ]
9093 np .testing .assert_allclose (buildMat (tFac ), buildMat (tensor ), atol = 1e-9 )
9194
9295 return tensor
@@ -106,6 +109,7 @@ def delete_component(tFac, compNum):
106109
107110 if hasattr (tFac , 'mFactor' ):
108111 tensor .mFactor = np .delete (tensor .mFactor , compNum , axis = 1 )
112+ tensor .mWeights = np .delete (tensor .mWeights , compNum )
109113
110114 tensor .factors = [np .delete (fac , compNum , axis = 1 ) for fac in tensor .factors ]
111115 return tensor
@@ -142,7 +146,9 @@ def cp_normalize(tFac):
142146 scales = np .linalg .norm (factor , ord = np .inf , axis = 0 )
143147 tFac .weights *= scales
144148 if i == 0 and hasattr (tFac , 'mFactor' ):
145- tFac .mFactor *= scales
149+ mScales = np .linalg .norm (tFac .mFactor , ord = np .inf , axis = 0 )
150+ tFac .mWeights = scales * mScales
151+ tFac .mFactor /= mScales
146152
147153 tFac .factors [i ] /= scales
148154
@@ -245,16 +251,16 @@ def perform_CP(tOrig, r=6, tol=1e-6):
245251 return tFac
246252
247253
248- def perform_CMTF (tOrig , mOrig = None , r = 9 , tol = 1e-6 , maxiter = 50 ):
254+ def perform_CMTF (tOrig , mOrig , r = 9 , tol = 1e-6 , maxiter = 50 ):
249255 """ Perform CMTF decomposition. """
250256 assert tOrig .dtype == float
251- if mOrig is not None :
252- assert mOrig .dtype == float
257+ assert mOrig .dtype == float
253258 tFac = initialize_cmtf (tOrig , mOrig , r )
254259
255260 # Pre-unfold
256261 unfolded = np .hstack ((tl .unfold (tOrig , 0 ), mOrig ))
257262 missingM = np .all (np .isfinite (mOrig ), axis = 1 )
263+ assert np .sum (missingM ) >= 1 , "mOrig must contain at least one complete row"
258264 R2X = - np .inf
259265
260266 # Precalculate the missingness patterns
@@ -279,6 +285,7 @@ def perform_CMTF(tOrig, mOrig=None, r=9, tol=1e-6, maxiter=50):
279285 if R2X - R2X_last < tol :
280286 break
281287
288+ assert not np .all (tFac .mFactor == 0.0 )
282289 tFac = cp_normalize (tFac )
283290 tFac = reorient_factors (tFac )
284291 tFac = sort_factors (tFac )
0 commit comments