Skip to content

Commit ddbca8b

Browse files
Fix ranks not seeing the correct device (#4827)
* fix device! * couple of fixes * add a test * Bump patch release version 0.99.3 * Update test_distributed_architectures.jl * Modify ROCGPU device function for AMD compatibility Update device function to use device_id! for AMD devices. * Update architecture check for CUDAGPU * remove stray spaces * fix * test the node rank * add communicator = MPI.COMM_WORLD * Update test_distributed_architectures.jl --------- Co-authored-by: Navid C. Constantinou <[email protected]>
1 parent 31f6436 commit ddbca8b

File tree

9 files changed

+33
-8
lines changed

9 files changed

+33
-8
lines changed

ext/OceananigansAMDGPUExt.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ const ROCGPU = AC.GPU{ROCBackend}
3535
ROCGPU() = AC.GPU(AMDGPU.ROCBackend())
3636

3737
Base.summary(::ROCGPU) = "ROCGPU"
38+
AC.device!(::ROCGPU, i) = AMDGPU.device_id!(id+1) # AMD devices are numbered 1..ndevices
3839

3940
AC.architecture(::ROCArray) = ROCGPU()
4041
AC.architecture(::Type{ROCArray}) = ROCGPU()

ext/OceananigansCUDAExt.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ function UT.versioninfo_with_gpu(::CUDAGPU)
5454
end
5555

5656
Base.summary(::CUDAGPU) = "CUDAGPU"
57+
AC.device!(::CUDAGPU, i) = CUDA.device!(i)
5758

5859
AC.architecture(::CuArray) = CUDAGPU()
5960
AC.architecture(::Type{CuArray}) = CUDAGPU()

src/Architectures.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,8 @@ struct ReactantState <: AbstractSerialArchitecture end
5858

5959
device(a::CPU) = KA.CPU()
6060
device(a::GPU) = a.device
61-
device!(::CPU, i) = KA.device!(CPU(), i+1)
61+
device!(::CPU, i) = nothing
6262
device!(::CPU) = nothing
63-
device!(a::GPU, i) = KA.device!(a.device, i+1)
6463
ndevices(a::CPU) = KA.ndevices(KA.CPU())
6564
ndevices(a::AbstractArchitecture) = KA.ndevices(a.device)
6665
synchronize(a::CPU) = KA.synchronize(KA.CPU())

src/DistributedComputations/distributed_architectures.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
using Oceananigans.Architectures
22
using Oceananigans.Grids: topology, validate_tupled_argument
33

4-
import Oceananigans.Architectures: device, cpu_architecture, on_architecture, array_type, child_architecture, convert_to_device
4+
import Oceananigans.Architectures: device, device!, cpu_architecture, on_architecture, array_type, child_architecture, convert_to_device
55
import Oceananigans.Grids: zeros
66
import Oceananigans.Utils: sync_device!, tupleit
77

src/DistributedComputations/distributed_grids.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ function LatitudeLongitudeGrid(arch::Distributed,
130130
z,
131131
topology = nothing,
132132
radius = R_Earth,
133-
halo = (1, 1, 1))
133+
halo = nothing)
134134

135135
topology, global_sz, halo, latitude, longitude, z, precompute_metrics =
136136
validate_lat_lon_grid_args(topology, size, halo, FT, latitude, longitude, z, precompute_metrics)

src/Models/HydrostaticFreeSurfaceModels/compute_hydrostatic_free_surface_buffers.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ function compute_buffer_tendency_contributions!(grid::DistributedActiveInteriorI
4848

4949
# If the map == nothing, we don't need to compute the buffer because
5050
# the buffer is not adjacent to a processor boundary
51-
!isnothing(map) && compute_hydrostatic_free_surface_tendency_contributions!(model, :xyz; active_cells_map)
51+
!isnothing(active_cells_map) && compute_hydrostatic_free_surface_tendency_contributions!(model, :xyz; active_cells_map)
5252
end
5353

5454
return nothing

src/Utils/kernel_launching.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -351,7 +351,7 @@ end
351351

352352
# launching with an empty tuple has no effect
353353
@inline function launch!(arch, grid, workspec_tuple::Tuple{}, kernel, args...; kwargs...)
354-
@warn "trying to launch kernel $kernel! with workspec == (). The kernel will not be launched."
354+
@warn "trying to launch kernel $kernel with workspec == (). The kernel will not be launched."
355355
return nothing
356356
end
357357

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,7 @@ CUDA.allowscalar() do
169169
MPI.Initialized() || MPI.Init()
170170
# In case CUDA is not found, we reset CUDA and restart the julia session
171171
reset_cuda_if_necessary()
172+
include("test_distributed_architectures.jl")
172173
include("test_distributed_models.jl")
173174
end
174175

@@ -178,7 +179,6 @@ CUDA.allowscalar() do
178179
reset_cuda_if_necessary()
179180
include("test_distributed_transpose.jl")
180181
include("test_distributed_poisson_solvers.jl")
181-
include("test_distributed_macros.jl")
182182
end
183183

184184
if group == :distributed_hydrostatic_model || group == :all

test/test_distributed_macros.jl renamed to test/test_distributed_architectures.jl

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
1+
include("dependencies_for_runtests.jl")
2+
13
using MPI
24
using Oceananigans.DistributedComputations
5+
using CUDA
36

4-
@testset begin
7+
@testset "Distributed macros" begin
58
rank = MPI.Comm_rank(MPI.COMM_WORLD)
69

710
@onrank 0 begin
@@ -57,3 +60,24 @@ using Oceananigans.DistributedComputations
5760
@onrank split_comm 0 @test a == [1, 3, 5, 7, 9]
5861
@onrank split_comm 1 @test a == [2, 4, 6, 8, 10]
5962
end
63+
64+
#=
65+
@testset "Distributed architectures" begin
66+
for arch in test_architectures()
67+
child_arch = child_architecture(arch)
68+
69+
communicator = MPI.COMM_WORLD
70+
71+
if child_arch isa Oceananigans.Architectures.GPU
72+
# Check that no device is the same!
73+
local_comm = MPI.Comm_split_type(communicator, MPI.COMM_TYPE_SHARED, arch.local_rank)
74+
node_rank = MPI.Comm_rank(local_comm)
75+
device_number = CUDA.device().handle
76+
# We are testing on the same node, therefore we can
77+
# assume the GPU number changes with the rank
78+
@test node_rank == device_number
79+
end
80+
end
81+
end
82+
=#
83+

0 commit comments

Comments
 (0)