@@ -132,7 +132,8 @@ Assumes the SLURM variables `SLURM_STEP_GPUS` and `GPU_DEVICE_ORDINAL` are
132132defined on the workers.
133133"""
134134function assign_GPU_workers ()
135- topo = @eval Main pmap (workers ()) do i
135+ @everywhere @eval using CUDA, Distributed
136+ topo = @eval Main pmap (workers ()) do _
136137 hostname = gethostname ()
137138 virtgpus = parse .(Int,split (ENV [" GPU_DEVICE_ORDINAL" ]," ," ))
138139 if " SLURM_STEP_GPUS" in keys (ENV )
@@ -143,7 +144,11 @@ function assign_GPU_workers()
143144 # will work if you requested a full node's worth of GPUs at least
144145 physgpus = virtgpus
145146 end
146- (i= i, hostname= hostname, virtgpus= virtgpus, physgpus= physgpus)
147+ if Set (virtgpus)!= Set (deviceid .(devices ()))
148+ @warn " Virtual GPUs not same as CUDA.devices(), using latter"
149+ virtgpus = deviceid .(devices ())
150+ end
151+ (i= myid (), hostname= hostname, virtgpus= virtgpus, physgpus= physgpus)
147152 end
148153 claimed = Set ()
149154 assignments = Dict (map (topo) do (i,hostname,virtgpus,physgpus)
@@ -154,6 +159,6 @@ function assign_GPU_workers()
154159 end
155160 end
156161 end )
157- @everywhere @eval using CUDA, Distributed
158162 @everywhere workers () device! ($ assignments[myid ()])
163+ return topo, claimed, assignments
159164end
0 commit comments