Skip to content

Commit 72d59ea

Browse files
committed
refactor: Merge all transpose functions.
1 parent fdd6fd4 commit 72d59ea

File tree

4 files changed

+61
-164
lines changed

4 files changed

+61
-164
lines changed

src/backend.f90

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,7 @@ module m_base_backend
2929
procedure(transeq_ders), deferred :: transeq_y
3030
procedure(transeq_ders), deferred :: transeq_z
3131
procedure(tds_solve), deferred :: tds_solve
32-
procedure(trans_d2d), deferred :: trans_x2y
33-
procedure(trans_d2d), deferred :: trans_x2z
34-
procedure(trans_d2d), deferred :: trans_y2z
35-
procedure(trans_d2d), deferred :: trans_z2y
36-
procedure(trans_d2d), deferred :: trans_y2x
32+
procedure(trans_d2d), deferred :: trans_d2d
3733
procedure(sum9into3), deferred :: sum_yzintox
3834
procedure(vecadd), deferred :: vecadd
3935
procedure(get_fields), deferred :: get_fields
@@ -82,7 +78,7 @@ end subroutine tds_solve
8278
end interface
8379

8480
abstract interface
85-
subroutine trans_d2d(self, u_, u)
81+
subroutine trans_d2d(self, u_, u, direction)
8682
!! transposer subroutines are straightforward, they rearrange
8783
!! data into our specialist data structure so that regardless
8884
!! of the direction tridiagonal systems are solved efficiently
@@ -94,6 +90,7 @@ subroutine trans_d2d(self, u_, u)
9490
class(base_backend_t) :: self
9591
class(field_t), intent(inout) :: u_
9692
class(field_t), intent(in) :: u
93+
integer, intent(in) :: direction
9794
end subroutine trans_d2d
9895
end interface
9996

src/cuda/backend.f90

Lines changed: 38 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,7 @@ module m_cuda_backend
3434
procedure :: transeq_y => transeq_y_cuda
3535
procedure :: transeq_z => transeq_z_cuda
3636
procedure :: tds_solve => tds_solve_cuda
37-
procedure :: trans_x2y => trans_x2y_cuda
38-
procedure :: trans_x2z => trans_x2z_cuda
39-
procedure :: trans_y2z => trans_y2z_cuda
40-
procedure :: trans_z2y => trans_z2y_cuda
41-
procedure :: trans_y2x => trans_y2x_cuda
37+
procedure :: trans_d2d => trans_d2d_cuda
4238
procedure :: sum_yzintox => sum_yzintox_cuda
4339
procedure :: vecadd => vecadd_cuda
4440
procedure :: set_fields => set_fields_cuda
@@ -421,107 +417,50 @@ subroutine tds_solve_dist(self, du, u, dirps, tdsops, blocks, threads)
421417

422418
end subroutine tds_solve_dist
423419

424-
subroutine trans_x2y_cuda(self, u_y, u)
420+
subroutine trans_d2d_cuda(self, u_o, u_i, direction)
425421
implicit none
426422

427423
class(cuda_backend_t) :: self
428-
class(field_t), intent(inout) :: u_y
429-
class(field_t), intent(in) :: u
430-
431-
real(dp), device, pointer, dimension(:, :, :) :: u_d, u_y_d
432-
type(dim3) :: blocks, threads
433-
434-
select type(u); type is (cuda_field_t); u_d => u%data_d; end select
435-
select type(u_y); type is (cuda_field_t); u_y_d => u_y%data_d; end select
436-
437-
blocks = dim3(self%nx_loc/SZ, self%nz_loc, self%ny_loc/SZ)
438-
threads = dim3(SZ, SZ, 1)
439-
440-
call trans_x2y_k<<<blocks, threads>>>(u_y_d, u_d, self%nz_loc)
441-
442-
end subroutine trans_x2y_cuda
443-
444-
subroutine trans_x2z_cuda(self, u_z, u)
445-
implicit none
446-
447-
class(cuda_backend_t) :: self
448-
class(field_t), intent(inout) :: u_z
449-
class(field_t), intent(in) :: u
424+
class(field_t), intent(inout) :: u_o
425+
class(field_t), intent(in) :: u_i
426+
integer, intent(in) :: direction
450427

