44#pragma once
55
66#include " gprat/detail/config.hpp"
7+ #include " gprat/tile_data.hpp"
78
89#include < hpx/future.hpp>
910#include < vector>
1011
1112GPRAT_NS_BEGIN
1213
13- using vector_future = hpx::shared_future<std::vector<double >>;
14-
1514// Constants that are compatible with CBLAS
1615typedef enum BLAS_TRANSPOSE { Blas_no_trans = 111 , Blas_trans = 112 } BLAS_TRANSPOSE;
1716
@@ -29,108 +28,117 @@ typedef enum BLAS_ALPHA { Blas_add = 1, Blas_substract = -1 } BLAS_ALPHA;
2928
3029/* *
3130 * @brief FP64 In-place Cholesky decomposition of A
32- * @param f_A matrix to be factorized
31+ * @param A matrix to be factorized
3332 * @param N matrix dimension
3433 * @return factorized, lower triangular matrix f_L
3534 */
36- vector_future potrf (vector_future f_A, const int N);
35+ mutable_tile_data< double > potrf (const mutable_tile_data< double > &A, int N);
3736
3837/* *
3938 * @brief FP64 In-place solve L(^T) * X = A or X * L(^T) = A where L lower triangular
40- * @param f_L Cholesky factor matrix
41- * @param f_A right hand side matrix
39+ * @param L Cholesky factor matrix
40+ * @param A right hand side matrix
4241 * @param N first dimension
4342 * @param M second dimension
4443 * @return solution matrix f_X
4544 */
46- vector_future trsm (vector_future f_L,
47- vector_future f_A,
48- const int N,
49- const int M,
50- const BLAS_TRANSPOSE transpose_L,
51- const BLAS_SIDE side_L);
45+ mutable_tile_data<double >
46+ trsm (const const_tile_data<double > &L,
47+ const mutable_tile_data<double > &A,
48+ int N,
49+ int M,
50+ BLAS_TRANSPOSE transpose_L,
51+ BLAS_SIDE side_L);
5252
5353/* *
5454 * @brief FP64 Symmetric rank-k update: A = A - B * B^T
55- * @param f_A Base matrix
56- * @param f_B Symmetric update matrix
55+ * @param A Base matrix
56+ * @param B Symmetric update matrix
5757 * @param N matrix dimension
5858 * @return updated matrix f_A
5959 */
60- vector_future syrk (vector_future f_A, vector_future f_B , const int N);
60+ mutable_tile_data< double > syrk (const mutable_tile_data< double > &A , const const_tile_data< double > &B, int N);
6161
6262/* *
6363 * @brief FP64 General matrix-matrix multiplication: C = C - A(^T) * B(^T)
64- * @param f_C Base matrix
65- * @param f_B Right update matrix
66- * @param f_A Left update matrix
64+ * @param C Base matrix
65+ * @param B Right update matrix
66+ * @param A Left update matrix
6767 * @param N first matrix dimension
6868 * @param M second matrix dimension
6969 * @param K third matrix dimension
7070 * @param transpose_A transpose left matrix
7171 * @param transpose_B transpose right matrix
7272 * @return updated matrix f_X
7373 */
74- vector_future
75- gemm (vector_future f_A ,
76- vector_future f_B ,
77- vector_future f_C ,
78- const int N,
79- const int M,
80- const int K,
81- const BLAS_TRANSPOSE transpose_A,
82- const BLAS_TRANSPOSE transpose_B);
74+ mutable_tile_data< double >
75+ gemm (const const_tile_data< double > &A ,
76+ const const_tile_data< double > &B ,
77+ const mutable_tile_data< double > &C ,
78+ int N,
79+ int M,
80+ int K,
81+ BLAS_TRANSPOSE transpose_A,
82+ BLAS_TRANSPOSE transpose_B);
8383
8484// BLAS level 2 operations
8585
8686/* *
8787 * @brief FP64 In-place solve L(^T) * x = a where L lower triangular
88- * @param f_L Cholesky factor matrix
89- * @param f_a right hand side vector
88+ * @param L Cholesky factor matrix
89+ * @param a right hand side vector
9090 * @param N matrix dimension
9191 * @param transpose_L transpose Cholesky factor
9292 * @return solution vector f_x
9393 */
94- vector_future trsv (vector_future f_L, vector_future f_a, const int N, const BLAS_TRANSPOSE transpose_L);
94+ mutable_tile_data<double >
95+ trsv (const const_tile_data<double > &L, const mutable_tile_data<double > &a, int N, BLAS_TRANSPOSE transpose_L);
9596
9697/* *
9798 * @brief FP64 General matrix-vector multiplication: b = b - A(^T) * a
98- * @param f_A update matrix
99- * @param f_a update vector
100- * @param f_b base vector
99+ * @param A update matrix
100+ * @param a update vector
101+ * @param b base vector
101102 * @param N matrix dimension
102103 * @param alpha add or substract update to base vector
103104 * @param transpose_A transpose update matrix
104105 * @return updated vector f_b
105106 */
106- vector_future gemv (vector_future f_A,
107- vector_future f_a,
108- vector_future f_b,
109- const int N,
110- const int M,
111- const BLAS_ALPHA alpha,
112- const BLAS_TRANSPOSE transpose_A);
107+ mutable_tile_data<double >
108+ gemv (const const_tile_data<double > &A,
109+ const const_tile_data<double > &a,
110+ const mutable_tile_data<double > &b,
111+ int N,
112+ int M,
113+ BLAS_ALPHA alpha,
114+ BLAS_TRANSPOSE transpose_A);
113115
114116/* *
115117 * @brief FP64 Vector update with diagonal SYRK: r = r + diag(A^T * A)
116- * @param f_A update matrix
117- * @param f_r base vector
118+ * @param A update matrix
119+ * @param r base vector
118120 * @param N first matrix dimension
119121 * @param M second matrix dimension
120122 * @return updated vector f_r
121123 */
122- vector_future dot_diag_syrk (vector_future f_A, vector_future f_r, const int N, const int M);
124+ mutable_tile_data<double >
125+ dot_diag_syrk (const const_tile_data<double > &A, const mutable_tile_data<double > &r, int N, int M);
123126
124127/* *
125128 * @brief FP64 Vector update with diagonal GEMM: r = r + diag(A * B)
126- * @param f_A first update matrix
127- * @param f_B second update matrix
128- * @param f_r base vector
129+ * @param A first update matrix
130+ * @param B second update matrix
131+ * @param r base vector
129132 * @param N first matrix dimension
130133 * @param M second matrix dimension
131134 * @return updated vector f_r
132135 */
133- vector_future dot_diag_gemm (vector_future f_A, vector_future f_B, vector_future f_r, const int N, const int M);
136+ mutable_tile_data<double >
137+ dot_diag_gemm (const const_tile_data<double > &A,
138+ const const_tile_data<double > &B,
139+ const mutable_tile_data<double > &r,
140+ int N,
141+ int M);
134142
135143// BLAS level 1 operations
136144
@@ -141,7 +149,7 @@ vector_future dot_diag_gemm(vector_future f_A, vector_future f_B, vector_future
141149 * @param N vector length
142150 * @return y - x
143151 */
144- vector_future axpy (vector_future f_y, vector_future f_x , const int N);
152+ mutable_tile_data< double > axpy (const mutable_tile_data< double > &y , const const_tile_data< double > &x, int N);
145153
146154/* *
147155 * @brief FP64 Dot product: a * b
@@ -150,7 +158,7 @@ vector_future axpy(vector_future f_y, vector_future f_x, const int N);
150158 * @param N vector length
151159 * @return a * b
152160 */
153- double dot (std::vector< double > a, std::vector< double > b, const int N);
161+ double dot (std::span< const double > a, std::span< const double > b, int N);
154162
155163GPRAT_NS_END
156164
0 commit comments