Skip to content

Commit 88ef981

Browse files
committed
Fixes in asm1 hipBLAS implementation
1 parent 238d8ac commit 88ef981

File tree

1 file changed

+11
-37
lines changed

1 file changed

+11
-37
lines changed

src/elliptic/box/crs_box_gpu.cpp

Lines changed: 11 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@ static void setup_inverse(T *A_inv, const struct csr *A) {
4646
}
4747

4848
static hipblasHandle_t handle = NULL;
49-
void *h_r, *h_x;
5049
void *d_r, *d_x;
5150

5251
template <typename T>
@@ -61,8 +60,6 @@ void asm1_gpu_setup(struct csr *A, unsigned null_space, struct box *box) {
6160
hipMemcpyHostToDevice));
6261
free(A_inv);
6362

64-
h_r = calloc(A->nr, sizeof(T));
65-
h_x = calloc(A->nr, sizeof(T));
6663
check_hip_runtime(hipMalloc(&d_r, A->nr * sizeof(T)));
6764
check_hip_runtime(hipMalloc(&d_x, A->nr * sizeof(T)));
6865

@@ -75,53 +72,30 @@ void asm1_gpu_setup(struct csr *A, unsigned null_space, struct box *box) {
7572

7673
template <typename T>
7774
void box_hipblas(T *x, struct box *box, const T *r) {
78-
T *h_r_ = (T *)h_r;
79-
for (uint i = 0; i < nr; i++)
80-
h_r_[i] = 0;
81-
for (uint i = 0; i < box->sn; i++) {
82-
if (box->u2c[i] >= 0)
83-
h_r_[box->u2c[i]] += r[i];
84-
}
85-
86-
check_hip_runtime(
87-
hipMemcpy(d_r, h_r_, nr * sizeof(T), hipMemcpyHostToDevice));
88-
89-
// FIXME: hibblasSgemv, one_f32
90-
hipblasSgemv(handle, HIPBLAS_OP_T, nr, nr, &one_f32, d_A_inv_f32, nr,
91-
(T *)d_r, 1, &zero_f32, (T *)d_x, 1);
92-
93-
check_hip_runtime(
94-
hipMemcpy(h_x, d_x, nr * sizeof(T), hipMemcpyDeviceToHost));
95-
96-
T *h_x_ = (T *)h_x;
97-
for (uint i = 0; i < box->sn; i++) {
98-
if (box->u2c[i] >= 0)
99-
x[i] = h_x_[box->u2c[i]];
100-
else
101-
x[i] = 0;
75+
if (sizeof(T) == sizeof(float)) {
76+
hipblasSgemv(handle, HIPBLAS_OP_T, nr, nr, &one_f32, (float *)d_A_inv, nr,
77+
(float *)d_r, 1, &zero_f32, (float *)d_x, 1);
78+
} else if (sizeof(T) == sizeof(double)) {
79+
hipblasDgemv(handle, HIPBLAS_OP_T, nr, nr, &one, (double *)d_A_inv, nr,
80+
(double *)d_r, 1, &zero, (double *)d_x, 1);
10281
}
10382
}
10483

10584
void asm1_gpu_solve(occa::memory &o_x, struct box *box, occa::memory &o_r) {
10685
if (!initialized) MPI_Abort(MPI_COMM_WORLD, EXIT_FAILURE);
10786

108-
if (box->opts.dom == gs_double) {
109-
fprintf(stderr, "GPU double BLAS not supported.\n");
110-
exit(EXIT_FAILURE);
111-
} else {
112-
box_hipblas<float>((float *)o_x.ptr(), nr, (float *)d_A_inv, (float *)o_r.ptr());
113-
}
87+
if (box->opts.dom == gs_double)
88+
box_hipblas<double>((double *)o_x.ptr(), box, (double *)o_r.ptr());
89+
else
90+
box_hipblas<float>((float *)o_x.ptr(), box, (float *)o_r.ptr());
11491
}
11592

11693
void asm1_gpu_free(struct box *box) {
11794
hipblasDestroy(handle);
11895
check_hip_runtime(hipFree(d_A_inv));
11996
check_hip_runtime(hipFree(d_r));
12097
check_hip_runtime(hipFree(d_x));
121-
free(h_r), h_r = NULL;
122-
free(h_x), h_x = NULL;
123-
nr = 0;
124-
initialized = 0;
98+
nr = 0, initialized = 0;
12599
}
126100

127101
#elif defined(ENABLE_BOX_ONEMKL)

0 commit comments

Comments
 (0)