Skip to content

Commit 475b34c

Browse files
committed
fix assign_GPU_workers
1 parent ed452da commit 475b34c

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

src/gpu.jl

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,8 @@ Assumes the SLURM variables `SLURM_STEP_GPUS` and `GPU_DEVICE_ORDINAL` are
132132
defined on the workers.
133133
"""
134134
function assign_GPU_workers()
135-
topo = @eval Main pmap(workers()) do i
135+
@everywhere @eval using CUDA, Distributed
136+
topo = @eval Main pmap(workers()) do _
136137
hostname = gethostname()
137138
virtgpus = parse.(Int,split(ENV["GPU_DEVICE_ORDINAL"],","))
138139
if "SLURM_STEP_GPUS" in keys(ENV)
@@ -143,7 +144,11 @@ function assign_GPU_workers()
143144
# will work if you requested a full node's worth of GPUs at least
144145
physgpus = virtgpus
145146
end
146-
(i=i, hostname=hostname, virtgpus=virtgpus, physgpus=physgpus)
147+
if Set(virtgpus)!=Set(deviceid.(devices()))
148+
@warn "Virtual GPUs not same as CUDA.devices(), using latter"
149+
virtgpus = deviceid.(devices())
150+
end
151+
(i=myid(), hostname=hostname, virtgpus=virtgpus, physgpus=physgpus)
147152
end
148153
claimed = Set()
149154
assignments = Dict(map(topo) do (i,hostname,virtgpus,physgpus)
@@ -154,6 +159,6 @@ function assign_GPU_workers()
154159
end
155160
end
156161
end)
157-
@everywhere @eval using CUDA, Distributed
158162
@everywhere workers() device!($assignments[myid()])
163+
return topo, claimed, assignments
159164
end

0 commit comments

Comments
 (0)