Skip to content

Commit 5a0f27c

Browse files
committed
Missing matrix imputation with TruncatedSVD; add and edit test cases
1 parent b351770 commit 5a0f27c

File tree

3 files changed

+61
-24
lines changed

3 files changed

+61
-24
lines changed

tensorpack/decomposition.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,35 @@
11
import pickle
22
import numpy as np
33
import pandas as pd
4+
from numpy.linalg import norm
45
from sklearn.decomposition import TruncatedSVD
56
from .cmtf import perform_CP, calcR2X
67

78

9+
def impute_missing_mat(dat):
10+
miss_idx = np.where(~np.isfinite(dat))
11+
if len(miss_idx[0]) <= 0:
12+
return dat
13+
assert np.all(np.any(np.isfinite(dat), axis=0)), "Cannot impute if an entire column is empty"
14+
assert np.all(np.any(np.isfinite(dat), axis=1)), "Cannot impute if an entire row is empty"
15+
16+
imp = np.copy(dat)
17+
col_mean = np.nanmean(dat, axis=0, keepdims=True)
18+
imp[miss_idx] = np.take(col_mean, miss_idx[1])
19+
20+
diff = 1.0
21+
while diff > 1e-3:
22+
tsvd = TruncatedSVD(n_components=min(dat.shape)-1)
23+
scores = tsvd.fit_transform(imp)
24+
loadings = tsvd.components_
25+
recon = scores @ loadings
26+
new_diff = norm(imp[miss_idx] - recon[miss_idx]) / norm(recon[miss_idx])
27+
assert new_diff < diff, "Matrix imputation difference is not decreasing"
28+
diff = new_diff
29+
imp[miss_idx] = recon[miss_idx]
30+
return imp
31+
32+
833
class Decomposition():
934
def __init__(self, data, max_rr=6):
1035
self.data = data
@@ -20,6 +45,8 @@ def perform_tfac(self):
2045
def perform_PCA(self, flattenon=0):
2146
dataShape = self.data.shape
2247
flatData = np.reshape(np.moveaxis(self.data, flattenon, 0), (dataShape[flattenon], -1))
48+
if not np.all(np.isfinite(flatData)):
49+
flatData = impute_missing_mat(flatData)
2350

2451
tsvd = TruncatedSVD(n_components=max(self.rrs))
2552
scores = tsvd.fit_transform(flatData)

tensorpack/test/test_decomposition.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,13 @@
22
Testing Decomposition
33
"""
44

5+
import numpy as np
6+
import os
7+
import tensorly as tl
8+
from tensorly.random import random_cp
59
from ..decomposition import Decomposition
610
from .atyeo import createCube
7-
import os
11+
from ..cmtf import perform_CP, calcR2X
812

913
def test_decomp_obj():
1014
a = Decomposition(createCube())
@@ -21,3 +25,32 @@ def test_decomp_obj():
2125
assert len(b.PCAR2X) == len(b.sizePCA)
2226
assert len(b.TR2X) == len(b.sizeT)
2327
os.remove(fname)
28+
29+
30+
def test_missing_obj():
31+
for miss_rate in [0.1, 0.2]:
32+
dat = createCube()
33+
filter = np.random.rand(*dat.shape) > 1-miss_rate
34+
dat[filter] = np.nan
35+
a = Decomposition(dat)
36+
a.perform_tfac()
37+
a.perform_PCA()
38+
assert len(a.PCAR2X) == len(a.sizePCA)
39+
assert len(a.TR2X) == len(a.sizeT)
40+
41+
42+
def test_known_rank():
43+
shape = (100, 80, 60)
44+
tFacOrig = random_cp(shape, 10, full=False)
45+
tOrig = tl.cp_to_tensor(tFacOrig)
46+
assert calcR2X(tFacOrig, tOrig) >= 1.0
47+
48+
newtFac = [calcR2X(perform_CP(tOrig, r=rr), tOrig) for rr in [1,3,5,7,9]]
49+
assert np.all([newtFac[ii+1] > newtFac[ii] for ii in range(len(newtFac)-1)])
50+
assert newtFac[0] > 0.0
51+
assert newtFac[-1] < 1.0
52+
53+
filter = np.random.rand(*shape) > 0.8
54+
missT = np.copy(tOrig)
55+
missT[filter] = np.nan
56+
pass

tensorpack/test/testcase.py

Lines changed: 0 additions & 23 deletions
This file was deleted.

0 commit comments

Comments
 (0)