Skip to content

Commit 5c0acfe

Browse files
committed
Simplify asm1_gpu_setup implementation
1 parent 9485b74 commit 5c0acfe

File tree

3 files changed

+27
-108
lines changed

3 files changed

+27
-108
lines changed

src/elliptic/box/crs_box.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ static void asm1_setup(struct box *box, double tol, const struct comm *comm) {
223223
asm1_cholmod_setup(A, null_space, box);
224224
break;
225225
case BOX_GPU:
226-
asm1_gpu_setup(A, null_space, box);
226+
asm1_gpu_setup<T>(A, null_space, box);
227227
break;
228228
}
229229

src/elliptic/box/crs_box_gpu.cpp

Lines changed: 25 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -11,46 +11,24 @@ static const float one_f32 = 1.0f, zero_f32 = 0.0f;
1111
static int initialized = 0;
1212
static gs_dom dom;
1313
static int nr = 0;
14-
static void *h_r = NULL, *h_x = NULL;
15-
1614
static double *d_A_inv = NULL;
17-
static float *d_A_inv_f32 = NULL;
18-
static void *d_r = NULL, *d_x = NULL;
19-
20-
static void setup_core(uint nr_) {
21-
nr = nr_;
22-
h_r = calloc(nr, sizeof(double));
23-
h_x = calloc(nr, sizeof(double));
24-
}
2515

