Skip to content

Commit 4cb6048

Browse files
authored
Merge pull request #29 from meyer-lab/normalize
Normalize matrix weights to a mWeights in CMTF
2 parents bb57856 + be40ee7 commit 4cb6048

File tree

2 files changed

+15
-7
lines changed

2 files changed

+15
-7
lines changed

tensorpack/cmtf.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
def buildMat(tFac):
1818
""" Build the matrix in CMTF from the factors. """
19+
if hasattr(tFac, 'mWeights'):
20+
return tFac.factors[0] @ (tFac.mFactor * tFac.mWeights).T
1921
return tFac.factors[0] @ tFac.mFactor.T
2022

2123

@@ -78,7 +80,7 @@ def sort_factors(tFac):
7880

7981
# Add the variance of the matrix
8082
if hasattr(tFac, 'mFactor'):
81-
norm += np.sum(np.square(tFac.factors[0]), axis=0) * np.sum(np.square(tFac.mFactor), axis=0)
83+
norm += np.sum(np.square(tFac.factors[0]), axis=0) * np.sum(np.square(tFac.mFactor), axis=0) * tFac.mWeights
8284

8385
order = np.flip(np.argsort(norm))
8486
tensor.weights = tensor.weights[order]
@@ -87,6 +89,7 @@ def sort_factors(tFac):
8789

8890
if hasattr(tFac, 'mFactor'):
8991
tensor.mFactor = tensor.mFactor[:, order]
92+
tensor.mWeights = tensor.mWeights[order]
9093
np.testing.assert_allclose(buildMat(tFac), buildMat(tensor), atol=1e-9)
9194

9295
return tensor
@@ -106,6 +109,7 @@ def delete_component(tFac, compNum):
106109

107110
if hasattr(tFac, 'mFactor'):
108111
tensor.mFactor = np.delete(tensor.mFactor, compNum, axis=1)
112+
tensor.mWeights = np.delete(tensor.mWeights, compNum)
109113

110114
tensor.factors = [np.delete(fac, compNum, axis=1) for fac in tensor.factors]
111115
return tensor
@@ -142,7 +146,9 @@ def cp_normalize(tFac):
142146
scales = np.linalg.norm(factor, ord=np.inf, axis=0)
143147
tFac.weights *= scales
144148
if i == 0 and hasattr(tFac, 'mFactor'):
145-
tFac.mFactor *= scales
149+
mScales = np.linalg.norm(tFac.mFactor, ord=np.inf, axis=0)
150+
tFac.mWeights = scales * mScales
151+
tFac.mFactor /= mScales
146152

147153
tFac.factors[i] /= scales
148154

@@ -245,16 +251,16 @@ def perform_CP(tOrig, r=6, tol=1e-6):
245251
return tFac
246252

247253

248-
def perform_CMTF(tOrig, mOrig=None, r=9, tol=1e-6, maxiter=50):
254+
def perform_CMTF(tOrig, mOrig, r=9, tol=1e-6, maxiter=50):
249255
""" Perform CMTF decomposition. """
250256
assert tOrig.dtype == float
251-
if mOrig is not None:
252-
assert mOrig.dtype == float
257+
assert mOrig.dtype == float
253258
tFac = initialize_cmtf(tOrig, mOrig, r)
254259

255260
# Pre-unfold
256261
unfolded = np.hstack((tl.unfold(tOrig, 0), mOrig))
257262
missingM = np.all(np.isfinite(mOrig), axis=1)
263+
assert np.sum(missingM) >= 1, "mOrig must contain at least one complete row"
258264
R2X = -np.inf
259265

260266
# Precalculate the missingness patterns
@@ -279,6 +285,7 @@ def perform_CMTF(tOrig, mOrig=None, r=9, tol=1e-6, maxiter=50):
279285
if R2X - R2X_last < tol:
280286
break
281287

288+
assert not np.all(tFac.mFactor == 0.0)
282289
tFac = cp_normalize(tFac)
283290
tFac = reorient_factors(tFac)
284291
tFac = sort_factors(tFac)

tensorpack/test/test_cmtf.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def test_cmtf_R2X():
2020
""" Test to ensure R2X for higher components is larger. """
2121
arr = []
2222
tensor = createCube(missing=0.2, size=(10, 20, 25))
23-
matrix = createCube(missing=0.2, size=(10, 15))
23+
matrix = createCube(missing=0.05, size=(10, 15))
2424
for i in range(1, 5):
2525
facT = perform_CMTF(tensor, matrix, r=i)
2626
assert np.all(np.isfinite(facT.factors[0]))
@@ -58,7 +58,7 @@ def test_cp():
5858
def test_delete():
5959
""" Test deleting a component results in a valid tensor. """
6060
tOrig = createCube(missing=0.2, size=(10, 20, 25))
61-
mOrig = createCube(missing=0.2, size=(10, 15))
61+
mOrig = createCube(missing=0.05, size=(10, 15))
6262
facT = perform_CMTF(tOrig, mOrig, r=4)
6363

6464
fullR2X = calcR2X(facT, tOrig, mOrig)
@@ -79,6 +79,7 @@ def test_sort():
7979

8080
tFac = random_cp(tOrig.shape, 3)
8181
tFac.mFactor = np.random.randn(mOrig.shape[1], 3)
82+
tFac.mWeights = np.ones(3)
8283

8384
R2X = calcR2X(tFac, tOrig, mOrig)
8485
tRec = tl.cp_to_tensor(tFac)

0 commit comments

Comments
 (0)