451-
real(dp), device, pointer, dimension(:, :, :) :: u_d, u_z_d
428+
real(dp), device, pointer, dimension(:, :, :) :: u_o_d, u_i_d
452429
type(dim3) :: blocks, threads
453430

454-
select type(u); type is (cuda_field_t); u_d => u%data_d; end select
455-
select type(u_z); type is (cuda_field_t); u_z_d => u_z%data_d; end select
456-
457-
blocks = dim3(self%nx_loc, self%ny_loc/SZ, 1)
458-
threads = dim3(SZ, 1, 1)
459-
460-
call trans_x2z_k<<<blocks, threads>>>(u_z_d, u_d, self%nz_loc)
461-
462-
end subroutine trans_x2z_cuda
463-
464-
subroutine trans_y2z_cuda(self, u_z, u_y)
465-
implicit none
466-
467-
class(cuda_backend_t) :: self
468-
class(field_t), intent(inout) :: u_z
469-
class(field_t), intent(in) :: u_y
470-
471-
real(dp), device, pointer, dimension(:, :, :) :: u_z_d, u_y_d
472-
type(dim3) :: blocks, threads
473-
474-
select type(u_z); type is (cuda_field_t); u_z_d => u_z%data_d; end select
475-
select type(u_y); type is (cuda_field_t); u_y_d => u_y%data_d; end select
476-
477-
blocks = dim3(self%nx_loc/SZ, self%ny_loc/SZ, self%nz_loc)
478-
threads = dim3(SZ, SZ, 1)
479-
480-
call trans_y2z_k<<<blocks, threads>>>(u_z_d, u_y_d, &
481-
self%nx_loc, self%nz_loc)
482-
483-
end subroutine trans_y2z_cuda
484-
485-
subroutine trans_z2y_cuda(self, u_y, u_z)
486-
implicit none
487-
488-
class(cuda_backend_t) :: self
489-
class(field_t), intent(inout) :: u_y
490-
class(field_t), intent(in) :: u_z
491-
492-
real(dp), device, pointer, dimension(:, :, :) :: u_y_d, u_z_d
493-
type(dim3) :: blocks, threads
494-
495-
select type(u_y); type is (cuda_field_t); u_y_d => u_y%data_d; end select
496-
select type(u_z); type is (cuda_field_t); u_z_d => u_z%data_d; end select
497-
498-
blocks = dim3(self%nx_loc/SZ, self%ny_loc/SZ, self%nz_loc)
499-
threads = dim3(SZ, SZ, 1)
500-
501-
call trans_z2y_k<<<blocks, threads>>>(u_y_d, u_z_d, &
502-
self%nx_loc, self%nz_loc)
503-
504-
end subroutine trans_z2y_cuda
505-
506-
subroutine trans_y2x_cuda(self, u_x, u_y)
507-
implicit none
508-
509-
class(cuda_backend_t) :: self
510-
class(field_t), intent(inout) :: u_x
511-
class(field_t), intent(in) :: u_y
512-
513-
real(dp), device, pointer, dimension(:, :, :) :: u_x_d, u_y_d
514-
type(dim3) :: blocks, threads
515-
516-
select type(u_x); type is (cuda_field_t); u_x_d => u_x%data_d; end select
517-
select type(u_y); type is (cuda_field_t); u_y_d => u_y%data_d; end select
518-
519-
blocks = dim3(self%nx_loc/SZ, self%ny_loc/SZ, self%nz_loc)
520-
threads = dim3(SZ, SZ, 1)
521-
522-
call trans_y2x_k<<<blocks, threads>>>(u_x_d, u_y_d, self%nz_loc)
431+
select type(u_o); type is (cuda_field_t); u_o_d => u_o%data_d; end select
432+
select type(u_i); type is (cuda_field_t); u_i_d => u_i%data_d; end select
433+
434+
select case (direction)
435+
case (12) ! x2y
436+
blocks = dim3(self%nx_loc/SZ, self%nz_loc, self%ny_loc/SZ)
437+
threads = dim3(SZ, SZ, 1)
438+
call trans_x2y_k<<<blocks, threads>>>(u_o_d, u_i_d, self%nz_loc)
439+
case (13) ! x2z
440+
blocks = dim3(self%nx_loc, self%ny_loc/SZ, 1)
441+
threads = dim3(SZ, 1, 1)
442+
call trans_x2z_k<<<blocks, threads>>>(u_o_d, u_i_d, self%nz_loc)
443+
case (21) ! y2x
444+
blocks = dim3(self%nx_loc/SZ, self%ny_loc/SZ, self%nz_loc)
445+
threads = dim3(SZ, SZ, 1)
446+
call trans_y2x_k<<<blocks, threads>>>(u_o_d, u_i_d, self%nz_loc)
447+
case (23) ! y2z
448+
blocks = dim3(self%nx_loc/SZ, self%ny_loc/SZ, self%nz_loc)
449+
threads = dim3(SZ, SZ, 1)
450+
call trans_y2z_k<<<blocks, threads>>>(u_o_d, u_i_d, &
451+
self%nx_loc, self%nz_loc)
452+
case (32) ! z2y
453+
blocks = dim3(self%nx_loc/SZ, self%ny_loc/SZ, self%nz_loc)
454+
threads = dim3(SZ, SZ, 1)
455+
456+
call trans_z2y_k<<<blocks, threads>>>(u_o_d, u_i_d, &
457+
self%nx_loc, self%nz_loc)
458+
case default
459+
print *, 'Transpose direction is undefined.'
460+
stop
461+
end select
523462