26-
static void finalize_core(void) {
27-
free(h_r), h_r = NULL;
28-
free(h_x), h_x = NULL;
29-
nr = 0;
30-
initialized = 0;
31-
}
32-
33-
static void setup_inverse(double **A_inv, float **A_inv_f32, const struct csr *A) {
16+
template <typename T>
17+
static void setup_inverse(T *A_inv, const struct csr *A) {
3418
assert(sizeof(dfloat) == sizeof(double));
3519

36-
int N = A->nr;
20+
const int N = A->nr;
3721
std::vector<dfloat> B(N * N);
3822
for (uint i = 0; i < A->nr; i++) {
39-
for (uint j = A->offs[i]; j < A->offs[i + 1]; j++) {
40-
// B[i * A->nr + A->cols[j] - A->base] = A->vals[j];
23+
for (uint j = A->offs[i]; j < A->offs[i + 1]; j++)
4124
B[(A->cols[j] - A->base) * A->nr + i] = A->vals[j];
42-
}
4325
}
4426

4527
auto invA = platform->linAlg->matrixInverse(N, B);
4628

47-
*A_inv = tcalloc(double, A->nr * A->nr);
48-
*A_inv_f32 = tcalloc(float, A->nr * A->nr);
4929
for (int i = 0; i < N; i++) {
50-
for (int j = 0; j < N; j++) {
51-
(*A_inv)[i * N + j] = (double)invA[j * N + i];
52-
(*A_inv_f32)[i * N + j] = (float)invA[j * N + i];
53-
}
30+
for (int j = 0; j < N; j++)
31+
A_inv[i * N + j] = (T)invA[j * N + i];
5432
}
5533
}
5634

@@ -185,102 +163,43 @@ void asm1_gpu_free(struct box *box) {
185163
#elif defined(ENABLE_ONEMKL)
186164
#include "crs_box_gpu_onemkl.hpp"
187165

166+
template <typename T>
188167
void asm1_gpu_setup(struct csr *A, unsigned null_space, struct box *box) {
189-
assert(null_space == 0);
190-
191168
if (initialized) return;
192169

193-
double *A_inv = 0;
194-
float *A_inv_f32 = 0;
195-
setup_inverse(&A_inv, &A_inv_f32, A);
196-
setup_core(A->nr);
197-
198-
const size_t size = nr * nr;
199-
d_A_inv = box_onemkl_device_malloc<double>(size);
200-
box_onemkl_device_copyto<double>(d_A_inv, A_inv, size);
201-
202-
d_A_inv_f32 = box_onemkl_device_malloc<float>(size);
203-
box_onemkl_device_copyto<float>(d_A_inv_f32, A_inv_f32, size);
170+
assert(null_space == 0);
204171

205-
free(A_inv), free(A_inv_f32);
172+
const size_t size = A->nr * A->nr;
173+
T *A_inv = tcalloc(T, size);
174+
setup_inverse(A_inv, A);
206175

207-
d_r = box_onemkl_device_malloc<double>(nr);
208-
d_x = box_onemkl_device_malloc<double>(nr);
176+
d_A_inv = box_onemkl_device_malloc<T>(size);
177+
box_onemkl_device_copyto<T>(d_A_inv, A_inv, size);
178+
free(A_inv);
209179

210180
initialized = 1;
211181
dom = box->opts.dom;
212-
}
213-
214-
template <typename T>
215-
static void asm1_gpu_solve_aux(T *x, struct box *box, const T *d_A, const T *r) {
216-
T *h_r_T = static_cast<T *>(h_r);
217-
T *d_r_T = static_cast<T *>(d_r);
218-
T *h_x_T = static_cast<T *>(h_x);
219-
T *d_x_T = static_cast<T *>(d_x);
220-
221-
for (uint i = 0; i < nr; i++)
222-
h_r_T[i] = 0;
223-
for (uint i = 0; i < box->sn; i++) {
224-
if (box->u2c[i] >= 0)
225-
h_r_T[box->u2c[i]] += r[i];
226-
}
227-
228-
box_onemkl_device_copyto<T>(d_r_T, h_r_T, nr);
229-
box_onemkl_device_gemv(d_x_T, nr, d_A, d_r_T);
230-
box_onemkl_device_copyfrom<T>(d_x_T, h_x_T, nr);
231-
232-
for (uint i = 0; i < box->sn; i++) {
233-
if (box->u2c[i] >= 0)
234-
x[i] = h_x_T[box->u2c[i]];
235-
else
236-
x[i] = 0;
237-
}
238-
}
239-
240-
void asm1_gpu_solve(void *x, struct box *box, const void *r) {
241-
if (!initialized) {
242-
fprintf(stderr, "oneMKL is not initialized.\n");
243-
MPI_Abort(MPI_COMM_WORLD, EXIT_FAILURE);
244-
}
245-
246-
switch(dom) {
247-
case gs_double:
248-
asm1_gpu_solve_aux<double>((double *)x, box, d_A_inv, (double *)r);
249-
break;
250-
case gs_float:
251-
asm1_gpu_solve_aux<float>((float *)x, box, d_A_inv_f32, (float *)r);
252-
break;
253-
default:
254-
break;
255-
}
182+
nr = A->nr;
256183
}
257184

258185
void asm1_gpu_solve(occa::memory &o_x, struct box *box, occa::memory &o_r) {
259-
if (!initialized) {
260-
fprintf(stderr, "oneMKL is not initialized.\n");
261-
MPI_Abort(MPI_COMM_WORLD, EXIT_FAILURE);
262-
}
186+
if (!initialized) MPI_Abort(MPI_COMM_WORLD, EXIT_FAILURE);
263187

264-
switch(dom) {
265-
case gs_double:
266-
box_onemkl_device_gemv<double>((double *)o_x.ptr(), nr, d_A_inv, (double *)o_r.ptr());
267-
break;
268-
case gs_float:
269-
box_onemkl_device_gemv<float>((float *)o_x.ptr(), nr, d_A_inv_f32, (float *)o_r.ptr());
270-
break;
271-
default:
272-
break;
273-
}
188+
if (box->opts.dom == gs_double)
189+
box_onemkl_device_gemv<double>((double *)o_x.ptr(), nr, (double *)d_A_inv, (double *)o_r.ptr());
190+
else
191+
box_onemkl_device_gemv<float>((float *)o_x.ptr(), nr, (float *)d_A_inv, (float *)o_r.ptr());
274192
}
275193

276194
void asm1_gpu_free(struct box *box) {
277195
box_onemkl_free(static_cast<void *>(d_A_inv));
278-
box_onemkl_free(static_cast<void *>(d_A_inv_f32));
279-
box_onemkl_free(static_cast<void *>(d_r));
280-
box_onemkl_free(static_cast<void *>(d_r));
281-
finalize_core();
282196
}
197+
198+
template void asm1_gpu_setup<float>(struct csr *A, unsigned null_space, struct box *box);
199+
template void asm1_gpu_setup<double>(struct csr *A, unsigned null_space, struct box *box);
200+
283201
#else
202+
284203
void asm1_gpu_setup(struct csr *A, unsigned null_space, struct box *box) {
285204
fprintf(stderr, "GPU BLAS not enabled.\n");
286205
exit(EXIT_FAILURE);

src/elliptic/box/crs_box_impl.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,8 @@ void asm1_cholmod_solve(void *x, struct box *box, const void *r);
5555
void asm1_cholmod_free(struct box *box);
5656

5757
// ASM1: GPU BLAS interface.
58+
template <typename T>
5859
void asm1_gpu_setup(struct csr *A, unsigned null_space, struct box *box);
59-
void asm1_gpu_solve(void *x, struct box *box, const void *r);
6060
void asm1_gpu_solve(occa::memory &o_x, struct box *box, occa::memory &o_r);
6161
void asm1_gpu_free(struct box *box);
6262

0 commit comments

Comments
 (0)