@@ -17,13 +17,12 @@ using micro_kernels::householder::mut_W_accessor;
1717// / `Func<4, 8>` on the first block row of its arguments, then
1818// / `Func<4, 4>` on the second block row, `Func<4, 2>`
1919// / and finally `downdate_tail<4, 1>` for the bottom row.
20- template <template <auto , class > class Func , Config Conf, class UpDown ,
21- index_t M, index_t ... Ms>
20+ template <class T , template <auto , class , class > class Func , Config Conf,
21+ class UpDown , index_t M, index_t ... Ms>
2222inline void tile_tail (index_t rowsA, index_t colsA0, index_t colsA,
23- mut_W_accessor<> W, real_t *L, index_t ldL,
24- const real_t *B, index_t ldB, real_t *A, index_t ldA,
25- UpDown updown) noexcept {
26- constexpr auto simd_M = micro_kernels::native_simd_size;
23+ mut_W_accessor<T> W, T *L, index_t ldL, const T *B,
24+ index_t ldB, T *A, index_t ldA, UpDown updown) noexcept {
25+ constexpr auto simd_M = micro_kernels::native_simd_size<T>;
2726 // If the block size is larger than the config allows, skip it.
2827 constexpr bool skip_large_M = M > Conf.block_size_s ;
2928 // If the block size is not efficiently vectorizable, and is not yet a
@@ -34,30 +33,30 @@ inline void tile_tail(index_t rowsA, index_t colsA0, index_t colsA,
3433 constexpr bool skip_suboptimal_M = M > simd_M && (M % simd_M) != 0 ;
3534 if constexpr (skip_large_M || skip_suboptimal_M) {
3635 if constexpr (sizeof ...(Ms) > 0 )
37- tile_tail<Func, Conf, UpDown, Ms...>(rowsA, colsA0, colsA, W, L,
38- ldL, B, ldB, A, ldA, updown);
36+ tile_tail<T, Func, Conf, UpDown, Ms...>(
37+ rowsA, colsA0, colsA, W, L, ldL, B, ldB, A, ldA, updown);
3938 return ;
4039 }
4140 while (rowsA >= M) {
42- constexpr Config NewConf {.block_size_r = Conf.block_size_r ,
43- .block_size_s = M};
44- Func<NewConf, UpDown> {}(colsA0, colsA, W, L, ldL, B, ldB, A, ldA,
45- updown);
41+ constexpr Config NewConf{.block_size_r = Conf.block_size_r ,
42+ .block_size_s = M};
43+ Func<NewConf, T, UpDown>{}(colsA0, colsA, W, L, ldL, B, ldB, A, ldA,
44+ updown);
4645 L += M;
4746 A += M;
4847 rowsA -= M;
4948 }
5049 if constexpr (sizeof ...(Ms) > 0 )
5150 if (rowsA > 0 )
52- tile_tail<Func, Conf, UpDown, Ms...>(rowsA, colsA0, colsA, W, L,
53- ldL, B, ldB, A, ldA, updown);
51+ tile_tail<T, Func, Conf, UpDown, Ms...>(
52+ rowsA, colsA0, colsA, W, L, ldL, B, ldB, A, ldA, updown);
5453}
5554
56- template <Config Conf, class UpDown >
55+ template <Config Conf, class T , class UpDown >
5756struct updowndate_tail_func {
5857 template <class ... Args>
5958 decltype (auto ) operator ()(Args &&...args) const {
60- return micro_kernels::householder::updowndate_tail<Conf, UpDown>(
59+ return micro_kernels::householder::updowndate_tail<Conf, T, UpDown>(
6160 std::forward<Args>(args)...);
6261 }
6362};
@@ -67,44 +66,45 @@ struct updowndate_tail_func {
6766// / @see @ref detail::tile_tail
6867// / The sizes specified here should be instantiated in the code generated by
6968// / CMake.
70- template <micro_kernels::householder::Config Conf, class UpDown >
69+ template <micro_kernels::householder::Config Conf, class T , class UpDown >
7170inline void updowndate_tile_tail (index_t rowsA, index_t colsA0, index_t colsA,
72- detail::mut_W_accessor<> W,
73- detail::mut_matrix_accessor L,
74- detail::matrix_accessor B,
75- detail::mut_matrix_accessor A, UpDown signs) {
76- detail::tile_tail<detail::updowndate_tail_func, Conf, UpDown, //
71+ detail::mut_W_accessor<T> W,
72+ detail::mut_matrix_accessor<T> L,
73+ detail::matrix_accessor<T> B,
74+ detail::mut_matrix_accessor<T> A,
75+ UpDown signs) {
76+ detail::tile_tail<T, detail::updowndate_tail_func, Conf, UpDown, //
7777 32 , 24 , 16 , 12 , 8 , 4 , 2 , 1 >(
7878 rowsA, colsA0, colsA, W, L.data , L.outer_stride , B.data , B.outer_stride ,
7979 A.data , A.outer_stride , signs);
8080}
8181
82- template <micro_kernels::householder::Config Conf, class UpDown >
82+ template <micro_kernels::householder::Config Conf, class T , class UpDown >
8383inline void
84- updowndate_tail (index_t colsA0, index_t colsA, detail::mut_W_accessor<> W,
85- detail::mut_matrix_accessor L, detail::matrix_accessor B,
86- detail::mut_matrix_accessor A, UpDown signs) {
84+ updowndate_tail (index_t colsA0, index_t colsA, detail::mut_W_accessor<T > W,
85+ detail::mut_matrix_accessor<T> L, detail::matrix_accessor<T> B,
86+ detail::mut_matrix_accessor<T> A, UpDown signs) {
8787 using micro_kernels::householder::updowndate_tail;
88- updowndate_tail<Conf, UpDown>(colsA0, colsA, W, L.data , L.outer_stride ,
89- B.data , B.outer_stride , A.data ,
90- A.outer_stride , signs);
88+ updowndate_tail<Conf, T, UpDown>(colsA0, colsA, W, L.data , L.outer_stride ,
89+ B.data , B.outer_stride , A.data ,
90+ A.outer_stride , signs);
9191}
9292
93- template <index_t R, class UpDown >
94- inline void updowndate_diag (index_t colsA, detail::mut_W_accessor<> W,
95- detail::mut_matrix_accessor L,
96- detail::mut_matrix_accessor A, UpDown signs) {
93+ template <index_t R, class T , class UpDown >
94+ inline void updowndate_diag (index_t colsA, detail::mut_W_accessor<T > W,
95+ detail::mut_matrix_accessor<T> L,
96+ detail::mut_matrix_accessor<T> A, UpDown signs) {
9797 using micro_kernels::householder::updowndate_diag;
98- updowndate_diag<R, UpDown>(colsA, W, L.data , L.outer_stride , A.data ,
99- A.outer_stride , signs);
98+ updowndate_diag<R, T, UpDown>(colsA, W, L.data , L.outer_stride , A.data ,
99+ A.outer_stride , signs);
100100}
101101
102- template <index_t R, class UpDown >
103- inline void updowndate_full (index_t colsA, detail::mut_matrix_accessor L,
104- detail::mut_matrix_accessor A, UpDown signs) {
102+ template <index_t R, class T , class UpDown >
103+ inline void updowndate_full (index_t colsA, detail::mut_matrix_accessor<T> L,
104+ detail::mut_matrix_accessor<T> A, UpDown signs) {
105105 using micro_kernels::householder::updowndate_full;
106- updowndate_full<R, UpDown>(colsA, L.data , L.outer_stride , A.data ,
107- A.outer_stride , signs);
106+ updowndate_full<R, T, UpDown>(colsA, L.data , L.outer_stride , A.data ,
107+ A.outer_stride , signs);
108108}
109109
110110} // namespace hyhound
0 commit comments