Skip to content

Commit 2c89b9d

Browse files
Fix missing imports in FluxMPIExt (#2589)
* missing import of cpu_device * test AMD GPU availability with functional(AMDGPUDevice) * add missing import for NCCL * test reduction for MPI and NCCL
1 parent 9147e84 commit 2c89b9d

File tree

3 files changed

+33
-2
lines changed

3 files changed

+33
-2
lines changed

ext/FluxMPIExt/FluxMPIExt.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ module FluxMPIExt
33
using Flux: MPIBackend, NCCLBackend, DistributedUtils,
44
MPI_CUDA_AWARE, MPI_ROCM_AWARE
55
using MPI: MPI
6-
using MLDataDevices: AbstractDevice, CUDADevice, AMDGPUDevice, functional, set_device!
6+
using MLDataDevices: AbstractDevice, CUDADevice, AMDGPUDevice, functional, set_device!, cpu_device
77

88

99
function DistributedUtils.__initialize(
@@ -24,7 +24,7 @@ function DistributedUtils.__initialize(
2424
error(lazy"CUDA devices are not functional and `force_cuda` is set to `true`. This is caused by backend: $(caller).")
2525
end
2626

27-
if amdgpu_devices !== missing && AMDGPU.functional()
27+
if amdgpu_devices !== missing && functional(AMDGPUDevice)
2828
if amdgpu_devices === nothing
2929
set_device!(AMDGPUDevice, nothing, local_rank + 1)
3030
else
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
using Flux
2+
using Test
3+
4+
backend_string = ARGS[1]
5+
6+
if backend_string == "mpi"
7+
import MPI
8+
backend_type = MPIBackend
9+
elseif backend_string == "nccl"
10+
import MPI, NCCL, CUDA
11+
backend_type = NCCLBackend
12+
else
13+
error("unsupported backend: $backend_string")
14+
end
15+
16+
DistributedUtils.initialize(backend_type)
17+
backend = DistributedUtils.get_distributed_backend(backend_type)
18+
19+
rank = DistributedUtils.local_rank(backend)
20+
total_workers = DistributedUtils.total_workers(backend)
21+
22+
sendrecvbuf = fill(rank+1,4)
23+
24+
DistributedUtils.reduce!(backend, sendrecvbuf, +)
25+
26+
if rank == 0
27+
@test all(sendrecvbuf .== sum(1:total_workers))
28+
end

test/ext_distributed/runtests.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
# Distributed Tests
22
using MPI, Pkg, Test
3+
if get(ENV, "FLUX_TEST_DISTRIBUTED_NCCL", "false") == "true"
4+
import CUDA
5+
end
36

47
nprocs_str = get(ENV, "JULIA_MPI_TEST_NPROCS", "")
58
nprocs = nprocs_str == "" ? clamp(Sys.CPU_THREADS, 2, 4) : parse(Int, nprocs_str)

0 commit comments

Comments
 (0)