524-
end subroutine trans_y2x_cuda
463+
end subroutine trans_d2d_cuda
525464

526465
subroutine sum_yzintox_cuda(self, du, dv, dw, &
527466
du_y, dv_y, dw_y, du_z, dv_z, dw_z)

src/omp/backend.f90

Lines changed: 4 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,7 @@ module m_omp_backend
2424
procedure :: transeq_y => transeq_y_omp
2525
procedure :: transeq_z => transeq_z_omp
2626
procedure :: tds_solve => tds_solve_omp
27-
procedure :: trans_x2y => trans_x2y_omp
28-
procedure :: trans_x2z => trans_x2z_omp
29-
procedure :: trans_y2z => trans_y2z_omp
30-
procedure :: trans_z2y => trans_z2y_omp
31-
procedure :: trans_y2x => trans_y2x_omp
27+
procedure :: trans_d2d => trans_d2d_omp
3228
procedure :: sum_yzintox => sum_yzintox_omp
3329
procedure :: vecadd => vecadd_omp
3430
procedure :: set_fields => set_fields_omp
@@ -164,50 +160,15 @@ subroutine tds_solve_omp(self, du, u, dirps, tdsops)
164160

165161
end subroutine tds_solve_omp
166162

167-
subroutine trans_x2y_omp(self, u_, u)
163+
subroutine trans_d2d_omp(self, u_, u, direction)
168164
implicit none
169165

170166
class(omp_backend_t) :: self
171167
class(field_t), intent(inout) :: u_
172168
class(field_t), intent(in) :: u
169+
integer, intent(in) :: direction
173170

174-
end subroutine trans_x2y_omp
175-
176-
subroutine trans_x2z_omp(self, u_, u)
177-
implicit none
178-
179-
class(omp_backend_t) :: self
180-
class(field_t), intent(inout) :: u_
181-
class(field_t), intent(in) :: u
182-
183-
end subroutine trans_x2z_omp
184-
185-
subroutine trans_y2z_omp(self, u_, u)
186-
implicit none
187-
188-
class(omp_backend_t) :: self
189-
class(field_t), intent(inout) :: u_
190-
class(field_t), intent(in) :: u
191-
192-
end subroutine trans_y2z_omp
193-
194-
subroutine trans_z2y_omp(self, u_, u)
195-
implicit none
196-
197-
class(omp_backend_t) :: self
198-
class(field_t), intent(inout) :: u_
199-
class(field_t), intent(in) :: u
200-
201-
end subroutine trans_z2y_omp
202-
203-
subroutine trans_y2x_omp(self, u_, u)
204-
implicit none
205-
206-
class(omp_backend_t) :: self
207-
class(field_t), intent(inout) :: u_
208-
class(field_t), intent(in) :: u
209-
210-
end subroutine trans_y2x_omp
171+
end subroutine trans_d2d_omp
211172

212173
subroutine sum_yzintox_omp(self, du, dv, dw, &
213174
du_y, dv_y, dw_y, du_z, dv_z, dw_z)

src/solver.f90

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -185,9 +185,9 @@ subroutine transeq(self, du, dv, dw, u, v, w)
185185
dw_y => self%backend%allocator%get_block()
186186

