Skip to content

Commit 800410e

Browse files
authored
Censored least squares is faster given missingness (#35)
1 parent 50b6de7 commit 800410e

File tree

1 file changed

+9
-5
lines changed

1 file changed

+9
-5
lines changed

tensorpack/cmtf.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import tensorly as tl
88
from tensorly.tenalg import khatri_rao
99
from copy import deepcopy
10-
from tensorly.decomposition._cp import initialize_cp, parafac
10+
from tensorly.decomposition._cp import initialize_cp
1111
from tqdm import tqdm
1212
from .SVD_impute import IterativeSVD
1313

@@ -116,7 +116,7 @@ def delete_component(tFac, compNum):
116116
return tensor
117117

118118

119-
def censored_lstsq(A: np.ndarray, B: np.ndarray, uniqueInfo) -> np.ndarray:
119+
def censored_lstsq(A: np.ndarray, B: np.ndarray, uniqueInfo=None) -> np.ndarray:
120120
"""Solves least squares problem subject to missing data.
121121
Note: uses a for loop over the missing patterns of B, leading to a
122122
slower but more numerically stable algorithm
@@ -130,7 +130,10 @@ def censored_lstsq(A: np.ndarray, B: np.ndarray, uniqueInfo) -> np.ndarray:
130130
"""
131131
X = np.empty((A.shape[1], B.shape[1]))
132132
# Missingness patterns
133-
unique, uIDX = uniqueInfo
133+
if uniqueInfo is None:
134+
unique, uIDX = np.unique(np.isfinite(B), axis=1, return_inverse=True)
135+
else:
136+
unique, uIDX = uniqueInfo
134137

135138
for i in range(unique.shape[1]):
136139
uI = uIDX == i
@@ -268,8 +271,9 @@ def perform_CMTF(tOrig, mOrig, r=9, tol=1e-6, maxiter=50, progress=True):
268271

269272
tq = tqdm(range(maxiter), disable=(not progress))
270273
for _ in tq:
271-
tensor = np.nan_to_num(tOrig) + tl.cp_to_tensor(tFac) * np.isnan(tOrig)
272-
tFac = parafac(tensor, r, 2000, init=tFac, verbose=False, fixed_modes=[0], mask=np.isfinite(tOrig), linesearch=True, tol=1e-9)
274+
for m in [1, 2]:
275+
kr = khatri_rao(tFac.factors, skip_matrix=m)
276+
tFac.factors[m] = censored_lstsq(kr, tl.unfold(tOrig, m).T)
273277

274278
# Solve for the glycan matrix fit
275279
tFac.mFactor = np.linalg.lstsq(tFac.factors[0][missingM, :], mOrig[missingM, :], rcond=-1)[0].T

0 commit comments

Comments
 (0)