Skip to content

Commit 08c1eed

Browse files
feat!(core): Introduce const_tile_data + mutable_tile_data
in lieu of std::vector<T> for tiles of type T. The advantage of this is: - tiles are easily HPX-serializable and we can put them into HPX components - we can perhaps later add support for automatic GPU upload
1 parent 42412be commit 08c1eed

23 files changed

+764
-698
lines changed

core/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ add_compile_definitions(GPRAT_WITH_CUDA=$<BOOL:${GPRAT_WITH_CUDA}>)
1111
set(SOURCE_FILES
1212
src/gprat.cpp
1313
src/utils.cpp
14+
src/performance_counters.cpp
1415
src/target.cpp
1516
src/kernels.cpp
1617
src/hyperparameters.cpp

core/include/gprat/cpu/adapter_cblas_fp32.hpp

Lines changed: 50 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,13 @@
44
#pragma once
55

66
#include "gprat/detail/config.hpp"
7+
#include "gprat/tile_data.hpp"
78

89
#include <hpx/future.hpp>
910
#include <vector>
1011

1112
GPRAT_NS_BEGIN
1213

13-
using vector_future = hpx::shared_future<std::vector<float>>;
14-
1514
// Constants that are compatible with CBLAS
1615
typedef enum BLAS_TRANSPOSE { Blas_no_trans = 111, Blas_trans = 112 } BLAS_TRANSPOSE;
1716

@@ -29,69 +28,71 @@ typedef enum BLAS_ALPHA { Blas_add = 1, Blas_substract = -1 } BLAS_ALPHA;
2928

3029
/**
3130
* @brief FP32 In-place Cholesky decomposition of A
32-
* @param f_A matrix to be factorized
31+
* @param A matrix to be factorized
3332
* @param N matrix dimension
3433
* @return factorized, lower triangular matrix f_L
3534
*/
36-
vector_future potrf(vector_future f_A, const int N);
35+
mutable_tile_data<float> potrf(const mutable_tile_data<float> &A, int N);
3736

3837
/**
3938
* @brief FP32 In-place solve L(^T) * X = A or X * L(^T) = A where L lower triangular
40-
* @param f_L Cholesky factor matrix
41-
* @param f_A right hand side matrix
39+
* @param L Cholesky factor matrix
40+
* @param A right hand side matrix
4241
* @param N first dimension
4342
* @param M second dimension
4443
* @return solution matrix f_X
4544
*/
46-
vector_future trsm(vector_future f_L,
47-
vector_future f_A,
48-
const int N,
49-
const int M,
50-
const BLAS_TRANSPOSE transpose_L,
51-
const BLAS_SIDE side_L);
45+
mutable_tile_data<float>
46+
trsm(const const_tile_data<float> &L,
47+
const mutable_tile_data<float> &A,
48+
int N,
49+
int M,
50+
BLAS_TRANSPOSE transpose_L,
51+
BLAS_SIDE side_L);
5252

5353
/**
5454
* @brief FP32 Symmetric rank-k update: A = A - B * B^T
55-
* @param f_A Base matrix
56-
* @param f_B Symmetric update matrix
55+
* @param A Base matrix
56+
* @param B Symmetric update matrix
5757
* @param N matrix dimension
5858
* @return updated matrix f_A
5959
*/
60-
vector_future syrk(vector_future f_A, vector_future f_B, const int N);
60+
mutable_tile_data<float> syrk(const mutable_tile_data<float> &A, const const_tile_data<float> &B, int N);
6161

6262
/**
6363
* @brief FP32 General matrix-matrix multiplication: C = C - A(^T) * B(^T)
64-
* @param f_C Base matrix
65-
* @param f_B Right update matrix
66-
* @param f_A Left update matrix
64+
* @param C Base matrix
65+
* @param B Right update matrix
66+
* @param A Left update matrix
6767
* @param N first matrix dimension
6868
* @param M second matrix dimension
6969
* @param K third matrix dimension
7070
* @param transpose_A transpose left matrix
7171
* @param transpose_B transpose right matrix
7272
* @return updated matrix f_X
7373
*/
74-
vector_future
75-
gemm(vector_future f_A,
76-
vector_future f_B,
77-
vector_future f_C,
78-
const int N,
79-
const int M,
80-
const int K,
81-
const BLAS_TRANSPOSE transpose_A,
82-
const BLAS_TRANSPOSE transpose_B);
74+
mutable_tile_data<float>
75+
gemm(const const_tile_data<float> &A,
76+
const const_tile_data<float> &B,
77+
const mutable_tile_data<float> &C,
78+
int N,
79+
int M,
80+
int K,
81+
BLAS_TRANSPOSE transpose_A,
82+
BLAS_TRANSPOSE transpose_B);
8383

8484
// BLAS level 2 operations
8585

8686
/**
8787
* @brief FP32 In-place solve L(^T) * x = a where L lower triangular
88-
* @param f_L Cholesky factor matrix
89-
* @param f_a right hand side vector
88+
* @param L Cholesky factor matrix
89+
* @param a right hand side vector
9090
* @param N matrix dimension
9191
* @param transpose_L transpose Cholesky factor
9292
* @return solution vector f_x
9393
*/
94-
vector_future trsv(vector_future f_L, vector_future f_a, const int N, const BLAS_TRANSPOSE transpose_L);
94+
mutable_tile_data<float>
95+
trsv(const const_tile_data<float> &L, const mutable_tile_data<float> &a, int N, BLAS_TRANSPOSE transpose_L);
9596

9697
/**
9798
* @brief FP32 General matrix-vector multiplication: b = b - A(^T) * a
@@ -103,34 +104,37 @@ vector_future trsv(vector_future f_L, vector_future f_a, const int N, const BLAS
103104
* @param transpose_A transpose update matrix
104105
* @return updated vector f_b
105106
*/
106-
vector_future gemv(vector_future f_A,
107-
vector_future f_a,
108-
vector_future f_b,
109-
const int N,
110-
const int M,
111-
const BLAS_ALPHA alpha,
112-
const BLAS_TRANSPOSE transpose_A);
107+
mutable_tile_data<float>
108+
gemv(const const_tile_data<float> &A,
109+
const const_tile_data<float> &a,
110+
const mutable_tile_data<float> &b,
111+
int N,
112+
int M,
113+
BLAS_ALPHA alpha,
114+
BLAS_TRANSPOSE transpose_A);
113115

114116
/**
115117
* @brief FP32 Vector update with diagonal SYRK: r = r + diag(A^T * A)
116-
* @param f_A update matrix
117-
* @param f_r base vector
118+
* @param A update matrix
119+
* @param r base vector
118120
* @param N first matrix dimension
119121
* @param M second matrix dimension
120122
* @return updated vector f_r
121123
*/
122-
vector_future dot_diag_syrk(vector_future f_A, vector_future f_r, const int N, const int M);
124+
mutable_tile_data<float>
125+
dot_diag_syrk(const const_tile_data<float> &A, const mutable_tile_data<float> &r, int N, int M);
123126

124127
/**
125128
* @brief FP32 Vector update with diagonal GEMM: r = r + diag(A * B)
126-
* @param f_A first update matrix
127-
* @param f_B second update matrix
128-
* @param f_r base vector
129+
* @param A first update matrix
130+
* @param B second update matrix
131+
* @param r base vector
129132
* @param N first matrix dimension
130133
* @param M second matrix dimension
131134
* @return updated vector f_r
132135
*/
133-
vector_future dot_diag_gemm(vector_future f_A, vector_future f_B, vector_future f_r, const int N, const int M);
136+
mutable_tile_data<float> dot_diag_gemm(
137+
const const_tile_data<float> &A, const const_tile_data<float> &B, const mutable_tile_data<float> &r, int N, int M);
134138

135139
// BLAS level 1 operations
136140

@@ -141,7 +145,7 @@ vector_future dot_diag_gemm(vector_future f_A, vector_future f_B, vector_future
141145
* @param N vector length
142146
* @return y - x
143147
*/
144-
vector_future axpy(vector_future f_y, vector_future f_x, const int N);
148+
mutable_tile_data<float> axpy(const mutable_tile_data<float> &y, const const_tile_data<float> &x, int N);
145149

146150
/**
147151
* @brief FP32 Dot product: a * b
@@ -150,7 +154,7 @@ vector_future axpy(vector_future f_y, vector_future f_x, const int N);
150154
* @param N vector length
151155
* @return f_a * f_b
152156
*/
153-
float dot(std::vector<float> a, std::vector<float> b, const int N);
157+
float dot(std::span<const float> a, std::span<const float> b, int N);
154158

155159
GPRAT_NS_END
156160

core/include/gprat/cpu/adapter_cblas_fp64.hpp

Lines changed: 57 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,13 @@
44
#pragma once
55

66
#include "gprat/detail/config.hpp"
7+
#include "gprat/tile_data.hpp"
78

89
#include <hpx/future.hpp>
910
#include <vector>
1011

1112
GPRAT_NS_BEGIN
1213

13-
using vector_future = hpx::shared_future<std::vector<double>>;
14-
1514
// Constants that are compatible with CBLAS
1615
typedef enum BLAS_TRANSPOSE { Blas_no_trans = 111, Blas_trans = 112 } BLAS_TRANSPOSE;
1716

@@ -29,108 +28,117 @@ typedef enum BLAS_ALPHA { Blas_add = 1, Blas_substract = -1 } BLAS_ALPHA;
2928

3029
/**
3130
* @brief FP64 In-place Cholesky decomposition of A
32-
* @param f_A matrix to be factorized
31+
* @param A matrix to be factorized
3332
* @param N matrix dimension
3433
* @return factorized, lower triangular matrix f_L
3534
*/
36-
vector_future potrf(vector_future f_A, const int N);
35+
mutable_tile_data<double> potrf(const mutable_tile_data<double> &A, int N);
3736

3837
/**
3938
* @brief FP64 In-place solve L(^T) * X = A or X * L(^T) = A where L lower triangular
40-
* @param f_L Cholesky factor matrix
41-
* @param f_A right hand side matrix
39+
* @param L Cholesky factor matrix
40+
* @param A right hand side matrix
4241
* @param N first dimension
4342
* @param M second dimension
4443
* @return solution matrix f_X
4544
*/
46-
vector_future trsm(vector_future f_L,
47-
vector_future f_A,
48-
const int N,
49-
const int M,
50-
const BLAS_TRANSPOSE transpose_L,
51-
const BLAS_SIDE side_L);
45+
mutable_tile_data<double>
46+
trsm(const const_tile_data<double> &L,
47+
const mutable_tile_data<double> &A,
48+
int N,
49+
int M,
50+
BLAS_TRANSPOSE transpose_L,
51+
BLAS_SIDE side_L);
5252

5353
/**
5454
* @brief FP64 Symmetric rank-k update: A = A - B * B^T
55-
* @param f_A Base matrix
56-
* @param f_B Symmetric update matrix
55+
* @param A Base matrix
56+
* @param B Symmetric update matrix
5757
* @param N matrix dimension
5858
* @return updated matrix f_A
5959
*/
60-
vector_future syrk(vector_future f_A, vector_future f_B, const int N);
60+
mutable_tile_data<double> syrk(const mutable_tile_data<double> &A, const const_tile_data<double> &B, int N);
6161

6262
/**
6363
* @brief FP64 General matrix-matrix multiplication: C = C - A(^T) * B(^T)
64-
* @param f_C Base matrix
65-
* @param f_B Right update matrix
66-
* @param f_A Left update matrix
64+
* @param C Base matrix
65+
* @param B Right update matrix
66+
* @param A Left update matrix
6767
* @param N first matrix dimension
6868
* @param M second matrix dimension
6969
* @param K third matrix dimension
7070
* @param transpose_A transpose left matrix
7171
* @param transpose_B transpose right matrix
7272
* @return updated matrix f_X
7373
*/
74-
vector_future
75-
gemm(vector_future f_A,
76-
vector_future f_B,
77-
vector_future f_C,
78-
const int N,
79-
const int M,
80-
const int K,
81-
const BLAS_TRANSPOSE transpose_A,
82-
const BLAS_TRANSPOSE transpose_B);
74+
mutable_tile_data<double>
75+
gemm(const const_tile_data<double> &A,
76+
const const_tile_data<double> &B,
77+
const mutable_tile_data<double> &C,
78+
int N,
79+
int M,
80+
int K,
81+
BLAS_TRANSPOSE transpose_A,
82+
BLAS_TRANSPOSE transpose_B);
8383

8484
// BLAS level 2 operations
8585

8686
/**
8787
* @brief FP64 In-place solve L(^T) * x = a where L lower triangular
88-
* @param f_L Cholesky factor matrix
89-
* @param f_a right hand side vector
88+
* @param L Cholesky factor matrix
89+
* @param a right hand side vector
9090
* @param N matrix dimension
9191
* @param transpose_L transpose Cholesky factor
9292
* @return solution vector f_x
9393
*/
94-
vector_future trsv(vector_future f_L, vector_future f_a, const int N, const BLAS_TRANSPOSE transpose_L);
94+
mutable_tile_data<double>
95+
trsv(const const_tile_data<double> &L, const mutable_tile_data<double> &a, int N, BLAS_TRANSPOSE transpose_L);
9596

9697
/**
9798
* @brief FP64 General matrix-vector multiplication: b = b - A(^T) * a
98-
* @param f_A update matrix
99-
* @param f_a update vector
100-
* @param f_b base vector
99+
* @param A update matrix
100+
* @param a update vector
101+
* @param b base vector
101102
* @param N matrix dimension
102103
* @param alpha add or substract update to base vector
103104
* @param transpose_A transpose update matrix
104105
* @return updated vector f_b
105106
*/
106-
vector_future gemv(vector_future f_A,
107-
vector_future f_a,
108-
vector_future f_b,
109-
const int N,
110-
const int M,
111-
const BLAS_ALPHA alpha,
112-
const BLAS_TRANSPOSE transpose_A);
107+
mutable_tile_data<double>
108+
gemv(const const_tile_data<double> &A,
109+
const const_tile_data<double> &a,
110+
const mutable_tile_data<double> &b,
111+
int N,
112+
int M,
113+
BLAS_ALPHA alpha,
114+
BLAS_TRANSPOSE transpose_A);
113115

114116
/**
115117
* @brief FP64 Vector update with diagonal SYRK: r = r + diag(A^T * A)
116-
* @param f_A update matrix
117-
* @param f_r base vector
118+
* @param A update matrix
119+
* @param r base vector
118120
* @param N first matrix dimension
119121
* @param M second matrix dimension
120122
* @return updated vector f_r
121123
*/
122-
vector_future dot_diag_syrk(vector_future f_A, vector_future f_r, const int N, const int M);
124+
mutable_tile_data<double>
125+
dot_diag_syrk(const const_tile_data<double> &A, const mutable_tile_data<double> &r, int N, int M);
123126

124127
/**
125128
* @brief FP64 Vector update with diagonal GEMM: r = r + diag(A * B)
126-
* @param f_A first update matrix
127-
* @param f_B second update matrix
128-
* @param f_r base vector
129+
* @param A first update matrix
130+
* @param B second update matrix
131+
* @param r base vector
129132
* @param N first matrix dimension
130133
* @param M second matrix dimension
131134
* @return updated vector f_r
132135
*/
133-
vector_future dot_diag_gemm(vector_future f_A, vector_future f_B, vector_future f_r, const int N, const int M);
136+
mutable_tile_data<double>
137+
dot_diag_gemm(const const_tile_data<double> &A,
138+
const const_tile_data<double> &B,
139+
const mutable_tile_data<double> &r,
140+
int N,
141+
int M);
134142

135143
// BLAS level 1 operations
136144

@@ -141,7 +149,7 @@ vector_future dot_diag_gemm(vector_future f_A, vector_future f_B, vector_future
141149
* @param N vector length
142150
* @return y - x
143151
*/
144-
vector_future axpy(vector_future f_y, vector_future f_x, const int N);
152+
mutable_tile_data<double> axpy(const mutable_tile_data<double> &y, const const_tile_data<double> &x, int N);
145153

146154
/**
147155
* @brief FP64 Dot product: a * b
@@ -150,7 +158,7 @@ vector_future axpy(vector_future f_y, vector_future f_x, const int N);
150158
* @param N vector length
151159
* @return a * b
152160
*/
153-
double dot(std::vector<double> a, std::vector<double> b, const int N);
161+
double dot(std::span<const double> a, std::span<const double> b, int N);
154162

155163
GPRAT_NS_END
156164

0 commit comments

Comments
 (0)