@@ -34,11 +34,7 @@ module m_cuda_backend
3434 procedure :: transeq_y = > transeq_y_cuda
3535 procedure :: transeq_z = > transeq_z_cuda
3636 procedure :: tds_solve = > tds_solve_cuda
37- procedure :: trans_x2y = > trans_x2y_cuda
38- procedure :: trans_x2z = > trans_x2z_cuda
39- procedure :: trans_y2z = > trans_y2z_cuda
40- procedure :: trans_z2y = > trans_z2y_cuda
41- procedure :: trans_y2x = > trans_y2x_cuda
37+ procedure :: trans_d2d = > trans_d2d_cuda
4238 procedure :: sum_yzintox = > sum_yzintox_cuda
4339 procedure :: vecadd = > vecadd_cuda
4440 procedure :: set_fields = > set_fields_cuda
@@ -421,107 +417,50 @@ subroutine tds_solve_dist(self, du, u, dirps, tdsops, blocks, threads)
421417
422418 end subroutine tds_solve_dist
423419
424- subroutine trans_x2y_cuda (self , u_y , u )
420+ subroutine trans_d2d_cuda (self , u_o , u_i , direction )
425421 implicit none
426422
427423 class(cuda_backend_t) :: self
428- class(field_t), intent (inout ) :: u_y
429- class(field_t), intent (in ) :: u
430-
431- real (dp), device, pointer , dimension (:, :, :) :: u_d, u_y_d
432- type (dim3) :: blocks, threads
433-
434- select type (u); type is (cuda_field_t); u_d = > u% data_d; end select
435- select type (u_y); type is (cuda_field_t); u_y_d = > u_y% data_d; end select
436-
437- blocks = dim3(self% nx_loc/ SZ, self% nz_loc, self% ny_loc/ SZ)
438- threads = dim3(SZ, SZ, 1 )
439-
440- call trans_x2y_k<<<blocks, threads>>>(u_y_d, u_d, self% nz_loc)
441-
442- end subroutine trans_x2y_cuda
443-
444- subroutine trans_x2z_cuda (self , u_z , u )
445- implicit none
446-
447- class(cuda_backend_t) :: self
448- class(field_t), intent (inout ) :: u_z
449- class(field_t), intent (in ) :: u
424+ class(field_t), intent (inout ) :: u_o
425+ class(field_t), intent (in ) :: u_i
426+ integer , intent (in ) :: direction
450427
451- real (dp), device, pointer , dimension (:, :, :) :: u_d, u_z_d
428+ real (dp), device, pointer , dimension (:, :, :) :: u_o_d, u_i_d
452429 type (dim3) :: blocks, threads
453430
454- select type (u); type is (cuda_field_t); u_d = > u% data_d; end select
455- select type (u_z); type is (cuda_field_t); u_z_d = > u_z% data_d; end select
456-
457- blocks = dim3(self% nx_loc, self% ny_loc/ SZ, 1 )
458- threads = dim3(SZ, 1 , 1 )
459-
460- call trans_x2z_k<<<blocks, threads>>>(u_z_d, u_d, self% nz_loc)
461-
462- end subroutine trans_x2z_cuda
463-
464- subroutine trans_y2z_cuda (self , u_z , u_y )
465- implicit none
466-
467- class(cuda_backend_t) :: self
468- class(field_t), intent (inout ) :: u_z
469- class(field_t), intent (in ) :: u_y
470-
471- real (dp), device, pointer , dimension (:, :, :) :: u_z_d, u_y_d
472- type (dim3) :: blocks, threads
473-
474- select type (u_z); type is (cuda_field_t); u_z_d = > u_z% data_d; end select
475- select type (u_y); type is (cuda_field_t); u_y_d = > u_y% data_d; end select
476-
477- blocks = dim3(self% nx_loc/ SZ, self% ny_loc/ SZ, self% nz_loc)
478- threads = dim3(SZ, SZ, 1 )
479-
480- call trans_y2z_k<<<blocks, threads>>>(u_z_d, u_y_d, &
481- self% nx_loc, self% nz_loc)
482-
483- end subroutine trans_y2z_cuda
484-
485- subroutine trans_z2y_cuda (self , u_y , u_z )
486- implicit none
487-
488- class(cuda_backend_t) :: self
489- class(field_t), intent (inout ) :: u_y
490- class(field_t), intent (in ) :: u_z
491-
492- real (dp), device, pointer , dimension (:, :, :) :: u_y_d, u_z_d
493- type (dim3) :: blocks, threads
494-
495- select type (u_y); type is (cuda_field_t); u_y_d = > u_y% data_d; end select
496- select type (u_z); type is (cuda_field_t); u_z_d = > u_z% data_d; end select
497-
498- blocks = dim3(self% nx_loc/ SZ, self% ny_loc/ SZ, self% nz_loc)
499- threads = dim3(SZ, SZ, 1 )
500-
501- call trans_z2y_k<<<blocks, threads>>>(u_y_d, u_z_d, &
502- self% nx_loc, self% nz_loc)
503-
504- end subroutine trans_z2y_cuda
505-
506- subroutine trans_y2x_cuda (self , u_x , u_y )
507- implicit none
508-
509- class(cuda_backend_t) :: self
510- class(field_t), intent (inout ) :: u_x
511- class(field_t), intent (in ) :: u_y
512-
513- real (dp), device, pointer , dimension (:, :, :) :: u_x_d, u_y_d
514- type (dim3) :: blocks, threads
515-
516- select type (u_x); type is (cuda_field_t); u_x_d = > u_x% data_d; end select
517- select type (u_y); type is (cuda_field_t); u_y_d = > u_y% data_d; end select
518-
519- blocks = dim3(self% nx_loc/ SZ, self% ny_loc/ SZ, self% nz_loc)
520- threads = dim3(SZ, SZ, 1 )
521-
522- call trans_y2x_k<<<blocks, threads>>>(u_x_d, u_y_d, self% nz_loc)
431+ select type (u_o); type is (cuda_field_t); u_o_d = > u_o% data_d; end select
432+ select type (u_i); type is (cuda_field_t); u_i_d = > u_i% data_d; end select
433+
434+ select case (direction)
435+ case (12 ) ! x2y
436+ blocks = dim3(self% nx_loc/ SZ, self% nz_loc, self% ny_loc/ SZ)
437+ threads = dim3(SZ, SZ, 1 )
438+ call trans_x2y_k<<<blocks, threads>>>(u_o_d, u_i_d, self% nz_loc)
439+ case (13 ) ! x2z
440+ blocks = dim3(self% nx_loc, self% ny_loc/ SZ, 1 )
441+ threads = dim3(SZ, 1 , 1 )
442+ call trans_x2z_k<<<blocks, threads>>>(u_o_d, u_i_d, self% nz_loc)
443+ case (21 ) ! y2x
444+ blocks = dim3(self% nx_loc/ SZ, self% ny_loc/ SZ, self% nz_loc)
445+ threads = dim3(SZ, SZ, 1 )
446+ call trans_y2x_k<<<blocks, threads>>>(u_o_d, u_i_d, self% nz_loc)
447+ case (23 ) ! y2z
448+ blocks = dim3(self% nx_loc/ SZ, self% ny_loc/ SZ, self% nz_loc)
449+ threads = dim3(SZ, SZ, 1 )
450+ call trans_y2z_k<<<blocks, threads>>>(u_o_d, u_i_d, &
451+ self% nx_loc, self% nz_loc)
452+ case (32 ) ! z2y
453+ blocks = dim3(self% nx_loc/ SZ, self% ny_loc/ SZ, self% nz_loc)
454+ threads = dim3(SZ, SZ, 1 )
455+
456+ call trans_z2y_k<<<blocks, threads>>>(u_o_d, u_i_d, &
457+ self% nx_loc, self% nz_loc)
458+ case default
459+ print * , ' Transpose direction is undefined.'
460+ stop
461+ end select
523462
524- end subroutine trans_y2x_cuda
463+ end subroutine trans_d2d_cuda
525464
526465 subroutine sum_yzintox_cuda (self , du , dv , dw , &
527466 du_y , dv_y , dw_y , du_z , dv_z , dw_z )
0 commit comments