99from ..parafac2_tensor import (
1010 Parafac2Tensor ,
1111 _validate_parafac2_tensor ,
12+ parafac2_to_slices ,
1213)
1314from ..cp_tensor import CPTensor , cp_normalize
1415from ..tenalg .svd import svd_interface , SVD_TYPES
1718# Yngve Mardal Moe
1819
1920
21+ def _update_imputed (tensor_slices , mask , decomposition , method ):
22+ """
23+ Update missing values of tensor slices according to method.
24+
25+ Parameters
26+ ----------
27+ tensor_slices : Iterable of ndarray
28+ mask : ndarray
29+ An array with the same shape as the tensor. It should be 0 where there are
30+ missing values and 1 everywhere else.
31+ decomposition : Parafac2Tensor, optional
32+ method : string
33+ One of 'mode-2' or 'factors'. 'mode-2' updates imputed values according to
34+ mean of each mode-2 slice. If 'factors' is chosen, set missing entries
35+ according to reconstructed tensor given from 'decomposition'.
36+ 'mode-2' is used (by default) for initializing missing entries while
37+ 'factors' is used for updating imputations during optimization. If an
38+ initial decomposition is specified, 'factors' is used at initialization.
39+
40+ Returns
41+ -------
42+ tensor_slices : Iterable of ndarray
43+ """
44+
45+ if method == "mode-2" :
46+
47+ for slice_no , (slice , slice_mask ) in enumerate (zip (tensor_slices , mask )):
48+
49+ slice_mean = tl .sum (slice * slice_mask ) / T .sum (slice_mask )
50+
51+ tensor_slices [slice_no ] = tl .where (
52+ slice_mask == 0 , slice_mean , tensor_slices [slice_no ]
53+ )
54+
55+ else : # factors
56+
57+ reconstructed_slices = parafac2_to_slices (decomposition )
58+ tensor_slices = list (tensor_slices )
59+
60+ for slice_no , (slice , rec_slice , slice_mask ) in enumerate (
61+ zip (tensor_slices , reconstructed_slices , mask )
62+ ):
63+
64+ tensor_slices [slice_no ] = tl .where (slice_mask == 0 , rec_slice , slice )
65+
66+ return tensor_slices
67+
68+
2069def initialize_decomposition (
21- tensor_slices , rank , init = "random" , svd = "truncated_svd" , random_state = None
70+ tensor_slices ,
71+ rank ,
72+ init = "random" ,
73+ svd = "truncated_svd" ,
74+ random_state = None ,
2275):
2376 r"""Initiate a random PARAFAC2 decomposition given rank and tensor slices.
2477
@@ -44,6 +97,9 @@ def initialize_decomposition(
4497 rank : int
4598 init : {'random', 'svd', CPTensor, Parafac2Tensor}, optional
4699 random_state : `np.random.RandomState`
100+ mask : ndarray, optional
101+ An array with the same shape as the tensor. It should be 0 where there are
102+ missing values and 1 everywhere else.
47103
48104 Returns
49105 -------
@@ -56,10 +112,12 @@ def initialize_decomposition(
56112 concat_shape = sum (shape [0 ] for shape in shapes )
57113
58114 if init == "random" :
115+
59116 return random_parafac2 (
60117 shapes , rank , full = False , random_state = random_state , ** context
61118 )
62119 elif init == "svd" :
120+
63121 if shapes [0 ][1 ] < rank :
64122 raise ValueError (
65123 f"Cannot perform SVD init if rank ({ rank } ) is greater than the number of columns in each tensor slice ({ shapes [0 ][1 ]} )"
@@ -92,6 +150,7 @@ def initialize_decomposition(
92150 )
93151 if decomposition .rank != rank :
94152 raise ValueError ("Cannot init with a decomposition of different rank" )
153+
95154 return decomposition
96155 raise ValueError (f'Initialization method "{ init } " not recognized' )
97156
@@ -130,6 +189,7 @@ def __init__(
130189 nn_modes = None ,
131190 acc_pow : float = 2.0 ,
132191 max_fail : int = 4 ,
192+ mask = None ,
133193 ):
134194 """The line search strategy defined within Rasmus Bro's thesis [1, 2].
135195
@@ -150,6 +210,10 @@ def __init__(
150210 Line search steps are defined as `iteration ** (1.0 / acc_pow)`.
151211 max_fail : int
152212 The number of line search failures before increasing `acc_pow`.
213+ mask : ndarray, optional
214+ An array with the same shape as the tensor. It should be 0 where there are
215+ missing values and 1 everywhere else.
216+
153217
154218 References
155219 ----------
@@ -166,6 +230,7 @@ def __init__(
166230 self .max_fail = max_fail # Increase acc_pow with one after max_fail failure
167231 self .acc_fail = 0 # How many times acceleration have failed
168232 self .nn_modes = nn_modes
233+ self .mask = mask # mask for missing values
169234
170235 def line_step (
171236 self ,
@@ -223,7 +288,10 @@ def line_step(
223288 projections_ls = _compute_projections (tensor_slices , factors_ls , self .svd )
224289
225290 ls_rec_error = _parafac2_reconstruction_error (
226- tensor_slices , (weights , factors_ls , projections_ls ), self .norm_tensor
291+ tensor_slices = tensor_slices ,
292+ decomposition = (weights , factors_ls , projections_ls ),
293+ norm_matrices = self .norm_tensor ,
294+ mask = self .mask ,
227295 )
228296 ls_rec_error /= self .norm_tensor
229297
@@ -251,7 +319,7 @@ def line_step(
251319
252320
253321def _parafac2_reconstruction_error (
254- tensor_slices , decomposition , norm_matrices = None , projected_tensor = None
322+ tensor_slices , decomposition , norm_matrices = None , projected_tensor = None , mask = None
255323):
256324 """Calculates the reconstruction error of the PARAFAC2 decomposition. This implementation
257325 uses the inner product with each matrix for efficiency, as this avoids needing to
@@ -277,6 +345,9 @@ def _parafac2_reconstruction_error(
277345 projected_tensor : ndarray, optional
278346 The projections of X into an aligned tensor for CP decomposition. This can be optionally
279347 provided to avoid recalculating it.
348+ mask : ndarray, optional
349+ An array with the same shape as the tensor. It should be 0 where there are
350+ missing values and 1 everywhere else.
280351
281352 Returns
282353 -------
@@ -285,32 +356,44 @@ def _parafac2_reconstruction_error(
285356 """
286357 _validate_parafac2_tensor (decomposition )
287358
288- if norm_matrices is None :
289- norm_X_sq = sum (tl .norm (t_slice , 2 ) ** 2 for t_slice in tensor_slices )
290- else :
291- norm_X_sq = norm_matrices ** 2
359+ if mask is None : # In fully observed data, we can utilize pre-computations
360+
361+ if norm_matrices is None :
362+ norm_X_sq = sum (tl .norm (t_slice , 2 ) ** 2 for t_slice in tensor_slices )
363+ else :
364+ norm_X_sq = norm_matrices ** 2
292365
293- weights , (A , B , C ), projections = decomposition
294- if weights is not None :
295- A = A * weights
366+ weights , (A , B , C ), projections = decomposition
367+ if weights is not None :
368+ A = A * weights
296369
297- norm_cmf_sq = 0
298- inner_product = 0
299- CtC = tl .dot (tl .transpose (C ), C )
370+ norm_cmf_sq = 0
371+ inner_product = 0
372+ CtC = tl .dot (tl .transpose (C ), C )
300373
301- for i , t_slice in enumerate (tensor_slices ):
302- B_i = (projections [i ] @ B ) * A [i ]
374+ for i , t_slice in enumerate (tensor_slices ):
375+ B_i = (projections [i ] @ B ) * A [i ]
303376
304- if projected_tensor is None :
305- tmp = tl .dot (tl .transpose (B_i ), t_slice )
306- else :
307- tmp = tl .reshape (A [i ], (- 1 , 1 )) * tl .transpose (B ) @ projected_tensor [i ]
377+ if projected_tensor is None :
378+ tmp = tl .dot (tl .transpose (B_i ), t_slice )
379+ else :
380+ tmp = tl .reshape (A [i ], (- 1 , 1 )) * tl .transpose (B ) @ projected_tensor [i ]
381+
382+ inner_product += tl .trace (tl .dot (tmp , C ))
383+
384+ norm_cmf_sq += tl .sum ((tl .transpose (B_i ) @ B_i ) * CtC )
308385
309- inner_product += tl .trace ( tl . dot ( tmp , C ) )
386+ return tl .sqrt ( norm_X_sq - 2 * inner_product + norm_cmf_sq )
310387
311- norm_cmf_sq += tl . sum (( tl . transpose ( B_i ) @ B_i ) * CtC )
388+ else :
312389
313- return tl .sqrt (norm_X_sq - 2 * inner_product + norm_cmf_sq )
390+ reconstructed_tensor = parafac2_to_slices (decomposition )
391+ total_error = 0
392+ for i , (slice , slice_mask ) in enumerate (zip (tensor_slices , mask )):
393+ total_error += (
394+ tl .norm (slice_mask * slice - slice_mask * reconstructed_tensor [i ]) ** 2
395+ )
396+ return tl .sqrt (total_error )
314397
315398
316399def parafac2 (
@@ -327,6 +410,7 @@ def parafac2(
327410 return_errors : bool = False ,
328411 n_iter_parafac : int = 5 ,
329412 linesearch : bool = True ,
413+ mask = None ,
330414):
331415 r"""PARAFAC2 decomposition [1]_ of a third order tensor via alternating least squares (ALS)
332416
@@ -410,6 +494,9 @@ def parafac2(
410494 linesearch : bool, default is False
411495 Whether to perform line search as proposed by Bro in his PhD dissertation [2]_
412496 (similar to the PLSToolbox line search described in [3]_).
497+ mask : ndarray, optional
498+ An array with the same shape as the tensor. It should be 0 where there are
499+ missing values and 1 everywhere else.
413500
414501 Returns
415502 -------
@@ -457,19 +544,52 @@ def parafac2(
457544 tensor_slices [0 ].shape [1 ] == tensor_slices [ii ].shape [1 ]
458545 ), "All tensor slices must have the same number of columns."
459546
460- weights , factors , projections = initialize_decomposition (
547+ ( weights , factors , projections ) = initialize_decomposition (
461548 tensor_slices , rank , init = init , svd = svd , random_state = random_state
462549 )
463550 factors = list (factors )
464551
552+ # Initial missing imputation
553+ if mask is not None :
554+
555+ if init == "random" or init == "svd" :
556+
557+ tensor_slices = _update_imputed (
558+ tensor_slices = list (tensor_slices ), # required casting for jax
559+ mask = mask ,
560+ decomposition = None ,
561+ method = "mode-2" ,
562+ )
563+
564+ else : # if factors are provided, we can impute missing values using the decomposition
565+
566+ tensor_slices = _update_imputed (
567+ tensor_slices = tensor_slices ,
568+ mask = mask ,
569+ decomposition = (weights , factors , projections ),
570+ method = "factors" ,
571+ )
572+
465573 rec_errors = []
466- norm_tensor = tl .sqrt (
467- sum (tl .norm (tensor_slice , 2 ) ** 2 for tensor_slice in tensor_slices )
468- )
574+
575+ if mask is not None :
576+
577+ norm_tensor = tl .sqrt (
578+ sum (
579+ tl .norm (tensor_slice * slice_mask , 2 ) ** 2
580+ for tensor_slice , slice_mask in zip (tensor_slices , mask )
581+ )
582+ )
583+
584+ else :
585+
586+ norm_tensor = tl .sqrt (
587+ sum (tl .norm (tensor_slice , 2 ) ** 2 for tensor_slice in tensor_slices )
588+ )
469589
470590 if linesearch and not isinstance (linesearch , _BroThesisLineSearch ):
471591 linesearch = _BroThesisLineSearch (
472- norm_tensor , svd , verbose = verbose , nn_modes = nn_modes
592+ norm_tensor , svd , verbose = verbose , nn_modes = nn_modes , mask = mask
473593 )
474594
475595 # If nn_modes is set, we use HALS, otherwise, we use the standard parafac implementation.
@@ -540,6 +660,16 @@ def parafac_updates(X, w, f):
540660 rec_errors [- 1 ],
541661 )
542662
663+ # Update imputations
664+ if mask is not None :
665+
666+ tensor_slices = _update_imputed (
667+ tensor_slices = tensor_slices ,
668+ decomposition = (weights , factors , projections ),
669+ mask = mask ,
670+ method = "factors" ,
671+ )
672+
543673 if normalize_factors :
544674 weights , factors = cp_normalize ((weights , factors ))
545675
@@ -549,6 +679,7 @@ def parafac_updates(X, w, f):
549679 (weights , factors , projections ),
550680 norm_tensor ,
551681 projected_tensor ,
682+ mask ,
552683 )
553684 rec_error /= norm_tensor
554685 rec_errors .append (rec_error )
@@ -651,6 +782,9 @@ class Parafac2(DecompositionMixin):
651782 Activate return of iteration errors
652783 n_iter_parafac : int, optional
653784 Number of PARAFAC iterations to perform for each PARAFAC2 iteration
785+ mask : ndarray, optional
786+ An array with the same shape as the tensor. It should be 0 where there are
787+ missing values and 1 everywhere else.
654788
655789 Returns
656790 -------
@@ -691,6 +825,7 @@ def __init__(
691825 return_errors = False ,
692826 n_iter_parafac = 5 ,
693827 linesearch = False ,
828+ mask = None ,
694829 ):
695830 self .rank = rank
696831 self .n_iter_max = n_iter_max
@@ -704,6 +839,7 @@ def __init__(
704839 self .return_errors = return_errors
705840 self .n_iter_parafac = n_iter_parafac
706841 self .linesearch = linesearch
842+ self .mask = mask
707843
708844 def fit_transform (self , tensor ):
709845 """Decompose an input tensor
@@ -730,5 +866,6 @@ def fit_transform(self, tensor):
730866 return_errors = self .return_errors ,
731867 n_iter_parafac = self .n_iter_parafac ,
732868 linesearch = self .linesearch ,
869+ mask = self .mask ,
733870 )
734871 return self .decomposition_
0 commit comments