Skip to content

Commit fb479a4

Browse files
authored
Merge pull request #19 from meyer-lab/tSVD
Use truncated SVD
2 parents 191093e + e845601 commit fb479a4

File tree

3 files changed

+7
-3
lines changed

3 files changed

+7
-3
lines changed

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,12 @@ authors = ["Your Name <[email protected]>"]
66
license = "MIT"
77

88
[tool.poetry.dependencies]
9-
python = "^3.9"
9+
python = ">=3.9,<3.11"
1010
numpy = "^1.21"
1111
scipy = "^1.7"
1212
statsmodels = "^0.13"
1313
tensorly = "^0.7"
14+
scikit-learn = "^1.0.1"
1415

1516
[tool.poetry.dev-dependencies]
1617
pytest = "^6.2"

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from setuptools import setup, find_packages
22

33
setup(name='tensorpack',
4-
version='0.1',
4+
version='0.0.2',
55
description='A collection of tensor methods from the Meyer lab.',
66
url='https://github.com/meyer-lab/tensorpack',
77
license='MIT',

tensorpack/cmtf.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"""
44

55
import numpy as np
6+
from sklearn.decomposition import TruncatedSVD
67
import tensorly as tl
78
from tensorly.tenalg import khatri_rao
89
from copy import deepcopy
@@ -192,6 +193,8 @@ def initialize_cp(tensor: np.ndarray, rank: int):
192193
factors = [np.ones((tensor.shape[i], rank)) for i in range(tensor.ndim)]
193194
contain_missing = (np.sum(~np.isfinite(tensor)) > 0)
194195

196+
tsvd = TruncatedSVD(n_components=rank)
197+
195198
# SVD init mode whose size is larger than rank
196199
for mode in range(tensor.ndim):
197200
if tensor.shape[mode] >= rank:
@@ -200,7 +203,7 @@ def initialize_cp(tensor: np.ndarray, rank: int):
200203
si = SoftImpute(max_rank=rank)
201204
unfold = si.fit_transform(unfold)
202205

203-
factors[mode] = np.linalg.svd(unfold)[0][:, :rank]
206+
factors[mode] = tsvd.fit_transform(unfold)
204207

205208
return tl.cp_tensor.CPTensor((None, factors))
206209

0 commit comments

Comments
 (0)