Skip to content

Commit 457efe3

Browse files
committed
Support multiple precisions
1 parent 11b54fc commit 457efe3

26 files changed

+428
-378
lines changed

CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@ include(CTest)
99

1010
# Options
1111
include(CMakeDependentOption)
12-
set(HYHOUND_DENSE_REAL_TYPE "double" CACHE STRING
13-
"The main floating point type for representing real numbers")
12+
set(HYHOUND_DENSE_REAL_TYPE "double" "float" CACHE STRING
13+
"The floating point types that the functions are instantiated for")
1414
set(HYHOUND_DENSE_INDEX_TYPE "long long" CACHE STRING
1515
"The main integer type for indices and sizes")
1616
# Target options

benchmarks/hyh.cpp

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
#include <algorithm>
1313
#include <cstdlib>
14+
#include <limits>
1415
#include <map>
1516
#include <mutex>
1617
#include <random>
@@ -28,15 +29,15 @@ using std::pow;
2829
#endif
2930

3031
struct ProblemMatrices {
31-
Eigen::MatrixXd K̃, K, L, A;
32+
Eigen::MatrixX<real_t> K̃, K, L, A;
3233
};
3334
using cache_t = std::map<std::pair<index_t, index_t>, ProblemMatrices>;
3435
std::mutex cache_mtx;
3536
cache_t cache;
3637

