Skip to content

Commit bb3f5a4

Browse files
committed
feat(cuda): Extend tdsops_t for CUDA backend.
1 parent 17c1c1b commit bb3f5a4

File tree

2 files changed

+74
-0
lines changed

2 files changed

+74
-0
lines changed

src/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ set(CUDASRC
1515
cuda/common.f90
1616
cuda/cuda_allocator.f90
1717
cuda/kernels_dist.f90
18+
cuda/tdsops.f90
1819
)
1920

2021
if(${CMAKE_Fortran_COMPILER_ID} STREQUAL "PGI")

src/cuda/tdsops.f90

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
module m_cuda_tdsops
2+
use iso_fortran_env, only: stderr => error_unit
3+
4+
use m_common, only: dp
5+
use m_tdsops, only: tdsops_t, tdsops_init
6+
7+
implicit none
8+
9+
type, extends(tdsops_t) :: cuda_tdsops_t
10+
!! CUDA extension of the Tridiagonal Solver Operators class.
11+
!!
12+
!! Regular tdsops_t class is initiated and the coefficient arrays are
13+
!! copied into device arrays so that cuda kernels can use them.
14+
real(dp), device, allocatable :: dist_fw_dev(:), dist_bw_dev(:), &
15+
dist_sa_dev(:), dist_sc_dev(:), &
16+
dist_af_dev(:)
17+
real(dp), device, allocatable :: coeffs_dev(:), &
18+
coeffs_s_dev(:, :), coeffs_e_dev(:, :)
19+
contains
20+
end type cuda_tdsops_t
21+
22+
interface cuda_tdsops_t
23+
module procedure cuda_tdsops_init
24+
end interface cuda_tdsops_t
25+
26+
contains
27+
28+
function cuda_tdsops_init(n, delta, operation, scheme, n_halo, from_to, &
29+
bc_start, bc_end, sym, c_nu, nu0_nu) &
30+
result(tdsops)
31+
!! Constructor function for the cuda_tdsops_t class.
32+
!! See tdsops_t for details.
33+
implicit none
34+
35+
type(cuda_tdsops_t) :: tdsops !! return value of the function
36+
37+
integer, intent(in) :: n
38+
real(dp), intent(in) :: delta
39+
character(*), intent(in) :: operation, scheme
40+
integer, optional, intent(in) :: n_halo
41+
character(*), optional, intent(in) :: from_to, bc_start, bc_end
42+
logical, optional, intent(in) :: sym
43+
real(dp), optional, intent(in) :: c_nu, nu0_nu
44+
45+
integer :: n_stencil
46+
47+
tdsops%tdsops_t = tdsops_init(n, delta, operation, scheme, n_halo, &
48+
from_to, bc_start, bc_end, sym, &
49+
c_nu, nu0_nu)
50+
51+
n_stencil = 2*tdsops%n_halo + 1
52+
53+
allocate(tdsops%dist_fw_dev(n), tdsops%dist_bw_dev(n))
54+
allocate(tdsops%dist_sa_dev(n), tdsops%dist_sc_dev(n))
55+
allocate(tdsops%dist_af_dev(n))
56+
allocate(tdsops%coeffs_dev(n_stencil))
57+
allocate(tdsops%coeffs_s_dev(n_stencil, tdsops%n_halo))
58+
allocate(tdsops%coeffs_e_dev(n_stencil, tdsops%n_halo))
59+
60+
tdsops%dist_fw_dev(:) = tdsops%dist_fw(:)
61+
tdsops%dist_bw_dev(:) = tdsops%dist_bw(:)
62+
tdsops%dist_sa_dev(:) = tdsops%dist_sa(:)
63+
tdsops%dist_sc_dev(:) = tdsops%dist_sc(:)
64+
tdsops%dist_af_dev(:) = tdsops%dist_af(:)
65+
66+
tdsops%coeffs_dev(:) = tdsops%coeffs(:)
67+
tdsops%coeffs_s_dev(:, :) = tdsops%coeffs_s(:, :)
68+
tdsops%coeffs_e_dev(:, :) = tdsops%coeffs_e(:, :)
69+
70+
end function cuda_tdsops_init
71+
72+
end module m_cuda_tdsops
73+

0 commit comments

Comments
 (0)