@@ -467,6 +467,62 @@ class CP {
467467 }
468468 return false ;
469469 }
470+
471+ // This function is for a 2 core CP i.e. a THC approximation
472+ virtual bool check_thc_fit (bool verbose = false ) {
473+ // Compute the inner product T * T_CP
474+ // The MTtKRP is for the core tensor, so only need to dot with the core which isn't normalized
475+ const auto ref_dot_cp = MTtKRP (" r,n" ).dot (cp_factors[2 ](" r,n" ));
476+ // compute the square of the CP tensor (can use the grammian)
477+ auto factor_norm = [&]() {
478+ auto gram_ptr = partial_grammian.begin ();
479+ Array WL, WR;
480+ WL (" P,L" ) = (*(gram_ptr))(" P,L" ) * (*(gram_ptr + 1 ))(" P,L" );
481+ WR (" Q,M" ) = (*(gram_ptr + 2 ))(" Q,M" ) * (*(gram_ptr + 3 ))(" Q,M" );
482+
483+ auto result = TA::dot (cp_factors[2 ](" P,Q" ) * WL (" P,L" ), WR (" Q,M" ) * cp_factors[2 ](" L,M" ));
484+ // not sure why need to fence here, but hang periodically without it
485+ WL.world ().gop .fence ();
486+
487+ return result;
488+ };
489+ // compute the error in the loss function and find the fit
490+ const auto norm_cp = factor_norm (); // ||T_CP||_2
491+ const auto squared_norm_error = norm_ref_sq +
492+ norm_cp -
493+ 2.0 * ref_dot_cp; // ||T - T_CP||_2^2
494+ // N.B. squared_norm_error is very noisy
495+ // TA_ASSERT(squared_norm_error >= - 1e-8);
496+ const auto norm_error = sqrt (abs (squared_norm_error));
497+ const auto fit = 1.0 - (norm_error / norm_reference);
498+ const auto fit_change = fit - prev_fit;
499+ prev_fit = fit;
500+ // print fit data if required
501+ if (verbose) {
502+ std::cout << MTtKRP.world ().rank () << " : fit=" << fit
503+ << " fit_change=" << fit_change << std::endl;
504+ }
505+
506+ // if the change in fit is less than the tolerance try to return true.
507+ if (abs (fit_change) < fit_tol) {
508+ converged_num++;
509+ if (converged_num == 2 ) {
510+ converged_num = 0 ;
511+ final_fit = prev_fit;
512+ prev_fit = 1.0 ;
513+ if (verbose)
514+ std::cout << MTtKRP.world ().rank () << " : converged" << std::endl;
515+ return true ;
516+ } else {
517+ TA_ASSERT (converged_num == 1 );
518+ if (verbose)
519+ std::cout << MTtKRP.world ().rank () << " : pre-converged" << std::endl;
520+ }
521+ } else {
522+ converged_num = 0 ;
523+ }
524+ return false ;
525+ }
470526};
471527
472528} // namespace TiledArray::math::cp
0 commit comments