|
1 | 1 | import warnings |
2 | 2 | import tensorly as tl |
3 | 3 | from ._base_decomposition import DecompositionMixin |
4 | | -from ._cp import initialize_cp |
| 4 | +from ._cp import initialize_cp, error_calc |
5 | 5 | from ..solvers.nnls import hals_nnls |
6 | 6 | from ..cp_tensor import ( |
7 | 7 | CPTensor, |
|
10 | 10 | cp_normalize, |
11 | 11 | validate_cp_rank, |
12 | 12 | ) |
13 | | -from ..tenalg.svd import svd_interface |
14 | 13 |
|
15 | 14 | # Authors: Jean Kossaifi <jean.kossaifi+tensors@gmail.com> |
16 | 15 | # Chris Swierczewski <csw@amazon.com> |
@@ -139,17 +138,19 @@ def non_negative_parafac( |
139 | 138 | weights, factors = cp_normalize((weights, factors)) |
140 | 139 |
|
141 | 140 | if tol: |
142 | | - # ||tensor - rec||^2 = ||tensor||^2 + ||rec||^2 - 2*<tensor, rec> |
143 | | - factors_norm = cp_norm((weights, factors)) |
144 | | - |
145 | | - # mttkrp and factor for the last mode. This is equivalent to the |
146 | | - # inner product <tensor, factorization> |
147 | | - iprod = tl.sum(tl.sum(mttkrp * factor, axis=0)) |
148 | | - rec_error = ( |
149 | | - tl.sqrt(tl.abs(norm_tensor**2 + factors_norm**2 - 2 * iprod)) |
150 | | - / norm_tensor |
| 141 | + # Calculate the reconstruction error. We can use the same method from CP. |
| 142 | + unnorml_rec_error, tensor, norm_tensor = error_calc( |
| 143 | + tensor, |
| 144 | + norm_tensor, |
| 145 | + weights, |
| 146 | + factors, |
| 147 | + sparsity=None, |
| 148 | + mask=mask, |
| 149 | + mttkrp=mttkrp, |
151 | 150 | ) |
| 151 | + rec_error = unnorml_rec_error / norm_tensor |
152 | 152 | rec_errors.append(rec_error) |
| 153 | + |
153 | 154 | if iteration >= 1: |
154 | 155 | rec_error_decrease = rec_errors[-2] - rec_errors[-1] |
155 | 156 |
|
@@ -346,6 +347,7 @@ def non_negative_parafac_hals( |
346 | 347 | / norm_tensor |
347 | 348 | ) |
348 | 349 | rec_errors.append(rec_error) |
| 350 | + |
349 | 351 | if iteration >= 1: |
350 | 352 | rec_error_decrease = rec_errors[-2] - rec_errors[-1] |
351 | 353 |
|
@@ -462,6 +464,7 @@ def __init__( |
462 | 464 | cvg_criterion="abs_rec_error", |
463 | 465 | fixed_modes=None, |
464 | 466 | ): |
| 467 | + self.rank = rank |
465 | 468 | self.n_iter_max = n_iter_max |
466 | 469 | self.init = init |
467 | 470 | self.svd = svd |
@@ -489,6 +492,7 @@ def fit_transform(self, tensor): |
489 | 492 |
|
490 | 493 | cp_tensor, errors = non_negative_parafac( |
491 | 494 | tensor, |
| 495 | + rank=self.rank, |
492 | 496 | n_iter_max=self.n_iter_max, |
493 | 497 | init=self.init, |
494 | 498 | svd=self.svd, |
|
0 commit comments