Skip to content

Commit 3912bb9

Browse files
authored
Merge pull request tensorly#563 from cchatzis/parafac2-als-EM
Parafac2 ALS EM
2 parents 29ad039 + 2c51714 commit 3912bb9

File tree

2 files changed

+280
-27
lines changed

2 files changed

+280
-27
lines changed

tensorly/decomposition/_parafac2.py

Lines changed: 164 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from ..parafac2_tensor import (
1010
Parafac2Tensor,
1111
_validate_parafac2_tensor,
12+
parafac2_to_slices,
1213
)
1314
from ..cp_tensor import CPTensor, cp_normalize
1415
from ..tenalg.svd import svd_interface, SVD_TYPES
@@ -17,8 +18,60 @@
1718
# Yngve Mardal Moe
1819

1920

21+
def _update_imputed(tensor_slices, mask, decomposition, method):
22+
"""
23+
Update missing values of tensor slices according to method.
24+
25+
Parameters
26+
----------
27+
tensor_slices : Iterable of ndarray
28+
mask : ndarray
29+
An array with the same shape as the tensor. It should be 0 where there are
30+
missing values and 1 everywhere else.
31+
decomposition : Parafac2Tensor, optional
32+
method : string
33+
One of 'mode-2' or 'factors'. 'mode-2' updates imputed values according to
34+
mean of each mode-2 slice. If 'factors' is chosen, set missing entries
35+
according to reconstructed tensor given from 'decomposition'.
36+
'mode-2' is used (by default) for initializing missing entries while
37+
'factors' is used for updating imputations during optimization. If an
38+
initial decomposition is specified, 'factors' is used at initialization.
39+
40+
Returns
41+
-------
42+
tensor_slices : Iterable of ndarray
43+
"""
44+
45+
if method == "mode-2":
46+
47+
for slice_no, (slice, slice_mask) in enumerate(zip(tensor_slices, mask)):
48+
49+
slice_mean = tl.sum(slice * slice_mask) / T.sum(slice_mask)
50+
51+
tensor_slices[slice_no] = tl.where(
52+
slice_mask == 0, slice_mean, tensor_slices[slice_no]
53+
)
54+
55+
else: # factors
56+
57+
reconstructed_slices = parafac2_to_slices(decomposition)
58+
tensor_slices = list(tensor_slices)
59+
60+
for slice_no, (slice, rec_slice, slice_mask) in enumerate(
61+
zip(tensor_slices, reconstructed_slices, mask)
62+
):
63+
64+
tensor_slices[slice_no] = tl.where(slice_mask == 0, rec_slice, slice)
65+
66+
return tensor_slices
67+
68+
2069
def initialize_decomposition(
21-
tensor_slices, rank, init="random", svd="truncated_svd", random_state=None
70+
tensor_slices,
71+
rank,
72+
init="random",
73+
svd="truncated_svd",
74+
random_state=None,
2275
):
2376
r"""Initiate a random PARAFAC2 decomposition given rank and tensor slices.
2477
@@ -44,6 +97,9 @@ def initialize_decomposition(
4497
rank : int
4598
init : {'random', 'svd', CPTensor, Parafac2Tensor}, optional
4699
random_state : `np.random.RandomState`
100+
mask : ndarray, optional
101+
An array with the same shape as the tensor. It should be 0 where there are
102+
missing values and 1 everywhere else.
47103
48104
Returns
49105
-------
@@ -56,10 +112,12 @@ def initialize_decomposition(
56112
concat_shape = sum(shape[0] for shape in shapes)
57113

58114
if init == "random":
115+
59116
return random_parafac2(
60117
shapes, rank, full=False, random_state=random_state, **context
61118
)
62119
elif init == "svd":
120+
63121
if shapes[0][1] < rank:
64122
raise ValueError(
65123
f"Cannot perform SVD init if rank ({rank}) is greater than the number of columns in each tensor slice ({shapes[0][1]})"
@@ -92,6 +150,7 @@ def initialize_decomposition(
92150
)
93151
if decomposition.rank != rank:
94152
raise ValueError("Cannot init with a decomposition of different rank")
153+
95154
return decomposition
96155
raise ValueError(f'Initialization method "{init}" not recognized')
97156

@@ -130,6 +189,7 @@ def __init__(
130189
nn_modes=None,
131190
acc_pow: float = 2.0,
132191
max_fail: int = 4,
192+
mask=None,
133193
):
134194
"""The line search strategy defined within Rasmus Bro's thesis [1, 2].
135195
@@ -150,6 +210,10 @@ def __init__(
150210
Line search steps are defined as `iteration ** (1.0 / acc_pow)`.
151211
max_fail : int
152212
The number of line search failures before increasing `acc_pow`.
213+
mask : ndarray, optional
214+
An array with the same shape as the tensor. It should be 0 where there are
215+
missing values and 1 everywhere else.
216+
153217
154218
References
155219
----------
@@ -166,6 +230,7 @@ def __init__(
166230
self.max_fail = max_fail # Increase acc_pow with one after max_fail failure
167231
self.acc_fail = 0 # How many times acceleration have failed
168232
self.nn_modes = nn_modes
233+
self.mask = mask # mask for missing values
169234

170235
def line_step(
171236
self,
@@ -223,7 +288,10 @@ def line_step(
223288
projections_ls = _compute_projections(tensor_slices, factors_ls, self.svd)
224289

225290
ls_rec_error = _parafac2_reconstruction_error(
226-
tensor_slices, (weights, factors_ls, projections_ls), self.norm_tensor
291+
tensor_slices=tensor_slices,
292+
decomposition=(weights, factors_ls, projections_ls),
293+
norm_matrices=self.norm_tensor,
294+
mask=self.mask,
227295
)
228296
ls_rec_error /= self.norm_tensor
229297

@@ -251,7 +319,7 @@ def line_step(
251319

252320

253321
def _parafac2_reconstruction_error(
254-
tensor_slices, decomposition, norm_matrices=None, projected_tensor=None
322+
tensor_slices, decomposition, norm_matrices=None, projected_tensor=None, mask=None
255323
):
256324
"""Calculates the reconstruction error of the PARAFAC2 decomposition. This implementation
257325
uses the inner product with each matrix for efficiency, as this avoids needing to
@@ -277,6 +345,9 @@ def _parafac2_reconstruction_error(
277345
projected_tensor : ndarray, optional
278346
The projections of X into an aligned tensor for CP decomposition. This can be optionally
279347
provided to avoid recalculating it.
348+
mask : ndarray, optional
349+
An array with the same shape as the tensor. It should be 0 where there are
350+
missing values and 1 everywhere else.
280351
281352
Returns
282353
-------
@@ -285,32 +356,44 @@ def _parafac2_reconstruction_error(
285356
"""
286357
_validate_parafac2_tensor(decomposition)
287358

288-
if norm_matrices is None:
289-
norm_X_sq = sum(tl.norm(t_slice, 2) ** 2 for t_slice in tensor_slices)
290-
else:
291-
norm_X_sq = norm_matrices**2
359+
if mask is None: # In fully observed data, we can utilize pre-computations
360+
361+
if norm_matrices is None:
362+
norm_X_sq = sum(tl.norm(t_slice, 2) ** 2 for t_slice in tensor_slices)
363+
else:
364+
norm_X_sq = norm_matrices**2
292365

293-
weights, (A, B, C), projections = decomposition
294-
if weights is not None:
295-
A = A * weights
366+
weights, (A, B, C), projections = decomposition
367+
if weights is not None:
368+
A = A * weights
296369

297-
norm_cmf_sq = 0
298-
inner_product = 0
299-
CtC = tl.dot(tl.transpose(C), C)
370+
norm_cmf_sq = 0
371+
inner_product = 0
372+
CtC = tl.dot(tl.transpose(C), C)
300373

301-
for i, t_slice in enumerate(tensor_slices):
302-
B_i = (projections[i] @ B) * A[i]
374+
for i, t_slice in enumerate(tensor_slices):
375+
B_i = (projections[i] @ B) * A[i]
303376

304-
if projected_tensor is None:
305-
tmp = tl.dot(tl.transpose(B_i), t_slice)
306-
else:
307-
tmp = tl.reshape(A[i], (-1, 1)) * tl.transpose(B) @ projected_tensor[i]
377+
if projected_tensor is None:
378+
tmp = tl.dot(tl.transpose(B_i), t_slice)
379+
else:
380+
tmp = tl.reshape(A[i], (-1, 1)) * tl.transpose(B) @ projected_tensor[i]
381+
382+
inner_product += tl.trace(tl.dot(tmp, C))
383+
384+
norm_cmf_sq += tl.sum((tl.transpose(B_i) @ B_i) * CtC)
308385

309-
inner_product += tl.trace(tl.dot(tmp, C))
386+
return tl.sqrt(norm_X_sq - 2 * inner_product + norm_cmf_sq)
310387

311-
norm_cmf_sq += tl.sum((tl.transpose(B_i) @ B_i) * CtC)
388+
else:
312389

313-
return tl.sqrt(norm_X_sq - 2 * inner_product + norm_cmf_sq)
390+
reconstructed_tensor = parafac2_to_slices(decomposition)
391+
total_error = 0
392+
for i, (slice, slice_mask) in enumerate(zip(tensor_slices, mask)):
393+
total_error += (
394+
tl.norm(slice_mask * slice - slice_mask * reconstructed_tensor[i]) ** 2
395+
)
396+
return tl.sqrt(total_error)
314397

315398

316399
def parafac2(
@@ -327,6 +410,7 @@ def parafac2(
327410
return_errors: bool = False,
328411
n_iter_parafac: int = 5,
329412
linesearch: bool = True,
413+
mask=None,
330414
):
331415
r"""PARAFAC2 decomposition [1]_ of a third order tensor via alternating least squares (ALS)
332416
@@ -410,6 +494,9 @@ def parafac2(
410494
linesearch : bool, default is False
411495
Whether to perform line search as proposed by Bro in his PhD dissertation [2]_
412496
(similar to the PLSToolbox line search described in [3]_).
497+
mask : ndarray, optional
498+
An array with the same shape as the tensor. It should be 0 where there are
499+
missing values and 1 everywhere else.
413500
414501
Returns
415502
-------
@@ -457,19 +544,52 @@ def parafac2(
457544
tensor_slices[0].shape[1] == tensor_slices[ii].shape[1]
458545
), "All tensor slices must have the same number of columns."
459546

460-
weights, factors, projections = initialize_decomposition(
547+
(weights, factors, projections) = initialize_decomposition(
461548
tensor_slices, rank, init=init, svd=svd, random_state=random_state
462549
)
463550
factors = list(factors)
464551

552+
# Initial missing imputation
553+
if mask is not None:
554+
555+
if init == "random" or init == "svd":
556+
557+
tensor_slices = _update_imputed(
558+
tensor_slices=list(tensor_slices), # required casting for jax
559+
mask=mask,
560+
decomposition=None,
561+
method="mode-2",
562+
)
563+
564+
else: # if factors are provided, we can impute missing values using the decomposition
565+
566+
tensor_slices = _update_imputed(
567+
tensor_slices=tensor_slices,
568+
mask=mask,
569+
decomposition=(weights, factors, projections),
570+
method="factors",
571+
)
572+
465573
rec_errors = []
466-
norm_tensor = tl.sqrt(
467-
sum(tl.norm(tensor_slice, 2) ** 2 for tensor_slice in tensor_slices)
468-
)
574+
575+
if mask is not None:
576+
577+
norm_tensor = tl.sqrt(
578+
sum(
579+
tl.norm(tensor_slice * slice_mask, 2) ** 2
580+
for tensor_slice, slice_mask in zip(tensor_slices, mask)
581+
)
582+
)
583+
584+
else:
585+
586+
norm_tensor = tl.sqrt(
587+
sum(tl.norm(tensor_slice, 2) ** 2 for tensor_slice in tensor_slices)
588+
)
469589

470590
if linesearch and not isinstance(linesearch, _BroThesisLineSearch):
471591
linesearch = _BroThesisLineSearch(
472-
norm_tensor, svd, verbose=verbose, nn_modes=nn_modes
592+
norm_tensor, svd, verbose=verbose, nn_modes=nn_modes, mask=mask
473593
)
474594

475595
# If nn_modes is set, we use HALS, otherwise, we use the standard parafac implementation.
@@ -540,6 +660,16 @@ def parafac_updates(X, w, f):
540660
rec_errors[-1],
541661
)
542662

663+
# Update imputations
664+
if mask is not None:
665+
666+
tensor_slices = _update_imputed(
667+
tensor_slices=tensor_slices,
668+
decomposition=(weights, factors, projections),
669+
mask=mask,
670+
method="factors",
671+
)
672+
543673
if normalize_factors:
544674
weights, factors = cp_normalize((weights, factors))
545675

@@ -549,6 +679,7 @@ def parafac_updates(X, w, f):
549679
(weights, factors, projections),
550680
norm_tensor,
551681
projected_tensor,
682+
mask,
552683
)
553684
rec_error /= norm_tensor
554685
rec_errors.append(rec_error)
@@ -651,6 +782,9 @@ class Parafac2(DecompositionMixin):
651782
Activate return of iteration errors
652783
n_iter_parafac : int, optional
653784
Number of PARAFAC iterations to perform for each PARAFAC2 iteration
785+
mask : ndarray, optional
786+
An array with the same shape as the tensor. It should be 0 where there are
787+
missing values and 1 everywhere else.
654788
655789
Returns
656790
-------
@@ -691,6 +825,7 @@ def __init__(
691825
return_errors=False,
692826
n_iter_parafac=5,
693827
linesearch=False,
828+
mask=None,
694829
):
695830
self.rank = rank
696831
self.n_iter_max = n_iter_max
@@ -704,6 +839,7 @@ def __init__(
704839
self.return_errors = return_errors
705840
self.n_iter_parafac = n_iter_parafac
706841
self.linesearch = linesearch
842+
self.mask = mask
707843

708844
def fit_transform(self, tensor):
709845
"""Decompose an input tensor
@@ -730,5 +866,6 @@ def fit_transform(self, tensor):
730866
return_errors=self.return_errors,
731867
n_iter_parafac=self.n_iter_parafac,
732868
linesearch=self.linesearch,
869+
mask=self.mask,
733870
)
734871
return self.decomposition_

0 commit comments

Comments
 (0)