Skip to content

Commit f4c5c40

Browse files
committed
ConjugateGradientSolver works with Eigen types
1 parent 4765265 commit f4c5c40

File tree

6 files changed

+122
-35
lines changed

6 files changed

+122
-35
lines changed

src/TiledArray/math/linalg/basic.h

Lines changed: 48 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,17 @@ inline void vec_multiply(
155155
a1.array() *= a2.array();
156156
}
157157

158+
template <typename Derived>
159+
inline auto clone(const Eigen::MatrixBase<Derived>& a) {
160+
return a.eval();
161+
}
162+
163+
template <typename XprType1, int BlockRows1, int BlockCols1, bool InnerPanel1>
164+
inline auto clone(
165+
const Eigen::Block<XprType1, BlockRows1, BlockCols1, InnerPanel1>& a) {
166+
return a.eval();
167+
}
168+
158169
template <typename Derived, typename S>
159170
inline void scale(Eigen::MatrixBase<Derived>& a, S scaling_factor) {
160171
using numeric_type = typename Eigen::MatrixBase<Derived>::value_type;
@@ -239,6 +250,21 @@ inline auto norm2(
239250
return m.template lpNorm<2>();
240251
}
241252

253+
template <typename Derived>
254+
inline auto volume(const Eigen::MatrixBase<Derived>& m) {
255+
return m.size();
256+
}
257+
258+
template <typename Derived>
259+
inline auto abs_min(const Eigen::MatrixBase<Derived>& m) {
260+
return m.array().abs().minCoeff();
261+
}
262+
263+
template <typename Derived>
264+
inline auto abs_max(const Eigen::MatrixBase<Derived>& m) {
265+
return m.array().abs().maxCoeff();
266+
}
267+
242268
} // namespace Eigen
243269