3738
struct CholeskyFixture : benchmark::Fixture {
3839
index_t m, n;
39-
Eigen::MatrixXd L̃;
40+
Eigen::MatrixX<real_t> L̃;
4041
cache_t::const_iterator matrices;
4142

4243
static cache_t::const_iterator generate_problem(index_t m, index_t n) {
@@ -53,7 +54,7 @@ struct CholeskyFixture : benchmark::Fixture {
5354
#endif
5455

5556
std::mt19937 rng{12345};
56-
std::uniform_real_distribution<> dist(0.0, 1.0);
57+
std::uniform_real_distribution<real_t> dist(-1, 1);
5758
mat.K̃.resize(n, n), mat.K.resize(n, n), mat.L.resize(n, n);
5859
mat.A.resize(n, m);
5960
std::ranges::generate(mat.K.reshaped(), [&] { return dist(rng); });
@@ -96,10 +97,11 @@ struct CholeskyFixture : benchmark::Fixture {
9697
}
9798

9899
void TearDown(benchmark::State &state) final {
99-
Eigen::MatrixXd E = matrices->second.K̃;
100-
const auto n = static_cast<index_t>(L̃.rows()),
101-
ldL̃ = static_cast<index_t>(L̃.outerStride()),
102-
ldE = static_cast<index_t>(E.outerStride());
100+
using std::pow;
101+
Eigen::MatrixX<real_t> E = matrices->second.K̃;
102+
const auto n = static_cast<index_t>(L̃.rows()),
103+
ldL̃ = static_cast<index_t>(L̃.outerStride()),
104+
ldE = static_cast<index_t>(E.outerStride());
103105
#if GUANAQO_WITH_OPENMP
104106
int old_num_threads = omp_get_max_threads();
105107
omp_set_num_threads(std::thread::hardware_concurrency() / 2);
@@ -113,7 +115,8 @@ struct CholeskyFixture : benchmark::Fixture {
113115
E.triangularView<Eigen::StrictlyUpper>().setZero();
114116
real_t r = E.lpNorm<Eigen::Infinity>();
115117
std::string label = "resid=" + guanaqo::float_to_str(r, 6);
116-
if (!(r < 1e-9))
118+
const auto ε = pow(std::numeric_limits<real_t>::epsilon(), real_t(0.5));
119+
if (!(r < ε))
117120
label = "\x1b[0;31m" + label + "\x1b[0m";
118121
state.SetLabel(label);
119122
compute_flops(state);
@@ -129,7 +132,7 @@ struct CholeskyFixture : benchmark::Fixture {
129132

130133
template <auto Func>
131134
void runUpdateBenchmark(benchmark::State &state) {
132-
Eigen::MatrixXd Ã(m, n);
135+
Eigen::MatrixX<real_t> Ã(m, n);
133136
for (auto _ : state) {
134137
state.PauseTiming();
135138
à = matrices->second.A;
@@ -191,7 +194,7 @@ std::vector<::benchmark::internal::Benchmark *> benchmarks;
191194
BENCHMARK_TEMPLATE_DEFINE_F( \
192195
BlockedFixture, BM_BLK_IMPL_NAME(name, __VA_ARGS__), __VA_ARGS__) \
193196
(benchmark::State & state) { \
194-
this->runUpdateBenchmark<func<{__VA_ARGS__}, updown>>(state); \
197+
this->runUpdateBenchmark<func<real_t, {__VA_ARGS__}, updown>>(state); \
195198
} \
196199
BM_BLK_REGISTER_F(BlockedFixture, BM_BLK_IMPL_NAME(name, __VA_ARGS__)) \
197200
->Name(BM_BLK_NAME(name, __VA_ARGS__))
@@ -256,24 +259,17 @@ BENCHMARK_BLOCKED(hyh_update, update_cholesky, Downdate, 4, 4);
256259
BENCHMARK_BLOCKED(hyh_update, update_cholesky, Downdate, 4, 8);
257260
BENCHMARK_BLOCKED(hyh_update, update_cholesky, Downdate, 4, 12);
258261
BENCHMARK_BLOCKED(hyh_update, update_cholesky, Downdate, 4, 16);
259-
#if __AVX512F__
260262
BENCHMARK_BLOCKED(hyh_update, update_cholesky, Downdate, 4, 24);
261263
BENCHMARK_BLOCKED(hyh_update, update_cholesky, Downdate, 4, 32);
262-
#endif
263-
BENCHMARK_BLOCKED(hyh_update, update_cholesky, Downdate, 4, 12, 2);
264-
BENCHMARK_BLOCKED(hyh_update, update_cholesky, Downdate, 4, 12, 4);
265264

266265
BENCHMARK_BLOCKED(hyh_update, update_cholesky, Downdate, 8, 8);
267266
BENCHMARK_BLOCKED(hyh_update, update_cholesky, Downdate, 8, 12);
268267
BENCHMARK_BLOCKED(hyh_update, update_cholesky, Downdate, 8, 16);
269-
#if __AVX512F__
270268
BENCHMARK_BLOCKED(hyh_update, update_cholesky, Downdate, 8, 24);
271269
BENCHMARK_BLOCKED(hyh_update, update_cholesky, Downdate, 8, 32);
270+
#if __AVX512F__
272271
BENCHMARK_BLOCKED(hyh_update, update_cholesky, Downdate, 16, 8);
273-
BENCHMARK_BLOCKED(hyh_update, update_cholesky, Downdate, 16, 12);
274272
BENCHMARK_BLOCKED(hyh_update, update_cholesky, Downdate, 16, 16);
275-
BENCHMARK_BLOCKED(hyh_update, update_cholesky, Downdate, 8, 24, 2);
276-
BENCHMARK_BLOCKED(hyh_update, update_cholesky, Downdate, 8, 24, 4);
277273
#endif
278274
// clang-format on
279275

benchmarks/ocp.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ void bm_solve_riccati(benchmark::State &state) {
3838
}
3939

4040
void bm_update_riccati(benchmark::State &state) {
41+
using std::exp2;
4142
std::mt19937 rng{54321};
4243
std::normal_distribution<real_t> nrml{0, 10};
4344
std::bernoulli_distribution bern{0.25};
@@ -74,6 +75,7 @@ void bm_solve_schur(benchmark::State &state) {
7475
}
7576

7677
void bm_update_schur(benchmark::State &state) {
78+
using std::exp2;
7779
std::mt19937 rng{54321};
7880
std::normal_distribution<real_t> nrml{0, 10};
7981
std::bernoulli_distribution bern{0.25};

conanfile.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,12 @@ class HyhoundRecipe(ConanFile):
2828
options = {
2929
"shared": [True, False],
3030
"fPIC": [True, False],
31-
"real_type": ["double", "float"],
31+
"real_type": ["double;float", "float;double", "double", "float"],
3232
} | {k: [True, False] for k in bool_hyhound_options}
3333
default_options = {
3434
"shared": False,
3535
"fPIC": True,
36-
"real_type": "double",
36+
"real_type": "double;float",
3737
} | bool_hyhound_options
3838

3939
# Sources are located in the same place as this recipe, copy them to the recipe

src/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ find_package(guanaqo REQUIRED)
44
# Configuration options
55
# ------------------------------------------------------------------------------
66
add_library(config INTERFACE)
7+
list(GET HYHOUND_DENSE_REAL_TYPE 0 HYHOUND_DENSE_REAL_TYPE_0)
78
configure_file("config.hpp.in"
89
"${CMAKE_CURRENT_BINARY_DIR}/config/include/hyhound/config.hpp" @ONLY)
910
target_sources(config INTERFACE FILE_SET headers TYPE HEADERS
@@ -25,7 +26,6 @@ target_sources(util INTERFACE FILE_SET headers TYPE HEADERS
2526
"util/include/hyhound/cneg.hpp"
2627
"util/include/hyhound/loop.hpp"
2728
"util/include/hyhound/lut.hpp"
28-
"util/include/hyhound/matrix-view.hpp"
2929
"util/include/hyhound/unroll.h"
3030
)
3131
target_link_libraries(util INTERFACE hyhound::config guanaqo::guanaqo)

src/config.hpp.in

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ namespace hyhound {
1010
#ifdef __clangd__
1111
using real_t = double; // clangd does not support std::float128_t etc.
1212
#else
13-
using real_t = @HYHOUND_DENSE_REAL_TYPE@;
13+
using real_t = @HYHOUND_DENSE_REAL_TYPE_0@;
1414
#endif
1515
using index_t = @HYHOUND_DENSE_INDEX_TYPE@;
1616

src/hyhound/include/hyhound/householder-updowndate-micro-kernels.tpp

Lines changed: 40 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,12 @@ using micro_kernels::householder::mut_W_accessor;
1717
/// `Func<4, 8>` on the first block row of its arguments, then
1818
/// `Func<4, 4>` on the second block row, `Func<4, 2>`
1919
/// and finally `downdate_tail<4, 1>` for the bottom row.
20-
template <template <auto, class> class Func, Config Conf, class UpDown,
21-
index_t M, index_t... Ms>
20+
template <class T, template <auto, class, class> class Func, Config Conf,
21+
class UpDown, index_t M, index_t... Ms>
2222
inline void tile_tail(index_t rowsA, index_t colsA0, index_t colsA,
23-
mut_W_accessor<> W, real_t *L, index_t ldL,
24-
const real_t *B, index_t ldB, real_t *A, index_t ldA,
25-
UpDown updown) noexcept {
26-
constexpr auto simd_M = micro_kernels::native_simd_size;
23+
mut_W_accessor<T> W, T *L, index_t ldL, const T *B,
24+
index_t ldB, T *A, index_t ldA, UpDown updown) noexcept {
25+
constexpr auto simd_M = micro_kernels::native_simd_size<T>;
2726
// If the block size is larger than the config allows, skip it.
2827
constexpr bool skip_large_M = M > Conf.block_size_s;
2928
// If the block size is not efficiently vectorizable, and is not yet a
@@ -34,30 +33,30 @@ inline void tile_tail(index_t rowsA, index_t colsA0, index_t colsA,
3433
constexpr bool skip_suboptimal_M = M > simd_M && (M % simd_M) != 0;
3534
if constexpr (skip_large_M || skip_suboptimal_M) {
3635
if constexpr (sizeof...(Ms) > 0)
37-
tile_tail<Func, Conf, UpDown, Ms...>(rowsA, colsA0, colsA, W, L,
38-
ldL, B, ldB, A, ldA, updown);
36+
tile_tail<T, Func, Conf, UpDown, Ms...>(
37+
rowsA, colsA0, colsA, W, L, ldL, B, ldB, A, ldA, updown);
3938
return;
4039
}
4140
while (rowsA >= M) {
42-
constexpr Config NewConf {.block_size_r = Conf.block_size_r,
43-
.block_size_s = M};
44-
Func<NewConf, UpDown> {}(colsA0, colsA, W, L, ldL, B, ldB, A, ldA,
45-
updown);
41+
constexpr Config NewConf{.block_size_r = Conf.block_size_r,
42+
.block_size_s = M};
43+
Func<NewConf, T, UpDown>{}(colsA0, colsA, W, L, ldL, B, ldB, A, ldA,
44+
updown);
4645
L += M;
4746
A += M;
4847
rowsA -= M;
4948
}
5049
if constexpr (sizeof...(Ms) > 0)
5150
if (rowsA > 0)
52-
tile_tail<Func, Conf, UpDown, Ms...>(rowsA, colsA0, colsA, W, L,
53-
ldL, B, ldB, A, ldA, updown);
51+
tile_tail<T, Func, Conf, UpDown, Ms...>(
52+
rowsA, colsA0, colsA, W, L, ldL, B, ldB, A, ldA, updown);
5453
}
5554

56-
template <Config Conf, class UpDown>
55+
template <Config Conf, class T, class UpDown>
5756
struct updowndate_tail_func {
5857
template <class... Args>
5958
decltype(auto) operator()(Args &&...args) const {
60-
return micro_kernels::householder::updowndate_tail<Conf, UpDown>(
59+
return micro_kernels::householder::updowndate_tail<Conf, T, UpDown>(
6160
std::forward<Args>(args)...);
6261
}
6362
};
@@ -67,44 +66,45 @@ struct updowndate_tail_func {
6766
/// @see @ref detail::tile_tail
6867
/// The sizes specified here should be instantiated in the code generated by
6968
/// CMake.
70-
template <micro_kernels::householder::Config Conf, class UpDown>
69+
template <micro_kernels::householder::Config Conf, class T, class UpDown>
7170
inline void updowndate_tile_tail(index_t rowsA, index_t colsA0, index_t colsA,
72-
detail::mut_W_accessor<> W,
73-
detail::mut_matrix_accessor L,
74-
detail::matrix_accessor B,
75-
detail::mut_matrix_accessor A, UpDown signs) {
76-
detail::tile_tail<detail::updowndate_tail_func, Conf, UpDown, //
71+
detail::mut_W_accessor<T> W,
72+
detail::mut_matrix_accessor<T> L,
73+
detail::matrix_accessor<T> B,
74+
detail::mut_matrix_accessor<T> A,
75+
UpDown signs) {
76+
detail::tile_tail<T, detail::updowndate_tail_func, Conf, UpDown, //
7777
32, 24, 16, 12, 8, 4, 2, 1>(
7878
rowsA, colsA0, colsA, W, L.data, L.outer_stride, B.data, B.outer_stride,
7979
A.data, A.outer_stride, signs);
8080
}
8181

82-
template <micro_kernels::householder::Config Conf, class UpDown>
82+
template <micro_kernels::householder::Config Conf, class T, class UpDown>
8383
inline void
84-
updowndate_tail(index_t colsA0, index_t colsA, detail::mut_W_accessor<> W,
85-
detail::mut_matrix_accessor L, detail::matrix_accessor B,
86-
detail::mut_matrix_accessor A, UpDown signs) {
84+
updowndate_tail(index_t colsA0, index_t colsA, detail::mut_W_accessor<T> W,
85+
detail::mut_matrix_accessor<T> L, detail::matrix_accessor<T> B,
86+
detail::mut_matrix_accessor<T> A, UpDown signs) {
8787
using micro_kernels::householder::updowndate_tail;
88-
updowndate_tail<Conf, UpDown>(colsA0, colsA, W, L.data, L.outer_stride,
89-
B.data, B.outer_stride, A.data,
90-
A.outer_stride, signs);
88+
updowndate_tail<Conf, T, UpDown>(colsA0, colsA, W, L.data, L.outer_stride,
89+
B.data, B.outer_stride, A.data,
90+
A.outer_stride, signs);
9191
}
9292

93-
template <index_t R, class UpDown>
94-
inline void updowndate_diag(index_t colsA, detail::mut_W_accessor<> W,
95-
detail::mut_matrix_accessor L,
96-
detail::mut_matrix_accessor A, UpDown signs) {
93+
template <index_t R, class T, class UpDown>
94+
inline void updowndate_diag(index_t colsA, detail::mut_W_accessor<T> W,
95+
detail::mut_matrix_accessor<T> L,
96+
detail::mut_matrix_accessor<T> A, UpDown signs) {
9797
using micro_kernels::householder::updowndate_diag;
98-
updowndate_diag<R, UpDown>(colsA, W, L.data, L.outer_stride, A.data,
99-
A.outer_stride, signs);
98+
updowndate_diag<R, T, UpDown>(colsA, W, L.data, L.outer_stride, A.data,
99+
A.outer_stride, signs);
100100
}
101101

102-
template <index_t R, class UpDown>
103-
inline void updowndate_full(index_t colsA, detail::mut_matrix_accessor L,
104-
detail::mut_matrix_accessor A, UpDown signs) {
102+
template <index_t R, class T, class UpDown>
103+
inline void updowndate_full(index_t colsA, detail::mut_matrix_accessor<T> L,
104+
detail::mut_matrix_accessor<T> A, UpDown signs) {
105105
using micro_kernels::householder::updowndate_full;
106-
updowndate_full<R, UpDown>(colsA, L.data, L.outer_stride, A.data,
107-
A.outer_stride, signs);
106+
updowndate_full<R, T, UpDown>(colsA, L.data, L.outer_stride, A.data,
107+
A.outer_stride, signs);
108108
}
109109

110110
} // namespace hyhound

0 commit comments

Comments
 (0)