Skip to content

Commit 013ad81

Browse files
authored
Merge pull request tensorly#595 from meyer-lab/masked-error-cp
Share reconstruction error between CP and NN CP to fix masking
2 parents c571df8 + 1a5ab0f commit 013ad81

File tree

1 file changed

+15
-11
lines changed

1 file changed

+15
-11
lines changed

tensorly/decomposition/_nn_cp.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import warnings
22
import tensorly as tl
33
from ._base_decomposition import DecompositionMixin
4-
from ._cp import initialize_cp
4+
from ._cp import initialize_cp, error_calc
55
from ..solvers.nnls import hals_nnls
66
from ..cp_tensor import (
77
CPTensor,
@@ -10,7 +10,6 @@
1010
cp_normalize,
1111
validate_cp_rank,
1212
)
13-
from ..tenalg.svd import svd_interface
1413

1514
# Authors: Jean Kossaifi <jean.kossaifi+tensors@gmail.com>
1615
# Chris Swierczewski <csw@amazon.com>
@@ -139,17 +138,19 @@ def non_negative_parafac(
139138
weights, factors = cp_normalize((weights, factors))
140139

141140
if tol:
142-
# ||tensor - rec||^2 = ||tensor||^2 + ||rec||^2 - 2*<tensor, rec>
143-
factors_norm = cp_norm((weights, factors))
144-
145-
# mttkrp and factor for the last mode. This is equivalent to the
146-
# inner product <tensor, factorization>
147-
iprod = tl.sum(tl.sum(mttkrp * factor, axis=0))
148-
rec_error = (
149-
tl.sqrt(tl.abs(norm_tensor**2 + factors_norm**2 - 2 * iprod))
150-
/ norm_tensor
141+
# Calculate the reconstruction error. We can use the same method from CP.
142+
unnorml_rec_error, tensor, norm_tensor = error_calc(
143+
tensor,
144+
norm_tensor,
145+
weights,
146+
factors,
147+
sparsity=None,
148+
mask=mask,
149+
mttkrp=mttkrp,
151150
)
151+
rec_error = unnorml_rec_error / norm_tensor
152152
rec_errors.append(rec_error)
153+
153154
if iteration >= 1:
154155
rec_error_decrease = rec_errors[-2] - rec_errors[-1]
155156

@@ -346,6 +347,7 @@ def non_negative_parafac_hals(
346347
/ norm_tensor
347348
)
348349
rec_errors.append(rec_error)
350+
349351
if iteration >= 1:
350352
rec_error_decrease = rec_errors[-2] - rec_errors[-1]
351353

@@ -462,6 +464,7 @@ def __init__(
462464
cvg_criterion="abs_rec_error",
463465
fixed_modes=None,
464466
):
467+
self.rank = rank
465468
self.n_iter_max = n_iter_max
466469
self.init = init
467470
self.svd = svd
@@ -489,6 +492,7 @@ def fit_transform(self, tensor):
489492

490493
cp_tensor, errors = non_negative_parafac(
491494
tensor,
495+
rank=self.rank,
492496
n_iter_max=self.n_iter_max,
493497
init=self.init,
494498
svd=self.svd,

0 commit comments

Comments
 (0)