244270
#ifndef TILEDARRAY_MATH_LINALG_DISPATCH_W_TTG
@@ -253,12 +279,12 @@ inline auto norm2(
253279
return scalapack::FN; \
254280
return non_distributed::FN;
255281
#elif (TILEDARRAY_HAS_TTG && !TILEDARRAY_HAS_SCALAPACK)
256-
#define TILEDARRAY_MATH_LINALG_DISPATCH_W_TTG(FN, MATRIX) \
257-
TA_MAX_THREADS; \
258-
if (get_linalg_backend() == LinearAlgebraBackend::TTG || \
259-
TiledArray::math::linalg::detail::prefer_distributed(MATRIX)) \
260-
return TiledArray::math::linalg::ttg::FN; \
261-
if (get_linalg_backend() == LinearAlgebraBackend::ScaLAPACK) \
282+
#define TILEDARRAY_MATH_LINALG_DISPATCH_W_TTG(FN, MATRIX) \
283+
TA_MAX_THREADS; \
284+
if (get_linalg_backend() == LinearAlgebraBackend::TTG || \
285+
TiledArray::math::linalg::detail::prefer_distributed(MATRIX)) \
286+
return TiledArray::math::linalg::ttg::FN; \
287+
if (get_linalg_backend() == LinearAlgebraBackend::ScaLAPACK) \
262288
TA_EXCEPTION("ScaLAPACK linear algebra backend is not available"); \
263289
return non_distributed::FN;
264290
#elif !TILEDARRAY_HAS_TTG && TILEDARRAY_HAS_SCALAPACK
@@ -271,11 +297,11 @@ inline auto norm2(
271297
return scalapack::FN; \
272298
return non_distributed::FN;
273299
#else // !TILEDARRAY_HAS_TTG && !TILEDARRAY_HAS_SCALAPACK
274-
#define TILEDARRAY_MATH_LINALG_DISPATCH_W_TTG(FN, MATRIX) \
275-
TA_MAX_THREADS; \
276-
if (get_linalg_backend() == LinearAlgebraBackend::TTG) \
277-
TA_EXCEPTION("TTG linear algebra backend is not available"); \
278-
if (get_linalg_backend() == LinearAlgebraBackend::ScaLAPACK) \
300+
#define TILEDARRAY_MATH_LINALG_DISPATCH_W_TTG(FN, MATRIX) \
301+
TA_MAX_THREADS; \
302+
if (get_linalg_backend() == LinearAlgebraBackend::TTG) \
303+
TA_EXCEPTION("TTG linear algebra backend is not available"); \
304+
if (get_linalg_backend() == LinearAlgebraBackend::ScaLAPACK) \
279305
TA_EXCEPTION("ScaLAPACK linear algebra backend is not available"); \
280306
return non_distributed::FN;
281307
#endif // !TILEDARRAY_HAS_TTG && !TILEDARRAY_HAS_SCALAPACK
@@ -297,12 +323,12 @@ inline auto norm2(
297323
return scalapack::FN; \
298324
return non_distributed::FN;
299325
#elif TILEDARRAY_HAS_TTG && !TILEDARRAY_HAS_SCALAPACK
300-
#define TILEDARRAY_MATH_LINALG_DISPATCH_WO_TTG(FN, MATRIX) \
301-
TA_MAX_THREADS; \
302-
if (get_linalg_backend() == LinearAlgebraBackend::TTG) \
303-
TA_EXCEPTION(TILEDARRAY_MATH_LINALG_DISPATCH_WO_TTG_STRINGIFY( \
304-
FN) " is not provided by the TTG backend"); \
305-
if (get_linalg_backend() == LinearAlgebraBackend::ScaLAPACK) \
326+
#define TILEDARRAY_MATH_LINALG_DISPATCH_WO_TTG(FN, MATRIX) \
327+
TA_MAX_THREADS; \
328+
if (get_linalg_backend() == LinearAlgebraBackend::TTG) \
329+
TA_EXCEPTION(TILEDARRAY_MATH_LINALG_DISPATCH_WO_TTG_STRINGIFY( \
330+
FN) " is not provided by the TTG backend"); \
331+
if (get_linalg_backend() == LinearAlgebraBackend::ScaLAPACK) \
306332
TA_EXCEPTION("ScaLAPACK linear algebra backend is not available"); \
307333
return non_distributed::FN;
308334
#elif !TILEDARRAY_HAS_TTG && TILEDARRAY_HAS_SCALAPACK
@@ -315,11 +341,11 @@ inline auto norm2(
315341
return scalapack::FN; \
316342
return non_distributed::FN;
317343
#else // !TILEDARRAY_HAS_TTG && !TILEDARRAY_HAS_SCALAPACK
318-
#define TILEDARRAY_MATH_LINALG_DISPATCH_WO_TTG(FN, MATRIX) \
319-
TA_MAX_THREADS; \
320-
if (get_linalg_backend() == LinearAlgebraBackend::TTG) \
321-
TA_EXCEPTION("TTG linear algebra backend is not available"); \
322-
if (get_linalg_backend() == LinearAlgebraBackend::ScaLAPACK) \
344+
#define TILEDARRAY_MATH_LINALG_DISPATCH_WO_TTG(FN, MATRIX) \
345+
TA_MAX_THREADS; \
346+
if (get_linalg_backend() == LinearAlgebraBackend::TTG) \
347+
TA_EXCEPTION("TTG linear algebra backend is not available"); \
348+
if (get_linalg_backend() == LinearAlgebraBackend::ScaLAPACK) \
323349
TA_EXCEPTION("ScaLAPACK linear algebra backend is not available"); \
324350
return non_distributed::FN;
325351
#endif // !TILEDARRAY_HAS_TTG && !TILEDARRAY_HAS_SCALAPACK

src/TiledArray/math/solvers/conjgrad.h

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#include <TiledArray/math/linalg/basic.h>
3030
#include <TiledArray/math/solvers/diis.h>
3131
#include "TiledArray/dist_array.h"
32+
#include "TiledArray/type_traits.h"
3233

3334
namespace TiledArray::math {
3435

@@ -44,8 +45,8 @@ namespace TiledArray::math {
4445
/// stand-alone functions:
4546
/// \li <tt> std::size_t volume(const D&) </tt> (returns the total number of elements)
4647
/// \li <tt> D clone(const D&) </tt>, returns a deep copy
47-
/// \li <tt> value_type minabs_value(const D&) </tt>
48-
/// \li <tt> value_type maxabs_value(const D&) </tt>
48+
/// \li <tt> value_type abs_min(const D&) </tt>
49+
/// \li <tt> value_type abs_max(const D&) </tt>
4950
/// \li <tt> void vec_multiply(D& a, const D& b) </tt> (element-wise multiply
5051
/// of \c a by \c b )
5152
/// \li <tt> value_type inner_product(const D& a, const D& b) </tt>
@@ -60,7 +61,7 @@ namespace TiledArray::math {
6061
// clang-format on
6162
template <typename D, typename F>
6263
struct ConjugateGradientSolver {
63-
typedef typename D::numeric_type value_type;
64+
typedef TiledArray::detail::numeric_t<D> value_type;
6465

6566
/// \param a object of type F
6667
/// \param b RHS
@@ -73,8 +74,8 @@ struct ConjugateGradientSolver {
7374
value_type convergence_target = -1.0) {
7475
std::size_t n = volume(preconditioner);
7576

76-
const bool use_diis = false;
77-
DIIS<D> diis;
77+
constexpr bool use_diis = false;
78+
std::conditional_t<use_diis, DIIS<D>, char> diis{};
7879

7980
// solution vector
8081
D XX_i;
@@ -120,7 +121,7 @@ struct ConjugateGradientSolver {
120121
scale(RR_i, -1.0);
121122
axpy(RR_i, 1.0, b); // RR_i = b - a(XX_i)
122123

123-
if (use_diis) diis.extrapolate(XX_i, RR_i, true);
124+
if constexpr (use_diis) diis.extrapolate(XX_i, RR_i, true);
124125

125126
// z_0 = D^-1 . r_0
126127
ZZ_i = RR_i;
@@ -144,7 +145,7 @@ struct ConjugateGradientSolver {
144145
// r_i -= alpha_i Ap_i
145146
axpy(RR_i, -alpha_i, APP_i);
146147

147-
if (use_diis) diis.extrapolate(XX_i, RR_i, true);
148+
if constexpr (use_diis) diis.extrapolate(XX_i, RR_i, true);
148149

149150
const value_type r_ip1_norm = norm2(RR_i) / rhs_size;
150151
if (r_ip1_norm < convergence_target) {

src/TiledArray/math/solvers/diis.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#include <TiledArray/math/linalg/basic.h>
3030
#include "TiledArray/dist_array.h"
3131
#include "TiledArray/external/eigen.h"
32+
#include "TiledArray/type_traits.h"
3233

3334
#include <Eigen/QR>
3435
#include <deque>
@@ -82,7 +83,7 @@ namespace TiledArray::math {
8283
template <typename D>
8384
class DIIS {
8485
public:
85-
typedef typename D::numeric_type value_type;
86+
typedef TiledArray::detail::numeric_t<D> value_type;
8687
typedef typename TiledArray::detail::scalar_t<value_type> scalar_type;
8788
typedef Eigen::Matrix<value_type, Eigen::Dynamic, Eigen::Dynamic,
8889
Eigen::RowMajor>

src/TiledArray/tile_interface/clone.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ namespace TiledArray {
3636
/// \tparam Arg The tile argument type
3737
/// \param arg The tile argument to be permuted
3838
/// \return A (deep) copy of \c arg
39-
template <typename Arg>
39+
template <typename Arg, typename = decltype(std::declval<const Arg&>().clone())>
4040
inline auto clone(const Arg& arg) {
4141
return arg.clone();
4242
}

src/TiledArray/tile_op/tile_interface.h

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -969,7 +969,8 @@ inline auto min(const Arg& arg) {
969969
/// \tparam Arg The tile argument type
970970
/// \param arg The argument to find the maximum
971971
/// \return A scalar that is equal to <tt>abs(max(arg))</tt>
972-
template <typename Arg>
972+
template <typename Arg,
973+
typename = decltype(std::declval<const Arg&>().abs_max())>
973974
inline auto abs_max(const Arg& arg) {
974975
return arg.abs_max();
975976
}
@@ -979,7 +980,8 @@ inline auto abs_max(const Arg& arg) {
979980
/// \tparam Arg The tile argument type
980981
/// \param arg The argument to find the minimum
981982
/// \return A scalar that is equal to <tt>abs(min(arg))</tt>
982-
template <typename Arg>
983+
template <typename Arg,
984+
typename = decltype(std::declval<const Arg&>().abs_min())>
983985
inline auto abs_min(const Arg& arg) {
984986
return arg.abs_min();
985987
}
@@ -991,7 +993,9 @@ inline auto abs_min(const Arg& arg) {
991993
/// \param left The left-hand argument tile
992994
/// \param right The right-hand argument tile
993995
/// \return A scalar that is equal to <tt>sum_i left[i] * right[i]</tt>
994-
template <typename Left, typename Right>
996+
template <typename Left, typename Right,
997+
typename = decltype(std::declval<const Left&>().dot(
998+
std::declval<const Right&>()))>
995999
inline auto dot(const Left& left, const Right& right) {
9961000
return left.dot(right);
9971001
}
@@ -1003,7 +1007,9 @@ inline auto dot(const Left& left, const Right& right) {
10031007
/// \param left The left-hand argument tile
10041008
/// \param right The right-hand argument tile
10051009
/// \return A scalar that is equal to <tt>sum_i conj(left[i]) * right[i]</tt>
1006-
template <typename Left, typename Right>
1010+
template <typename Left, typename Right,
1011+
typename = decltype(std::declval<const Left&>().inner_product(
1012+
std::declval<const Right&>()))>
10071013
inline auto inner_product(const Left& left, const Right& right) {
10081014
return left.inner_product(right);
10091015
}

tests/solvers.cpp

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,49 @@ struct validate<DistArray<Tile, Policy>> {
167167
}
168168
};
169169

170+
// Eigen specializations
171+
172+
template <>
173+
struct make_Ax<Eigen::VectorXd> {
174+
using T = Eigen::VectorXd;
175+
176+
struct Ax {
177+
Ax() : A_(3, 3) { A_ << 1, 2, 3, 2, 5, 8, 3, 8, 15; }
178+
void operator()(const T& x, T& result) const { result = A_ * x; }
179+
Eigen::MatrixXd A_;
180+
};
181+
Ax operator()() const { return Ax{}; }
182+
};
183+
184+
template <>
185+
struct make_b<Eigen::VectorXd> {
186+
using T = Eigen::VectorXd;
187+
188+
T operator()() const {
189+
T result(3);
190+
result << 1, 4, 0;
191+
return result;
192+
}
193+
};
194+
195+
template <>
196+
struct make_pc<Eigen::VectorXd> {
197+
using T = Eigen::VectorXd;
198+
199+
T operator()() const { return T::Ones(3); }
200+
};
201+
202+
template <>
203+
struct validate<Eigen::VectorXd> {
204+
using T = Eigen::VectorXd;
205+
206+
bool operator()(const T& x) const {
207+
T ref(3);
208+
ref << -6.5, 9., -3.5;
209+
return (x - ref).norm() < 1e-11;
210+
}
211+
};
212+
170213
BOOST_AUTO_TEST_SUITE(solvers)
171214

172215
BOOST_AUTO_TEST_CASE_TEMPLATE(conjugate_gradient, Array, array_types) {
@@ -178,4 +221,14 @@ BOOST_AUTO_TEST_CASE_TEMPLATE(conjugate_gradient, Array, array_types) {
178221
BOOST_CHECK(validate<Array>{}(x));
179222
}
180223

224+
BOOST_AUTO_TEST_CASE(conjugate_gradient_eigen) {
225+
using T = Eigen::VectorXd;
226+
auto Ax = make_Ax<T>{}();
227+
auto b = make_b<T>{}();
228+
auto pc = make_pc<T>{}();
229+
T x;
230+
ConjugateGradientSolver<T, decltype(Ax)>{}(Ax, b, x, pc, 1e-11);
231+
BOOST_CHECK(validate<T>{}(x));
232+
}
233+
181234
BOOST_AUTO_TEST_SUITE_END()

0 commit comments

Comments
 (0)