Skip to content

Commit 855fe42

Browse files
committed
Add solver
This solver decomposes the product of a THC times a CP (think THC * LT denominator) into a new THC
1 parent 7c94fd7 commit 855fe42

File tree

2 files changed

+507
-0
lines changed

2 files changed

+507
-0
lines changed

src/TiledArray/math/solvers/cp/cp.h

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -467,6 +467,62 @@ class CP {
467467
}
468468
return false;
469469
}
470+
471+
// This function is for a 2 core CP i.e. a THC approximation
472+
virtual bool check_thc_fit(bool verbose = false) {
473+
// Compute the inner product T * T_CP
474+
// The MTtKRP is for the core tensor, so only need to dot with the core which isn't normalized
475+
const auto ref_dot_cp = MTtKRP("r,n").dot(cp_factors[2]("r,n"));
476+
// compute the square of the CP tensor (can use the grammian)
477+
auto factor_norm = [&]() {
478+
auto gram_ptr = partial_grammian.begin();
479+
Array WL, WR;
480+
WL("P,L") = (*(gram_ptr))("P,L") * (*(gram_ptr + 1))("P,L");
481+
WR("Q,M") = (*(gram_ptr + 2))("Q,M") * (*(gram_ptr + 3))("Q,M");
482+
483+
auto result = TA::dot(cp_factors[2]("P,Q") * WL("P,L"), WR("Q,M") * cp_factors[2]("L,M"));
484+
// not sure why need to fence here, but hang periodically without it
485+
WL.world().gop.fence();
486+
487+
return result;
488+
};
489+
// compute the error in the loss function and find the fit
490+
const auto norm_cp = factor_norm(); // ||T_CP||_2
491+
const auto squared_norm_error = norm_ref_sq +
492+
norm_cp -
493+
2.0 * ref_dot_cp; // ||T - T_CP||_2^2
494+
// N.B. squared_norm_error is very noisy
495+
// TA_ASSERT(squared_norm_error >= - 1e-8);
496+
const auto norm_error = sqrt(abs(squared_norm_error));
497+
const auto fit = 1.0 - (norm_error / norm_reference);
498+
const auto fit_change = fit - prev_fit;
499+
prev_fit = fit;
500+
// print fit data if required
501+
if (verbose) {
502+
std::cout << MTtKRP.world().rank() << ": fit=" << fit
503+
<< " fit_change=" << fit_change << std::endl;
504+
}
505+
506+
// if the change in fit is less than the tolerance try to return true.
507+
if (abs(fit_change) < fit_tol) {
508+
converged_num++;
509+
if (converged_num == 2) {
510+
converged_num = 0;
511+
final_fit = prev_fit;
512+
prev_fit = 1.0;
513+
if (verbose)
514+
std::cout << MTtKRP.world().rank() << ": converged" << std::endl;
515+
return true;
516+
} else {
517+
TA_ASSERT(converged_num == 1);
518+
if (verbose)
519+
std::cout << MTtKRP.world().rank() << ": pre-converged" << std::endl;
520+
}
521+
} else {
522+
converged_num = 0;
523+
}
524+
return false;
525+
}
470526
};
471527

472528
} // namespace TiledArray::math::cp

0 commit comments

Comments
 (0)