@@ -14,10 +14,14 @@ module mtl_mod
1414 integer , parameter :: COMM_SEND = 1
1515 integer , parameter :: COMM_RECV = - 1
1616 integer , parameter :: COMM_NONE = 0
17+ integer , parameter :: COMM_FIELD = 1
18+ integer , parameter :: COMM_V = 2
19+ integer , parameter :: COMM_BOTH = 3
1720
1821 type, public :: communicator_t
1922 integer :: field_index = - 1 , v_index = - 1
2023 integer :: comm_task = COMM_NONE
24+ integer :: comm_type = COMM_NONE
2125 integer :: delta_rank = 0
2226 end type
2327 type, public :: comm_t
@@ -180,7 +184,7 @@ function mtl_unshielded(lpul, cpul, rpul, gpul, &
180184 call res% initStepSizeAndFieldSegments(step_size, segments, layer_indices)
181185 call res% initCommunicators(alloc_z)
182186 res% layer_indices = layer_indices
183- res% bundle_in_layer = bundle_in_layer
187+ res% bundle_in_layer = bundle_in_layer
184188 else
185189 res% step_size = step_size
186190 allocate (res% layer_indices(0 ,0 ))
@@ -350,6 +354,7 @@ subroutine initStepSizeAndFieldSegments(this, step_size, segments, layer_indices
350354 if (j /= size (layer_indices,1 )) then
351355 this% step_size(n + layer_indices(j,2 ) - layer_indices(j,1 ) + 1 ) = this% step_size(n + layer_indices(j,2 ) - layer_indices(j,1 ))
352356 this% segments(n + layer_indices(j,2 ) - layer_indices(j,1 ) + 1 ) = this% segments(n + layer_indices(j,2 ) - layer_indices(j,1 ))
357+ this% segments(n + layer_indices(j,2 ) - layer_indices(j,1 ) + 1 )% orientation = - 1
353358 n = n + 1
354359 end if
355360 n = n + layer_indices(j,2 ) - layer_indices(j,1 ) + 1
@@ -361,7 +366,7 @@ subroutine initStepSizeAndFieldSegments(this, step_size, segments, layer_indices
361366 subroutine initCommunicators (this , alloc_z )
362367 class(mtl_t) :: this
363368 integer (kind = 4 ), dimension (2 ) :: alloc_z
364- integer :: j, n
369+ integer :: j, n, z
365370 integer :: rank, ierr
366371 integer (kind = 4 ) :: z_init, z_end
367372 type (communicator_t), dimension (:), allocatable :: aux_comm
@@ -374,7 +379,33 @@ subroutine initCommunicators(this, alloc_z)
374379 z_end = alloc_z(2 )
375380
376381 do j = 1 , size (this% segments)
382+ if (this% segments(j)% orientation == - 1 ) cycle
377383
384+ z = this% segments(j)% z
385+ if (.not. isSegmentZOriented(j) .and. ((z == z_end) .or. (z == z_init + 1 ))) then
386+
387+ n = size (this% mpi_comm% comms)
388+ deallocate (aux_comm)
389+ allocate (aux_comm(n+1 ))
390+ aux_comm(1 :n) = this% mpi_comm% comms
391+
392+ aux_comm(n+1 )% field_index = j
393+ aux_comm(n+1 )% comm_type = COMM_FIELD
394+ aux_comm(n+1 )% v_index = - 1
395+ if (z == z_end) then
396+ aux_comm(n+1 )% delta_rank = 1
397+ aux_comm(n+1 )% comm_task = COMM_RECV
398+ else if (z == z_init + 1 ) then
399+ aux_comm(n+1 )% delta_rank = - 1
400+ aux_comm(n+1 )% comm_task = COMM_SEND
401+ end if
402+
403+ deallocate (this% mpi_comm% comms)
404+ allocate (this% mpi_comm% comms(n+1 ))
405+ this% mpi_comm% comms = aux_comm
406+
407+
408+ end if
378409 if (isSegmentZOriented(j) .and. &
379410 (isSegmentNextToLayerEnd(j,z_end) .or. isSegmentNextToLayerInit(j,z_init))) then
380411
@@ -383,6 +414,7 @@ subroutine initCommunicators(this, alloc_z)
383414 allocate (aux_comm(n+1 ))
384415 aux_comm(1 :n) = this% mpi_comm% comms
385416 aux_comm(n+1 )% field_index = j
417+ aux_comm(n+1 )% comm_type = COMM_BOTH
386418
387419 if (isSegmentNextToLayerEnd(j,z_end)) then
388420 aux_comm(n+1 )% delta_rank = 1
@@ -438,7 +470,12 @@ logical function isSegmentZOriented(j)
438470
439471 logical function isSegmentZPositive (j )
440472 integer , intent (in ) :: j
441- isSegmentZPositive = (this% segments(j)% orientation > 0 )
473+ isSegmentZPositive = (this% segments(j)% orientation == 3 )
474+ end function
475+
476+ logical function isSegmentZNegative (j )
477+ integer , intent (in ) :: j
478+ isSegmentZNegative = (this% segments(j)% orientation == - 3 )
442479 end function
443480
444481 logical function isSegmentBeforeLayerEnd (j , z_end )
@@ -473,14 +510,14 @@ logical function isSegmentNextToLayerEnd(j, z_end)
473510 integer , intent (in ) :: j, z_end
474511 integer :: z
475512 z = this% segments(j)% z
476- isSegmentNextToLayerEnd = (abs (z - z_end)< = 1 )
513+ isSegmentNextToLayerEnd = (z == z_end) .or. (z == z_end - 1 )
477514 end function
478515
479516 logical function isSegmentNextToLayerInit (j , z_init )
480517 integer , intent (in ) :: j, z_init
481518 integer :: z
482519 z = this% segments(j)% z
483- isSegmentNextToLayerInit = (abs (z - z_init-1 ) < = 1 )
520+ isSegmentNextToLayerInit = (z == z_init) .or. (z == z_init + 1 )
484521 end function
485522
486523
0 commit comments