Skip to content

Commit 9c8c8f6

Browse files
authored
Merge pull request #182 from ValeevGroup/kmp5/feature/cpB
Small updates to BTAS
2 parents 62d57d9 + 1c3307a commit 9c8c8f6

File tree

7 files changed

+497
-75
lines changed

7 files changed

+497
-75
lines changed

btas/generic/converge_class.h

Lines changed: 335 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44
#include <vector>
55
#include <iomanip>
66
#include <btas/generic/dot_impl.h>
7+
#include <btas/generic/contract.h>
8+
#include <btas/generic/reconstruct.h>
9+
#include <btas/generic/scal_impl.h>
710
#include <btas/varray/varray.h>
811

912
namespace btas {
@@ -23,7 +26,7 @@ namespace btas {
2326
public:
2427
/// constructor for the base convergence test object
2528
/// \param[in] tol tolerance for ALS convergence
26-
explicit NormCheck(double tol = 1e-3) : tol_(tol) {
29+
explicit NormCheck(double tol = 1e-3) : tol_(tol), iter_(0) {
2730
}
2831

2932
~NormCheck() = default;
@@ -50,17 +53,34 @@ namespace btas {
5053
prev[r] = btas_factors[r];
5154
}
5255

56+
if (verbose_) {
57+
std::cout << rank_ << "\t" << iter_ << "\t" << std::setprecision(16) << diff << std::endl;
58+
}
5359
if (diff < this->tol_) {
5460
return true;
5561
}
62+
++iter_;
63+
5664
return false;
5765
}
5866

67+
/// Option to print fit and change in fit in the () operator call
68+
/// \param[in] verb bool which turns off/on fit printing.
69+
void verbose(bool verb) {
70+
verbose_ = verb;
71+
}
72+
73+
double get_fit(bool hit_max_iters = false){
74+
75+
}
76+
5977
private:
6078
double tol_;
6179
std::vector<Tensor> prev; // Set of previous factor matrices
6280
size_t ndim; // Number of factor matrices
6381
ind_t rank_; // Rank of the CP problem
82+
bool verbose_ = false;
83+
size_t iter_;
6484
};
6585

6686
/**
@@ -105,7 +125,7 @@ namespace btas {
105125
}
106126

107127
double normFactors = norm(btas_factors, V);
108-
double normResidual = sqrt(abs(normT_ * normT_ + normFactors * normFactors - 2 * abs(iprod)));
128+
double normResidual = sqrt(abs(normT_ * normT_ + normFactors - 2 * abs(iprod)));
109129
double fit = 1. - (normResidual / normT_);
110130

111131
double fitChange = abs(fitOld_ - fit);
@@ -208,11 +228,11 @@ namespace btas {
208228
}
209229
}
210230

211-
dtype nrm = 0.0;
231+
RT nrm = 0.0;
212232
for (auto &i : coeffMat) {
213-
nrm += i;
233+
nrm += std::real(i);
214234
}
215-
return sqrt(abs(nrm));
235+
return nrm;
216236
}
217237
};
218238

@@ -421,5 +441,315 @@ namespace btas {
421441
return sqrt(abs(nrm));
422442
}
423443
};
444+
445+
/**
446+
\breif Class used to decide when ALS problem is converged.
447+
The fit is not computed and the optimization just runs until nALS is
448+
reached.
449+
**/
450+
template <typename Tensor>
451+
class NoCheck {
452+
using ind_t = typename Tensor::range_type::index_type::value_type;
453+
using ord_t = typename range_traits<typename Tensor::range_type>::ordinal_type;
454+
455+
public:
456+
/// constructor for the base convergence test object
457+
/// \param[in] tol tolerance for ALS convergence
458+
explicit NoCheck(double tol = 1e-3) : iter_(0){
459+
}
460+
461+
~NoCheck() = default;
462+
463+
/// Function to check convergence of the ALS problem
464+
/// convergence when \f$ \sum_n^{ndim} \frac{\|A^{i}_n - A^{i+1}_n\|}{dim(A^{i}_n} \leq \epsilon \f$
465+
/// \param[in] btas_factors Current set of factor matrices
466+
bool operator () (const std::vector<Tensor> &btas_factors,
467+
const std::vector<Tensor> & V = std::vector<Tensor>()){
468+
auto rank_ = btas_factors[1].extent(1);
469+
if (verbose_) {
470+
std::cout << rank_ << "\t" << iter_ << std::endl;
471+
}
472+
++iter_;
473+
474+
return false;
475+
}
476+
477+
/// Option to print fit and change in fit in the () operator call
478+
/// \param[in] verb bool which turns off/on fit printing.
479+
void verbose(bool verb) {
480+
verbose_ = verb;
481+
}
482+
483+
double get_fit(bool hit_max_iters = false){
484+
485+
}
486+
487+
private:
488+
double tol_;
489+
bool verbose_ = false;
490+
size_t iter_;
491+
Tensor prevT_;
492+
};
493+
494+
/// This class is going to take a tensor approximation
495+
/// and compare it to the previous tensor approximation
496+
/// Skipping the total fit and directly computing the relative fit
497+
template <typename Tensor>
498+
class ApproxFitCheck{
499+
using ind_t = typename Tensor::range_type::index_type::value_type;
500+
using ord_t = typename range_traits<typename Tensor::range_type>::ordinal_type;
501+
502+
public:
503+
/// constructor for the base convergence test object
504+
/// \param[in] tol tolerance for ALS convergence
505+
explicit ApproxFitCheck(double tol = 1e-3) : iter_(0), tol_(tol){
506+
}
507+
508+
~ApproxFitCheck() = default;
509+
510+
/// Function to check convergence of the ALS problem
511+
/// convergence when \f$ \sum_n^{ndim} \frac{\|A^{i}_n - A^{i+1}_n\|}{dim(A^{i}_n} \leq \epsilon \f$
512+
/// \param[in] btas_factors Current set of factor matrices
513+
514+
bool operator () (std::vector<Tensor> & btas_factors,
515+
const std::vector<Tensor> & V = std::vector<Tensor>()) {
516+
auto rank_ = btas_factors[1].extent(1);
517+
518+
auto fit = 0.0;
519+
if(iter_ == 0) {
520+
fit_prev_ = (norm(btas_factors, btas_factors, rank_));
521+
norm_prev_ = sqrt(fit_prev_);
522+
prev_factors = btas_factors;
523+
// diff = reconstruct(btas_factors, orders);
524+
if (verbose_) {
525+
std::cout << rank_ << "\t" << iter_ << "\t" << 1.0 << std::endl;
526+
}
527+
++iter_;
528+
return false;
529+
}
530+
531+
auto curr_norm = norm(btas_factors, btas_factors, rank_);
532+
fit = sqrt(fit_prev_ - 2 * norm(prev_factors, btas_factors, rank_) + curr_norm) / norm_prev_;
533+
// fit = norm(diff);
534+
// diff = tnew;
535+
fit_prev_ = curr_norm;
536+
norm_prev_ = sqrt(curr_norm);
537+
prev_factors = btas_factors;
538+
539+
if (verbose_) {
540+
std::cout << rank_ << "\t" << iter_ << "\t" << fit << std::endl;
541+
}
542+
++iter_;
543+
if (fit < tol_) {
544+
++converged_num;
545+
if(converged_num > 1) {
546+
iter_ = 0;
547+
return true;
548+
}
549+
}
550+
return false;
551+
}
552+
553+
void verbose(bool verb){
554+
verbose_ = verb;
555+
}
556+
557+
private:
558+
double tol_;
559+
bool verbose_ = false;
560+
double fit_prev_, norm_prev_;
561+
std::vector<size_t> orders;
562+
std::vector<Tensor> prev_factors;
563+
// Tensor diff;
564+
size_t converged_num = 0;
565+
size_t iter_;
566+
567+
double norm(Tensor& a){
568+
auto n = 0.0;
569+
for(auto & i : a)
570+
n += i * i;
571+
return sqrt(n);
572+
}
573+
574+
double norm(std::vector<Tensor> & facs1, std::vector<Tensor>& facs2, ind_t rank_){
575+
BTAS_ASSERT(facs1.size() == facs2.size());
576+
ind_t num_factors = facs1.size();
577+
Tensor RRp;
578+
Tensor T1 = facs1[0], T2 = facs2[0];
579+
auto lam_ptr1 = facs1[num_factors - 1].data(),
580+
lam_ptr2 = facs2[num_factors - 1].data();
581+
for (ind_t i = 0; i < rank_; i++) {
582+
scal(T1.extent(0), *(lam_ptr1 + i), std::begin(T1) + i, rank_);
583+
scal(T2.extent(0), *(lam_ptr2 + i), std::begin(T2) + i, rank_);
584+
}
585+
586+
contract(1.0, T1, {1,2}, T2, {1,3}, 0.0, RRp, {2,3});
587+
588+
for (ind_t i = 0; i < rank_; i++) {
589+
auto val1 = *(lam_ptr1 + i),
590+
val2 = *(lam_ptr2 + i);
591+
scal(T1.extent(0), (abs(val1) > 1e-12 ? 1.0/val1 : 1.0), std::begin(T1) + i, rank_);
592+
scal(T2.extent(0), (abs(val2) > 1e-12 ? 1.0/val2 : 1.0), std::begin(T2) + i, rank_);
593+
}
594+
595+
auto * ptr_RRp = RRp.data();
596+
for (ind_t i = 1; i < num_factors - 3; ++i) {
597+
Tensor temp;
598+
contract(1.0, facs1[i], {1,2}, facs2[i], {1,3}, 0.0, temp, {2,3});
599+
auto * ptr_temp = temp.data();
600+
for(ord_t r = 0; r < rank_ * rank_; ++r)
601+
*(ptr_RRp + r) *= *(ptr_temp + r);
602+
}
603+
Tensor temp;
604+
auto last = num_factors - 2;
605+
contract(1.0, facs1[last], {1,2}, facs2[last], {1,3}, 0.0, temp, {2,3});
606+
return btas::dot(RRp, temp);
607+
}
608+
609+
};
610+
611+
/**
612+
\breif This is a class that computes the difference in two fits
613+
/| T - T^{i} \|^2 - \| T - T^{i + 1}\|^2 = T^{i}^2 - 2 TT^{i} + 2 TT^{i+1} - T^{i+1}^2
614+
**/
615+
template <typename Tensor>
616+
class DiffFitCheck{
617+
using ind_t = typename Tensor::range_type::index_type::value_type;
618+
using ord_t = typename range_traits<typename Tensor::range_type>::ordinal_type;
619+
using dtype = typename Tensor::value_type;
620+
621+
public:
622+
/// constructor for the base convergence test object
623+
/// \param[in] tol tolerance for ALS convergence
624+
explicit DiffFitCheck(double tol = 1e-3) : iter_(0), tol_(tol){
625+
}
626+
627+
~DiffFitCheck() = default;
628+
629+
/// Function to check convergence of the ALS problem
630+
/// convergence when \f$ \sum_n^{ndim} \frac{\|A^{i}_n - A^{i+1}_n\|}{dim(A^{i}_n} \leq \epsilon \f$
631+
/// \param[in] btas_factors Current set of factor matrices
632+
633+
bool operator () (std::vector<Tensor> & btas_factors,
634+
const std::vector<Tensor> & V = std::vector<Tensor>()) {
635+
auto rank_ = btas_factors[1].extent(1);
636+
auto n = btas_factors.size() - 1;
637+
auto & lambda = btas_factors[n];
638+
auto fit = 0.0;
639+
if(iter_ == 0) {
640+
fit_prev_ = sqrt(abs(norm(V, lambda, rank_) - 2.0 * abs(compute_inner_product(btas_factors[n - 1], lambda))));
641+
if (verbose_) {
642+
std::cout << rank_ << "\t" << iter_ << "\t" << 1.0 << std::endl;
643+
}
644+
++iter_;
645+
return false;
646+
}
647+
648+
auto curr_norm = sqrt(abs(norm(V, lambda, rank_) - 2.0 * abs(compute_inner_product(btas_factors[n - 1], lambda))));
649+
fit = sqrt(abs(fit_prev_ * fit_prev_ - curr_norm * curr_norm));
650+
fit_prev_ = curr_norm;
651+
652+
if (verbose_) {
653+
std::cout << rank_ << "\t" << iter_ << "\t" << fit << std::endl;
654+
}
655+
++iter_;
656+
if (fit < tol_) {
657+
++converged_num;
658+
if(converged_num > 1) {
659+
return true;
660+
}
661+
}
662+
return false;
663+
}
664+
665+
void verbose(bool verb){
666+
verbose_ = verb;
667+
}
668+
669+
void set_MtKRP(Tensor & MtKRP){
670+
MtKRP_ = MtKRP;
671+
}
672+
673+
private:
674+
double tol_;
675+
bool verbose_ = false;
676+
double fit_prev_;
677+
Tensor MtKRP_;
678+
size_t converged_num = 0;
679+
size_t iter_;
680+
681+
dtype compute_inner_product(Tensor &last_factor, Tensor& lambda){
682+
ord_t size = last_factor.size();
683+
ind_t rank = last_factor.extent(1);
684+
auto *ptr_A = last_factor.data();
685+
auto *ptr_MtKRP = MtKRP_.data();
686+
auto lam_ptr = lambda.data();
687+
dtype iprod = 0.0;
688+
for (ord_t i = 0; i < size; ++i) {
689+
iprod += *(ptr_MtKRP + i) * btas::impl::conj(*(ptr_A + i)) * *(lam_ptr + i % rank);
690+
}
691+
return iprod;
692+
}
693+
694+
double norm(const std::vector<Tensor> &V, Tensor & lambda, ind_t rank_) {
695+
auto n = V.size();
696+
Tensor coeffMat;
697+
typename Tensor::value_type one = 1.0;
698+
ger(one, lambda.conj(), lambda, coeffMat);
699+
700+
auto rank2 = rank_ * (ord_t)rank_;
701+
Tensor temp(rank_, rank_);
702+
703+
auto *ptr_coeff = coeffMat.data();
704+
for (size_t i = 0; i < n; ++i) {
705+
auto *ptr_V = V[i].data();
706+
for (ord_t j = 0; j < rank2; ++j) {
707+
*(ptr_coeff + j) *= *(ptr_V + j);
708+
}
709+
}
710+
711+
dtype nrm = 0.0;
712+
for (auto &i : coeffMat) {
713+
nrm += i;
714+
}
715+
return nrm;
716+
}
717+
718+
double norm(std::vector<Tensor> & facs1, std::vector<Tensor>& facs2, ind_t rank_){
719+
BTAS_ASSERT(facs1.size() == facs2.size());
720+
ind_t num_factors = facs1.size();
721+
Tensor RRp;
722+
Tensor T1 = facs1[0], T2 = facs2[0];
723+
auto lam_ptr1 = facs1[num_factors - 1].data(),
724+
lam_ptr2 = facs2[num_factors - 1].data();
725+
for (ind_t i = 0; i < rank_; i++) {
726+
scal(T1.extent(0), *(lam_ptr1 + i), std::begin(T1) + i, rank_);
727+
scal(T2.extent(0), *(lam_ptr2 + i), std::begin(T2) + i, rank_);
728+
}
729+
730+
contract(1.0, T1, {1,2}, T2, {1,3}, 0.0, RRp, {2,3});
731+
732+
for (ind_t i = 0; i < rank_; i++) {
733+
auto val1 = *(lam_ptr1 + i),
734+
val2 = *(lam_ptr2 + i);
735+
scal(T1.extent(0), (abs(val1) > 1e-12 ? 1.0/val1 : 1.0), std::begin(T1) + i, rank_);
736+
scal(T2.extent(0), (abs(val2) > 1e-12 ? 1.0/val2 : 1.0), std::begin(T2) + i, rank_);
737+
}
738+
739+
auto * ptr_RRp = RRp.data();
740+
for (ind_t i = 1; i < num_factors - 3; ++i) {
741+
Tensor temp;
742+
contract(1.0, facs1[i], {1,2}, facs2[i], {1,3}, 0.0, temp, {2,3});
743+
auto * ptr_temp = temp.data();
744+
for(ord_t r = 0; r < rank_ * rank_; ++r)
745+
*(ptr_RRp + r) *= *(ptr_temp + r);
746+
}
747+
Tensor temp;
748+
auto last = num_factors - 2;
749+
contract(1.0, facs1[last], {1,2}, facs2[last], {1,3}, 0.0, temp, {2,3});
750+
return btas::dot(RRp, temp);
751+
}
752+
753+
};
424754
} //namespace btas
425755
#endif // BTAS_GENERIC_CONV_BASE_CLASS

0 commit comments

Comments
 (0)