@@ -46,7 +46,6 @@ static void setup_inverse(T *A_inv, const struct csr *A) {
4646 }
4747
4848static hipblasHandle_t handle = NULL ;
49- void *h_r, *h_x;
5049void *d_r, *d_x;
5150
5251template <typename T>
@@ -61,8 +60,6 @@ void asm1_gpu_setup(struct csr *A, unsigned null_space, struct box *box) {
6160 hipMemcpyHostToDevice));
6261 free (A_inv);
6362
64- h_r = calloc (A->nr , sizeof (T));
65- h_x = calloc (A->nr , sizeof (T));
6663 check_hip_runtime (hipMalloc (&d_r, A->nr * sizeof (T)));
6764 check_hip_runtime (hipMalloc (&d_x, A->nr * sizeof (T)));
6865
@@ -75,53 +72,30 @@ void asm1_gpu_setup(struct csr *A, unsigned null_space, struct box *box) {
7572
7673template <typename T>
7774void box_hipblas (T *x, struct box *box, const T *r) {
78- T *h_r_ = (T *)h_r;
79- for (uint i = 0 ; i < nr; i++)
80- h_r_[i] = 0 ;
81- for (uint i = 0 ; i < box->sn ; i++) {
82- if (box->u2c [i] >= 0 )
83- h_r_[box->u2c [i]] += r[i];
84- }
85-
86- check_hip_runtime (
87- hipMemcpy (d_r, h_r_, nr * sizeof (T), hipMemcpyHostToDevice));
88-
89- // FIXME: hibblasSgemv, one_f32
90- hipblasSgemv (handle, HIPBLAS_OP_T, nr, nr, &one_f32, d_A_inv_f32, nr,
91- (T *)d_r, 1 , &zero_f32, (T *)d_x, 1 );
92-
93- check_hip_runtime (
94- hipMemcpy (h_x, d_x, nr * sizeof (T), hipMemcpyDeviceToHost));
95-
96- T *h_x_ = (T *)h_x;
97- for (uint i = 0 ; i < box->sn ; i++) {
98- if (box->u2c [i] >= 0 )
99- x[i] = h_x_[box->u2c [i]];
100- else
101- x[i] = 0 ;
75+ if (sizeof (T) == sizeof (float )) {
76+ hipblasSgemv (handle, HIPBLAS_OP_T, nr, nr, &one_f32, (float *)d_A_inv, nr,
77+ (float *)d_r, 1 , &zero_f32, (float *)d_x, 1 );
78+ } else if (sizeof (T) == sizeof (double )) {
79+ hipblasDgemv (handle, HIPBLAS_OP_T, nr, nr, &one, (double *)d_A_inv, nr,
80+ (double *)d_r, 1 , &zero, (double *)d_x, 1 );
10281 }
10382}
10483
10584void asm1_gpu_solve (occa::memory &o_x, struct box *box, occa::memory &o_r) {
10685 if (!initialized) MPI_Abort (MPI_COMM_WORLD, EXIT_FAILURE);
10786
108- if (box->opts .dom == gs_double) {
109- fprintf (stderr, " GPU double BLAS not supported.\n " );
110- exit (EXIT_FAILURE);
111- } else {
112- box_hipblas<float >((float *)o_x.ptr (), nr, (float *)d_A_inv, (float *)o_r.ptr ());
113- }
87+ if (box->opts .dom == gs_double)
88+ box_hipblas<double >((double *)o_x.ptr (), box, (double *)o_r.ptr ());
89+ else
90+ box_hipblas<float >((float *)o_x.ptr (), box, (float *)o_r.ptr ());
11491}
11592
11693void asm1_gpu_free (struct box *box) {
11794 hipblasDestroy (handle);
11895 check_hip_runtime (hipFree (d_A_inv));
11996 check_hip_runtime (hipFree (d_r));
12097 check_hip_runtime (hipFree (d_x));
121- free (h_r), h_r = NULL ;
122- free (h_x), h_x = NULL ;
123- nr = 0 ;
124- initialized = 0 ;
98+ nr = 0 , initialized = 0 ;
12599}
126100
127101#elif defined(ENABLE_BOX_ONEMKL)
0 commit comments