187187
! reorder data from x orientation to y orientation
188-
call self%backend%trans_x2y(u_y, u)
189-
call self%backend%trans_x2y(v_y, v)
190-
call self%backend%trans_x2y(w_y, w)
188+
call self%backend%trans_d2d(u_y, u, 12)
189+
call self%backend%trans_d2d(v_y, v, 12)
190+
call self%backend%trans_d2d(w_y, w, 12)
191191

192192
! similar to the x direction, obtain derivatives in y.
193193
call self%backend%transeq_y(du_y, dv_y, dw_y, u_y, v_y, w_y, self%ydirps)
@@ -209,9 +209,9 @@ subroutine transeq(self, du, dv, dw, u, v, w)
209209
dw_z => self%backend%allocator%get_block()
210210

211211
! reorder from x to z
212-
call self%backend%trans_x2z(u_z, u)
213-
call self%backend%trans_x2z(v_z, v)
214-
call self%backend%trans_x2z(w_z, w)
212+
call self%backend%trans_d2d(u_z, u, 13)
213+
call self%backend%trans_d2d(v_z, v, 13)
214+
call self%backend%trans_d2d(w_z, w, 13)
215215

216216
! get the derivatives in z
217217
call self%backend%transeq_z(du_z, dv_z, dw_z, u_z, v_z, w_z, self%zdirps)
@@ -267,9 +267,9 @@ subroutine divergence(self, div_u, u, v, w)
267267
w_y => self%backend%allocator%get_block()
268268

269269
! reorder data from x orientation to y orientation
270-
call self%backend%trans_x2y(u_y, du_x)
271-
call self%backend%trans_x2y(v_y, dv_x)
272-
call self%backend%trans_x2y(w_y, dw_x)
270+
call self%backend%trans_d2d(u_y, du_x, 12)
271+
call self%backend%trans_d2d(v_y, dv_x, 12)
272+
call self%backend%trans_d2d(w_y, dw_x, 12)
273273

274274
call self%backend%allocator%release_block(du_x)
275275
call self%backend%allocator%release_block(dv_x)
@@ -303,8 +303,8 @@ subroutine divergence(self, div_u, u, v, w)
303303
call self%backend%vecadd(1._dp, dw_y, 1._dp, dv_y)
304304

305305
! reorder from y to z
306-
call self%backend%trans_y2z(u_z, du_y)
307-
call self%backend%trans_y2z(w_z, dw_y)
306+
call self%backend%trans_d2d(u_z, du_y, 23)
307+
call self%backend%trans_d2d(w_z, dw_y, 23)
308308

309309
! release all the unnecessary blocks.
310310
call self%backend%allocator%release_block(du_y)
@@ -358,8 +358,8 @@ subroutine gradient(self, dpdx, dpdy, dpdz, pressure)
358358
dpdz_sxy_y => self%backend%allocator%get_block()
359359

360360
! reorder data from z orientation to y orientation
361-
call self%backend%trans_z2y(p_sxy_y, p_sxy_z)
362-
call self%backend%trans_z2y(dpdz_sxy_y, dpdz_sxy_z)
361+
call self%backend%trans_d2d(p_sxy_y, p_sxy_z, 32)
362+
call self%backend%trans_d2d(dpdz_sxy_y, dpdz_sxy_z, 32)
363363

364364
call self%backend%allocator%release_block(p_sxy_z)
365365
call self%backend%allocator%release_block(dpdz_sxy_z)
@@ -386,9 +386,9 @@ subroutine gradient(self, dpdx, dpdy, dpdz, pressure)
386386
dpdz_sx_x => self%backend%allocator%get_block()
387387

388388
! reorder from y to x
389-
call self%backend%trans_y2x(p_sx_x, p_sx_y)
390-
call self%backend%trans_y2x(dpdy_sx_x, dpdy_sx_y)
391-
call self%backend%trans_y2x(dpdz_sx_x, dpdz_sx_y)
389+
call self%backend%trans_d2d(p_sx_x, p_sx_y, 21)
390+
call self%backend%trans_d2d(dpdy_sx_x, dpdy_sx_y, 21)
391+
call self%backend%trans_d2d(dpdz_sx_x, dpdz_sx_y, 21)
392392

393393
! release all the y directional fields.
394394
call self%backend%allocator%release_block(p_sx_y)

0 commit comments

Comments
 (0)