@@ -11,46 +11,24 @@ static const float one_f32 = 1.0f, zero_f32 = 0.0f;
1111static int initialized = 0 ;
1212static gs_dom dom;
1313static int nr = 0 ;
14- static void *h_r = NULL , *h_x = NULL ;
15-
1614static double *d_A_inv = NULL ;
17- static float *d_A_inv_f32 = NULL ;
18- static void *d_r = NULL , *d_x = NULL ;
19-
20- static void setup_core (uint nr_) {
21- nr = nr_;
22- h_r = calloc (nr, sizeof (double ));
23- h_x = calloc (nr, sizeof (double ));
24- }
2515
26- static void finalize_core (void ) {
27- free (h_r), h_r = NULL ;
28- free (h_x), h_x = NULL ;
29- nr = 0 ;
30- initialized = 0 ;
31- }
32-
33- static void setup_inverse (double **A_inv, float **A_inv_f32, const struct csr *A) {
16+ template <typename T>
17+ static void setup_inverse (T *A_inv, const struct csr *A) {
3418 assert (sizeof (dfloat) == sizeof (double ));
3519
36- int N = A->nr ;
20+ const int N = A->nr ;
3721 std::vector<dfloat> B (N * N);
3822 for (uint i = 0 ; i < A->nr ; i++) {
39- for (uint j = A->offs [i]; j < A->offs [i + 1 ]; j++) {
40- // B[i * A->nr + A->cols[j] - A->base] = A->vals[j];
23+ for (uint j = A->offs [i]; j < A->offs [i + 1 ]; j++)
4124 B[(A->cols [j] - A->base ) * A->nr + i] = A->vals [j];
42- }
4325 }
4426
4527 auto invA = platform->linAlg ->matrixInverse (N, B);
4628
47- *A_inv = tcalloc (double , A->nr * A->nr );
48- *A_inv_f32 = tcalloc (float , A->nr * A->nr );
4929 for (int i = 0 ; i < N; i++) {
50- for (int j = 0 ; j < N; j++) {
51- (*A_inv)[i * N + j] = (double )invA[j * N + i];
52- (*A_inv_f32)[i * N + j] = (float )invA[j * N + i];
53- }
30+ for (int j = 0 ; j < N; j++)
31+ A_inv[i * N + j] = (T)invA[j * N + i];
5432 }
5533}
5634
@@ -185,102 +163,43 @@ void asm1_gpu_free(struct box *box) {
185163#elif defined(ENABLE_ONEMKL)
186164#include " crs_box_gpu_onemkl.hpp"
187165
166+ template <typename T>
188167void asm1_gpu_setup (struct csr *A, unsigned null_space, struct box *box) {
189- assert (null_space == 0 );
190-
191168 if (initialized) return ;
192169
193- double *A_inv = 0 ;
194- float *A_inv_f32 = 0 ;
195- setup_inverse (&A_inv, &A_inv_f32, A);
196- setup_core (A->nr );
197-
198- const size_t size = nr * nr;
199- d_A_inv = box_onemkl_device_malloc<double >(size);
200- box_onemkl_device_copyto<double >(d_A_inv, A_inv, size);
201-
202- d_A_inv_f32 = box_onemkl_device_malloc<float >(size);
203- box_onemkl_device_copyto<float >(d_A_inv_f32, A_inv_f32, size);
170+ assert (null_space == 0 );
204171
205- free (A_inv), free (A_inv_f32);
172+ const size_t size = A->nr * A->nr ;
173+ T *A_inv = tcalloc (T, size);
174+ setup_inverse (A_inv, A);
206175
207- d_r = box_onemkl_device_malloc<double >(nr);
208- d_x = box_onemkl_device_malloc<double >(nr);
176+ d_A_inv = box_onemkl_device_malloc<T>(size);
177+ box_onemkl_device_copyto<T>(d_A_inv, A_inv, size);
178+ free (A_inv);
209179
210180 initialized = 1 ;
211181 dom = box->opts .dom ;
212- }
213-
214- template <typename T>
215- static void asm1_gpu_solve_aux (T *x, struct box *box, const T *d_A, const T *r) {
216- T *h_r_T = static_cast <T *>(h_r);
217- T *d_r_T = static_cast <T *>(d_r);
218- T *h_x_T = static_cast <T *>(h_x);
219- T *d_x_T = static_cast <T *>(d_x);
220-
221- for (uint i = 0 ; i < nr; i++)
222- h_r_T[i] = 0 ;
223- for (uint i = 0 ; i < box->sn ; i++) {
224- if (box->u2c [i] >= 0 )
225- h_r_T[box->u2c [i]] += r[i];
226- }
227-
228- box_onemkl_device_copyto<T>(d_r_T, h_r_T, nr);
229- box_onemkl_device_gemv (d_x_T, nr, d_A, d_r_T);
230- box_onemkl_device_copyfrom<T>(d_x_T, h_x_T, nr);
231-
232- for (uint i = 0 ; i < box->sn ; i++) {
233- if (box->u2c [i] >= 0 )
234- x[i] = h_x_T[box->u2c [i]];
235- else
236- x[i] = 0 ;
237- }
238- }
239-
240- void asm1_gpu_solve (void *x, struct box *box, const void *r) {
241- if (!initialized) {
242- fprintf (stderr, " oneMKL is not initialized.\n " );
243- MPI_Abort (MPI_COMM_WORLD, EXIT_FAILURE);
244- }
245-
246- switch (dom) {
247- case gs_double:
248- asm1_gpu_solve_aux<double >((double *)x, box, d_A_inv, (double *)r);
249- break ;
250- case gs_float:
251- asm1_gpu_solve_aux<float >((float *)x, box, d_A_inv_f32, (float *)r);
252- break ;
253- default :
254- break ;
255- }
182+ nr = A->nr ;
256183}
257184
258185void asm1_gpu_solve (occa::memory &o_x, struct box *box, occa::memory &o_r) {
259- if (!initialized) {
260- fprintf (stderr, " oneMKL is not initialized.\n " );
261- MPI_Abort (MPI_COMM_WORLD, EXIT_FAILURE);
262- }
186+ if (!initialized) MPI_Abort (MPI_COMM_WORLD, EXIT_FAILURE);
263187
264- switch (dom) {
265- case gs_double:
266- box_onemkl_device_gemv<double >((double *)o_x.ptr (), nr, d_A_inv, (double *)o_r.ptr ());
267- break ;
268- case gs_float:
269- box_onemkl_device_gemv<float >((float *)o_x.ptr (), nr, d_A_inv_f32, (float *)o_r.ptr ());
270- break ;
271- default :
272- break ;
273- }
188+ if (box->opts .dom == gs_double)
189+ box_onemkl_device_gemv<double >((double *)o_x.ptr (), nr, (double *)d_A_inv, (double *)o_r.ptr ());
190+ else
191+ box_onemkl_device_gemv<float >((float *)o_x.ptr (), nr, (float *)d_A_inv, (float *)o_r.ptr ());
274192}
275193
276194void asm1_gpu_free (struct box *box) {
277195 box_onemkl_free (static_cast <void *>(d_A_inv));
278- box_onemkl_free (static_cast <void *>(d_A_inv_f32));
279- box_onemkl_free (static_cast <void *>(d_r));
280- box_onemkl_free (static_cast <void *>(d_r));
281- finalize_core ();
282196}
197+
198+ template void asm1_gpu_setup<float >(struct csr *A, unsigned null_space, struct box *box);
199+ template void asm1_gpu_setup<double >(struct csr *A, unsigned null_space, struct box *box);
200+
283201#else
202+
284203void asm1_gpu_setup (struct csr *A, unsigned null_space, struct box *box) {
285204 fprintf (stderr, " GPU BLAS not enabled.\n " );
286205 exit (EXIT_FAILURE);
0 commit comments