44#include < vector>
55#include < iomanip>
66#include < btas/generic/dot_impl.h>
7+ #include < btas/generic/contract.h>
8+ #include < btas/generic/reconstruct.h>
9+ #include < btas/generic/scal_impl.h>
710#include < btas/varray/varray.h>
811
912namespace btas {
@@ -23,7 +26,7 @@ namespace btas {
2326 public:
2427 // / constructor for the base convergence test object
2528 // / \param[in] tol tolerance for ALS convergence
26- explicit NormCheck (double tol = 1e-3 ) : tol_(tol) {
29+ explicit NormCheck (double tol = 1e-3 ) : tol_(tol), iter_( 0 ) {
2730 }
2831
2932 ~NormCheck () = default ;
@@ -50,17 +53,34 @@ namespace btas {
5053 prev[r] = btas_factors[r];
5154 }
5255
56+ if (verbose_) {
57+ std::cout << rank_ << " \t " << iter_ << " \t " << std::setprecision (16 ) << diff << std::endl;
58+ }
5359 if (diff < this ->tol_ ) {
5460 return true ;
5561 }
62+ ++iter_;
63+
5664 return false ;
5765 }
5866
67+ // / Option to print fit and change in fit in the () operator call
68+ // / \param[in] verb bool which turns off/on fit printing.
69+ void verbose (bool verb) {
70+ verbose_ = verb;
71+ }
72+
73+ double get_fit (bool hit_max_iters = false ){
74+
75+ }
76+
5977 private:
6078 double tol_;
6179 std::vector<Tensor> prev; // Set of previous factor matrices
6280 size_t ndim; // Number of factor matrices
6381 ind_t rank_; // Rank of the CP problem
82+ bool verbose_ = false ;
83+ size_t iter_;
6484 };
6585
6686 /* *
@@ -105,7 +125,7 @@ namespace btas {
105125 }
106126
107127 double normFactors = norm (btas_factors, V);
108- double normResidual = sqrt (abs (normT_ * normT_ + normFactors * normFactors - 2 * abs (iprod)));
128+ double normResidual = sqrt (abs (normT_ * normT_ + normFactors - 2 * abs (iprod)));
109129 double fit = 1 . - (normResidual / normT_);
110130
111131 double fitChange = abs (fitOld_ - fit);
@@ -208,11 +228,11 @@ namespace btas {
208228 }
209229 }
210230
211- dtype nrm = 0.0 ;
231+ RT nrm = 0.0 ;
212232 for (auto &i : coeffMat) {
213- nrm += i ;
233+ nrm += std::real (i) ;
214234 }
215- return sqrt ( abs ( nrm)) ;
235+ return nrm;
216236 }
217237 };
218238
@@ -421,5 +441,315 @@ namespace btas {
421441 return sqrt (abs (nrm));
422442 }
423443 };
444+
445+ /* *
446+ \breif Class used to decide when ALS problem is converged.
447+ The fit is not computed and the optimization just runs until nALS is
448+ reached.
449+ **/
450+ template <typename Tensor>
451+ class NoCheck {
452+ using ind_t = typename Tensor::range_type::index_type::value_type;
453+ using ord_t = typename range_traits<typename Tensor::range_type>::ordinal_type;
454+
455+ public:
456+ // / constructor for the base convergence test object
457+ // / \param[in] tol tolerance for ALS convergence
458+ explicit NoCheck (double tol = 1e-3 ) : iter_(0 ){
459+ }
460+
461+ ~NoCheck () = default ;
462+
463+ // / Function to check convergence of the ALS problem
464+ // / convergence when \f$ \sum_n^{ndim} \frac{\|A^{i}_n - A^{i+1}_n\|}{dim(A^{i}_n} \leq \epsilon \f$
465+ // / \param[in] btas_factors Current set of factor matrices
466+ bool operator () (const std::vector<Tensor> &btas_factors,
467+ const std::vector<Tensor> & V = std::vector<Tensor>()){
468+ auto rank_ = btas_factors[1 ].extent (1 );
469+ if (verbose_) {
470+ std::cout << rank_ << " \t " << iter_ << std::endl;
471+ }
472+ ++iter_;
473+
474+ return false ;
475+ }
476+
477+ // / Option to print fit and change in fit in the () operator call
478+ // / \param[in] verb bool which turns off/on fit printing.
479+ void verbose (bool verb) {
480+ verbose_ = verb;
481+ }
482+
483+ double get_fit (bool hit_max_iters = false ){
484+
485+ }
486+
487+ private:
488+ double tol_;
489+ bool verbose_ = false ;
490+ size_t iter_;
491+ Tensor prevT_;
492+ };
493+
494+ // / This class is going to take a tensor approximation
495+ // / and compare it to the previous tensor approximation
496+ // / Skipping the total fit and directly computing the relative fit
497+ template <typename Tensor>
498+ class ApproxFitCheck {
499+ using ind_t = typename Tensor::range_type::index_type::value_type;
500+ using ord_t = typename range_traits<typename Tensor::range_type>::ordinal_type;
501+
502+ public:
503+ // / constructor for the base convergence test object
504+ // / \param[in] tol tolerance for ALS convergence
505+ explicit ApproxFitCheck (double tol = 1e-3 ) : iter_(0 ), tol_(tol){
506+ }
507+
508+ ~ApproxFitCheck () = default ;
509+
510+ // / Function to check convergence of the ALS problem
511+ // / convergence when \f$ \sum_n^{ndim} \frac{\|A^{i}_n - A^{i+1}_n\|}{dim(A^{i}_n} \leq \epsilon \f$
512+ // / \param[in] btas_factors Current set of factor matrices
513+
514+ bool operator () (std::vector<Tensor> & btas_factors,
515+ const std::vector<Tensor> & V = std::vector<Tensor>()) {
516+ auto rank_ = btas_factors[1 ].extent (1 );
517+
518+ auto fit = 0.0 ;
519+ if (iter_ == 0 ) {
520+ fit_prev_ = (norm (btas_factors, btas_factors, rank_));
521+ norm_prev_ = sqrt (fit_prev_);
522+ prev_factors = btas_factors;
523+ // diff = reconstruct(btas_factors, orders);
524+ if (verbose_) {
525+ std::cout << rank_ << " \t " << iter_ << " \t " << 1.0 << std::endl;
526+ }
527+ ++iter_;
528+ return false ;
529+ }
530+
531+ auto curr_norm = norm (btas_factors, btas_factors, rank_);
532+ fit = sqrt (fit_prev_ - 2 * norm (prev_factors, btas_factors, rank_) + curr_norm) / norm_prev_;
533+ // fit = norm(diff);
534+ // diff = tnew;
535+ fit_prev_ = curr_norm;
536+ norm_prev_ = sqrt (curr_norm);
537+ prev_factors = btas_factors;
538+
539+ if (verbose_) {
540+ std::cout << rank_ << " \t " << iter_ << " \t " << fit << std::endl;
541+ }
542+ ++iter_;
543+ if (fit < tol_) {
544+ ++converged_num;
545+ if (converged_num > 1 ) {
546+ iter_ = 0 ;
547+ return true ;
548+ }
549+ }
550+ return false ;
551+ }
552+
553+ void verbose (bool verb){
554+ verbose_ = verb;
555+ }
556+
557+ private:
558+ double tol_;
559+ bool verbose_ = false ;
560+ double fit_prev_, norm_prev_;
561+ std::vector<size_t > orders;
562+ std::vector<Tensor> prev_factors;
563+ // Tensor diff;
564+ size_t converged_num = 0 ;
565+ size_t iter_;
566+
567+ double norm (Tensor& a){
568+ auto n = 0.0 ;
569+ for (auto & i : a)
570+ n += i * i;
571+ return sqrt (n);
572+ }
573+
574+ double norm (std::vector<Tensor> & facs1, std::vector<Tensor>& facs2, ind_t rank_){
575+ BTAS_ASSERT (facs1.size () == facs2.size ());
576+ ind_t num_factors = facs1.size ();
577+ Tensor RRp;
578+ Tensor T1 = facs1[0 ], T2 = facs2[0 ];
579+ auto lam_ptr1 = facs1[num_factors - 1 ].data (),
580+ lam_ptr2 = facs2[num_factors - 1 ].data ();
581+ for (ind_t i = 0 ; i < rank_; i++) {
582+ scal (T1.extent (0 ), *(lam_ptr1 + i), std::begin (T1) + i, rank_);
583+ scal (T2.extent (0 ), *(lam_ptr2 + i), std::begin (T2) + i, rank_);
584+ }
585+
586+ contract (1.0 , T1, {1 ,2 }, T2, {1 ,3 }, 0.0 , RRp, {2 ,3 });
587+
588+ for (ind_t i = 0 ; i < rank_; i++) {
589+ auto val1 = *(lam_ptr1 + i),
590+ val2 = *(lam_ptr2 + i);
591+ scal (T1.extent (0 ), (abs (val1) > 1e-12 ? 1.0 /val1 : 1.0 ), std::begin (T1) + i, rank_);
592+ scal (T2.extent (0 ), (abs (val2) > 1e-12 ? 1.0 /val2 : 1.0 ), std::begin (T2) + i, rank_);
593+ }
594+
595+ auto * ptr_RRp = RRp.data ();
596+ for (ind_t i = 1 ; i < num_factors - 3 ; ++i) {
597+ Tensor temp;
598+ contract (1.0 , facs1[i], {1 ,2 }, facs2[i], {1 ,3 }, 0.0 , temp, {2 ,3 });
599+ auto * ptr_temp = temp.data ();
600+ for (ord_t r = 0 ; r < rank_ * rank_; ++r)
601+ *(ptr_RRp + r) *= *(ptr_temp + r);
602+ }
603+ Tensor temp;
604+ auto last = num_factors - 2 ;
605+ contract (1.0 , facs1[last], {1 ,2 }, facs2[last], {1 ,3 }, 0.0 , temp, {2 ,3 });
606+ return btas::dot (RRp, temp);
607+ }
608+
609+ };
610+
611+ /* *
612+ \breif This is a class that computes the difference in two fits
613+ /| T - T^{i} \|^2 - \| T - T^{i + 1}\|^2 = T^{i}^2 - 2 TT^{i} + 2 TT^{i+1} - T^{i+1}^2
614+ **/
615+ template <typename Tensor>
616+ class DiffFitCheck {
617+ using ind_t = typename Tensor::range_type::index_type::value_type;
618+ using ord_t = typename range_traits<typename Tensor::range_type>::ordinal_type;
619+ using dtype = typename Tensor::value_type;
620+
621+ public:
622+ // / constructor for the base convergence test object
623+ // / \param[in] tol tolerance for ALS convergence
624+ explicit DiffFitCheck (double tol = 1e-3 ) : iter_(0 ), tol_(tol){
625+ }
626+
627+ ~DiffFitCheck () = default ;
628+
629+ // / Function to check convergence of the ALS problem
630+ // / convergence when \f$ \sum_n^{ndim} \frac{\|A^{i}_n - A^{i+1}_n\|}{dim(A^{i}_n} \leq \epsilon \f$
631+ // / \param[in] btas_factors Current set of factor matrices
632+
633+ bool operator () (std::vector<Tensor> & btas_factors,
634+ const std::vector<Tensor> & V = std::vector<Tensor>()) {
635+ auto rank_ = btas_factors[1 ].extent (1 );
636+ auto n = btas_factors.size () - 1 ;
637+ auto & lambda = btas_factors[n];
638+ auto fit = 0.0 ;
639+ if (iter_ == 0 ) {
640+ fit_prev_ = sqrt (abs (norm (V, lambda, rank_) - 2.0 * abs (compute_inner_product (btas_factors[n - 1 ], lambda))));
641+ if (verbose_) {
642+ std::cout << rank_ << " \t " << iter_ << " \t " << 1.0 << std::endl;
643+ }
644+ ++iter_;
645+ return false ;
646+ }
647+
648+ auto curr_norm = sqrt (abs (norm (V, lambda, rank_) - 2.0 * abs (compute_inner_product (btas_factors[n - 1 ], lambda))));
649+ fit = sqrt (abs (fit_prev_ * fit_prev_ - curr_norm * curr_norm));
650+ fit_prev_ = curr_norm;
651+
652+ if (verbose_) {
653+ std::cout << rank_ << " \t " << iter_ << " \t " << fit << std::endl;
654+ }
655+ ++iter_;
656+ if (fit < tol_) {
657+ ++converged_num;
658+ if (converged_num > 1 ) {
659+ return true ;
660+ }
661+ }
662+ return false ;
663+ }
664+
665+ void verbose (bool verb){
666+ verbose_ = verb;
667+ }
668+
669+ void set_MtKRP (Tensor & MtKRP){
670+ MtKRP_ = MtKRP;
671+ }
672+
673+ private:
674+ double tol_;
675+ bool verbose_ = false ;
676+ double fit_prev_;
677+ Tensor MtKRP_;
678+ size_t converged_num = 0 ;
679+ size_t iter_;
680+
681+ dtype compute_inner_product (Tensor &last_factor, Tensor& lambda){
682+ ord_t size = last_factor.size ();
683+ ind_t rank = last_factor.extent (1 );
684+ auto *ptr_A = last_factor.data ();
685+ auto *ptr_MtKRP = MtKRP_.data ();
686+ auto lam_ptr = lambda.data ();
687+ dtype iprod = 0.0 ;
688+ for (ord_t i = 0 ; i < size; ++i) {
689+ iprod += *(ptr_MtKRP + i) * btas::impl::conj (*(ptr_A + i)) * *(lam_ptr + i % rank);
690+ }
691+ return iprod;
692+ }
693+
694+ double norm (const std::vector<Tensor> &V, Tensor & lambda, ind_t rank_) {
695+ auto n = V.size ();
696+ Tensor coeffMat;
697+ typename Tensor::value_type one = 1.0 ;
698+ ger (one, lambda.conj (), lambda, coeffMat);
699+
700+ auto rank2 = rank_ * (ord_t )rank_;
701+ Tensor temp (rank_, rank_);
702+
703+ auto *ptr_coeff = coeffMat.data ();
704+ for (size_t i = 0 ; i < n; ++i) {
705+ auto *ptr_V = V[i].data ();
706+ for (ord_t j = 0 ; j < rank2; ++j) {
707+ *(ptr_coeff + j) *= *(ptr_V + j);
708+ }
709+ }
710+
711+ dtype nrm = 0.0 ;
712+ for (auto &i : coeffMat) {
713+ nrm += i;
714+ }
715+ return nrm;
716+ }
717+
718+ double norm (std::vector<Tensor> & facs1, std::vector<Tensor>& facs2, ind_t rank_){
719+ BTAS_ASSERT (facs1.size () == facs2.size ());
720+ ind_t num_factors = facs1.size ();
721+ Tensor RRp;
722+ Tensor T1 = facs1[0 ], T2 = facs2[0 ];
723+ auto lam_ptr1 = facs1[num_factors - 1 ].data (),
724+ lam_ptr2 = facs2[num_factors - 1 ].data ();
725+ for (ind_t i = 0 ; i < rank_; i++) {
726+ scal (T1.extent (0 ), *(lam_ptr1 + i), std::begin (T1) + i, rank_);
727+ scal (T2.extent (0 ), *(lam_ptr2 + i), std::begin (T2) + i, rank_);
728+ }
729+
730+ contract (1.0 , T1, {1 ,2 }, T2, {1 ,3 }, 0.0 , RRp, {2 ,3 });
731+
732+ for (ind_t i = 0 ; i < rank_; i++) {
733+ auto val1 = *(lam_ptr1 + i),
734+ val2 = *(lam_ptr2 + i);
735+ scal (T1.extent (0 ), (abs (val1) > 1e-12 ? 1.0 /val1 : 1.0 ), std::begin (T1) + i, rank_);
736+ scal (T2.extent (0 ), (abs (val2) > 1e-12 ? 1.0 /val2 : 1.0 ), std::begin (T2) + i, rank_);
737+ }
738+
739+ auto * ptr_RRp = RRp.data ();
740+ for (ind_t i = 1 ; i < num_factors - 3 ; ++i) {
741+ Tensor temp;
742+ contract (1.0 , facs1[i], {1 ,2 }, facs2[i], {1 ,3 }, 0.0 , temp, {2 ,3 });
743+ auto * ptr_temp = temp.data ();
744+ for (ord_t r = 0 ; r < rank_ * rank_; ++r)
745+ *(ptr_RRp + r) *= *(ptr_temp + r);
746+ }
747+ Tensor temp;
748+ auto last = num_factors - 2 ;
749+ contract (1.0 , facs1[last], {1 ,2 }, facs2[last], {1 ,3 }, 0.0 , temp, {2 ,3 });
750+ return btas::dot (RRp, temp);
751+ }
752+
753+ };
424754} // namespace btas
425755#endif // BTAS_GENERIC_CONV_BASE_CLASS
0 commit comments