|
1 | 1 | #include <tiledarray.h> |
2 | 2 | #include <random> |
3 | 3 | #include "TiledArray/config.h" |
4 | | -//#include "range_fixture.h" |
5 | 4 | #include "unit_test_config.h" |
6 | 5 |
|
7 | 6 | #include "TiledArray/math/linalg/non-distributed/cholesky.h" |
@@ -469,26 +468,34 @@ BOOST_AUTO_TEST_CASE(heig_same_tiling) { |
469 | 468 | return this->make_ta_reference(t, range); |
470 | 469 | }); |
471 | 470 |
|
472 | | - auto [evals, evecs] = non_dist::heig(ref_ta); |
| 471 | + auto [evals, evecs] = heig(ref_ta); |
473 | 472 | auto [evals_non_dist, evecs_non_dist] = non_dist::heig(ref_ta); |
474 | | - // auto evals = heig( ref_ta ); |
475 | 473 |
|
476 | 474 | BOOST_CHECK(evecs.trange() == ref_ta.trange()); |
477 | 475 |
|
478 | | - // check eigenvectors against non_dist only, for now ... |
479 | | - decltype(evecs) evecs_error; |
480 | | - evecs_error("i,j") = evecs_non_dist("i,j") - evecs("i,j"); |
481 | | - // TODO need to fix phases of the eigenvectors to be able to compare ... |
482 | | - // BOOST_CHECK_SMALL(evecs_error("i,j").norm().get(), |
483 | | - // N * N * std::numeric_limits<double>::epsilon()); |
484 | | - |
485 | 476 | // Check eigenvalue correctness |
486 | 477 | double tol = N * N * std::numeric_limits<double>::epsilon(); |
487 | 478 | for (int64_t i = 0; i < N; ++i) { |
488 | 479 | BOOST_CHECK_SMALL(std::abs(evals[i] - exact_evals[i]), tol); |
489 | 480 | BOOST_CHECK_SMALL(std::abs(evals_non_dist[i] - exact_evals[i]), tol); |
490 | 481 | } |
491 | 482 |
|
| 483 | + // check eigenvectors by reconstruction |
| 484 | + auto reconstruction_check = [&](const auto& s, const auto& U, |
| 485 | + const auto str) { |
| 486 | + using Array = TA::TArray<double>; |
| 487 | + auto S = |
| 488 | + TA::diagonal_array<Array>(U.world(), U.trange(), s.begin(), s.end()); |
| 489 | + Array err; |
| 490 | + err("i,j") = U("i,k") * S("k,l") * U("j,l").conj() - ref_ta("i,j"); |
| 491 | + auto err_l2 = TA::norm2(err); |
| 492 | + const double epsilon = N * N * std::numeric_limits<double>::epsilon(); |
| 493 | + BOOST_CHECK(err_l2 < epsilon); |
| 494 | + // std::cout << str << " ||U s U† - A||_2 = " << err_l2 << std::endl; |
| 495 | + }; |
| 496 | + reconstruction_check(evals, evecs, "heig"); |
| 497 | + reconstruction_check(evals_non_dist, evecs_non_dist, "non_dist::heig"); |
| 498 | + |
492 | 499 | GlobalFixture::world->gop.fence(); |
493 | 500 | } |
494 | 501 |
|
|
0 commit comments