Skip to content

Commit fdd6fd4

Browse files
committed
feat(cuda): Enable CUDA backend to call transpose kernels.
1 parent 5614c06 commit fdd6fd4

File tree

2 files changed

+64
-0
lines changed

2 files changed

+64
-0
lines changed

src/backend.f90

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ module m_base_backend
2121
!! architecture.
2222

2323
real(dp) :: nu
24+
integer :: nx_loc, ny_loc, nz_loc
2425
class(allocator_t), pointer :: allocator
2526
class(dirps_t), pointer :: xdirps, ydirps, zdirps
2627
contains

src/cuda/backend.f90

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ module m_cuda_backend
1212
use m_cuda_sendrecv, only: sendrecv_fields, sendrecv_3fields
1313
use m_cuda_tdsops, only: cuda_tdsops_t
1414
use m_cuda_kernels_dist, only: transeq_3fused_dist, transeq_3fused_subs
15+
use m_cuda_kernels_trans, only: trans_x2y_k, trans_x2z_k, trans_y2x_k, &
16+
trans_y2z_k, trans_z2y_k
1517

1618
implicit none
1719

@@ -74,6 +76,10 @@ function init(globs, allocator) result(backend)
7476
backend%zthreads = dim3(SZ, 1, 1)
7577
backend%zblocks = dim3(globs%n_groups_z, 1, 1)
7678

79+
backend%nx_loc = globs%nx_loc
80+
backend%ny_loc = globs%ny_loc
81+
backend%nz_loc = globs%nz_loc
82+
7783
n_halo = 4
7884
n_block = globs%n_groups_x
7985

@@ -422,6 +428,17 @@ subroutine trans_x2y_cuda(self, u_y, u)
422428
class(field_t), intent(inout) :: u_y
423429
class(field_t), intent(in) :: u
424430

431+
real(dp), device, pointer, dimension(:, :, :) :: u_d, u_y_d
432+
type(dim3) :: blocks, threads
433+
434+
select type(u); type is (cuda_field_t); u_d => u%data_d; end select
435+
select type(u_y); type is (cuda_field_t); u_y_d => u_y%data_d; end select
436+
437+
blocks = dim3(self%nx_loc/SZ, self%nz_loc, self%ny_loc/SZ)
438+
threads = dim3(SZ, SZ, 1)
439+
440+
call trans_x2y_k<<<blocks, threads>>>(u_y_d, u_d, self%nz_loc)
441+
425442
end subroutine trans_x2y_cuda
426443

427444
subroutine trans_x2z_cuda(self, u_z, u)
@@ -431,6 +448,17 @@ subroutine trans_x2z_cuda(self, u_z, u)
431448
class(field_t), intent(inout) :: u_z
432449
class(field_t), intent(in) :: u
433450

451+
real(dp), device, pointer, dimension(:, :, :) :: u_d, u_z_d
452+
type(dim3) :: blocks, threads
453+
454+
select type(u); type is (cuda_field_t); u_d => u%data_d; end select
455+
select type(u_z); type is (cuda_field_t); u_z_d => u_z%data_d; end select
456+
457+
blocks = dim3(self%nx_loc, self%ny_loc/SZ, 1)
458+
threads = dim3(SZ, 1, 1)
459+
460+
call trans_x2z_k<<<blocks, threads>>>(u_z_d, u_d, self%nz_loc)
461+
434462
end subroutine trans_x2z_cuda
435463

436464
subroutine trans_y2z_cuda(self, u_z, u_y)
@@ -440,6 +468,18 @@ subroutine trans_y2z_cuda(self, u_z, u_y)
440468
class(field_t), intent(inout) :: u_z
441469
class(field_t), intent(in) :: u_y
442470

471+
real(dp), device, pointer, dimension(:, :, :) :: u_z_d, u_y_d
472+
type(dim3) :: blocks, threads
473+
474+
select type(u_z); type is (cuda_field_t); u_z_d => u_z%data_d; end select
475+
select type(u_y); type is (cuda_field_t); u_y_d => u_y%data_d; end select
476+
477+
blocks = dim3(self%nx_loc/SZ, self%ny_loc/SZ, self%nz_loc)
478+
threads = dim3(SZ, SZ, 1)
479+
480+
call trans_y2z_k<<<blocks, threads>>>(u_z_d, u_y_d, &
481+
self%nx_loc, self%nz_loc)
482+
443483
end subroutine trans_y2z_cuda
444484

445485
subroutine trans_z2y_cuda(self, u_y, u_z)
@@ -449,6 +489,18 @@ subroutine trans_z2y_cuda(self, u_y, u_z)
449489
class(field_t), intent(inout) :: u_y
450490
class(field_t), intent(in) :: u_z
451491

492+
real(dp), device, pointer, dimension(:, :, :) :: u_y_d, u_z_d
493+
type(dim3) :: blocks, threads
494+
495+
select type(u_y); type is (cuda_field_t); u_y_d => u_y%data_d; end select
496+
select type(u_z); type is (cuda_field_t); u_z_d => u_z%data_d; end select
497+
498+
blocks = dim3(self%nx_loc/SZ, self%ny_loc/SZ, self%nz_loc)
499+
threads = dim3(SZ, SZ, 1)
500+
501+
call trans_z2y_k<<<blocks, threads>>>(u_y_d, u_z_d, &
502+
self%nx_loc, self%nz_loc)
503+
452504
end subroutine trans_z2y_cuda
453505

454506
subroutine trans_y2x_cuda(self, u_x, u_y)
@@ -458,6 +510,17 @@ subroutine trans_y2x_cuda(self, u_x, u_y)
458510
class(field_t), intent(inout) :: u_x
459511
class(field_t), intent(in) :: u_y
460512

513+
real(dp), device, pointer, dimension(:, :, :) :: u_x_d, u_y_d
514+
type(dim3) :: blocks, threads
515+
516+
select type(u_x); type is (cuda_field_t); u_x_d => u_x%data_d; end select
517+
select type(u_y); type is (cuda_field_t); u_y_d => u_y%data_d; end select
518+
519+
blocks = dim3(self%nx_loc/SZ, self%ny_loc/SZ, self%nz_loc)
520+
threads = dim3(SZ, SZ, 1)
521+
522+
call trans_y2x_k<<<blocks, threads>>>(u_x_d, u_y_d, self%nz_loc)
523+
461524
end subroutine trans_y2x_cuda
462525

463526
subroutine sum_yzintox_cuda(self, du, dv, dw, &

0 commit comments

Comments
 (0)