Skip to content

Commit a83b78a

Browse files
committed
linalg test: test TA::heig by reconstruction, and eigenvalues against non_dist::heig
1 parent 152dad4 commit a83b78a

File tree

1 file changed

+17
-10
lines changed

1 file changed

+17
-10
lines changed

tests/linalg.cpp

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
#include <tiledarray.h>
22
#include <random>
33
#include "TiledArray/config.h"
4-
//#include "range_fixture.h"
54
#include "unit_test_config.h"
65

76
#include "TiledArray/math/linalg/non-distributed/cholesky.h"
@@ -469,26 +468,34 @@ BOOST_AUTO_TEST_CASE(heig_same_tiling) {
469468
return this->make_ta_reference(t, range);
470469
});
471470

472-
auto [evals, evecs] = non_dist::heig(ref_ta);
471+
auto [evals, evecs] = heig(ref_ta);
473472
auto [evals_non_dist, evecs_non_dist] = non_dist::heig(ref_ta);
474-
// auto evals = heig( ref_ta );
475473

476474
BOOST_CHECK(evecs.trange() == ref_ta.trange());
477475

478-
// check eigenvectors against non_dist only, for now ...
479-
decltype(evecs) evecs_error;
480-
evecs_error("i,j") = evecs_non_dist("i,j") - evecs("i,j");
481-
// TODO need to fix phases of the eigenvectors to be able to compare ...
482-
// BOOST_CHECK_SMALL(evecs_error("i,j").norm().get(),
483-
// N * N * std::numeric_limits<double>::epsilon());
484-
485476
// Check eigenvalue correctness
486477
double tol = N * N * std::numeric_limits<double>::epsilon();
487478
for (int64_t i = 0; i < N; ++i) {
488479
BOOST_CHECK_SMALL(std::abs(evals[i] - exact_evals[i]), tol);
489480
BOOST_CHECK_SMALL(std::abs(evals_non_dist[i] - exact_evals[i]), tol);
490481
}
491482

483+
// check eigenvectors by reconstruction
484+
auto reconstruction_check = [&](const auto& s, const auto& U,
485+
const auto str) {
486+
using Array = TA::TArray<double>;
487+
auto S =
488+
TA::diagonal_array<Array>(U.world(), U.trange(), s.begin(), s.end());
489+
Array err;
490+
err("i,j") = U("i,k") * S("k,l") * U("j,l").conj() - ref_ta("i,j");
491+
auto err_l2 = TA::norm2(err);
492+
const double epsilon = N * N * std::numeric_limits<double>::epsilon();
493+
BOOST_CHECK(err_l2 < epsilon);
494+
// std::cout << str << " ||U s U† - A||_2 = " << err_l2 << std::endl;
495+
};
496+
reconstruction_check(evals, evecs, "heig");
497+
reconstruction_check(evals_non_dist, evecs_non_dist, "non_dist::heig");
498+
492499
GlobalFixture::world->gop.fence();
493500
}
494501

0 commit comments

Comments
 (0)