Skip to content

Commit ef3f14c

Browse files
authored
Merge pull request #620 from JuliaParallel/jps/setindex-no-gpu
Fix DArray setindex! scope, and support threads=: in Dagger.scope()
2 parents 348b2a5 + 041d571 commit ef3f14c

File tree

3 files changed

+64
-26
lines changed

3 files changed

+64
-26
lines changed

src/array/indexing.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ function Base.setindex!(A::DArray{T,N}, value, idx::NTuple{N,Int}) where {T,N}
127127
# Set the value
128128
part = A.chunks[part_idx...]
129129
space = memory_space(part)
130-
scope = Dagger.scope(worker=root_worker_id(space))
130+
scope = UnionScope(map(ExactScope, collect(processors(space))))
131131
return fetch(Dagger.@spawn scope=scope setindex!(part, value, offset_idx...))
132132
end
133133
Base.setindex!(A::DArray, value, idx::Integer...) =

src/scopes.jl

Lines changed: 54 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,8 @@ ProcessScope() = ProcessScope(myid())
8282
struct ProcessorTypeTaint{T} <: AbstractScopeTaint end
8383

8484
"Scoped to any processor with a given supertype."
85-
ProcessorTypeScope(T) =
86-
TaintScope(AnyScope(),
85+
ProcessorTypeScope(T, inner_scope=AnyScope()) =
86+
TaintScope(inner_scope,
8787
Set{AbstractScopeTaint}([ProcessorTypeTaint{T}()]))
8888

8989
"Scoped to a specific processor."
@@ -318,47 +318,86 @@ function to_scope(sc::NamedTuple)
318318
return to_scope(Val(max_prec_key), sc)
319319
end
320320

321+
all_workers = false
321322
workers = if haskey(sc, :worker)
322323
Int[sc.worker]
323324
elseif haskey(sc, :workers)
324-
Int[sc.workers...]
325+
if sc.workers == Colon()
326+
all_workers = true
327+
nothing
328+
else
329+
Int[sc.workers...]
330+
end
325331
else
332+
all_workers = true
326333
nothing
327334
end
335+
336+
all_threads = false
337+
want_threads = false
328338
threads = if haskey(sc, :thread)
339+
want_threads = true
329340
Int[sc.thread]
330341
elseif haskey(sc, :threads)
331-
Int[sc.threads...]
342+
want_threads = true
343+
if sc.threads == Colon()
344+
all_threads = true
345+
nothing
346+
else
347+
Int[sc.threads...]
348+
end
332349
else
350+
all_threads = true
333351
nothing
334352
end
335353

336354
# Simple cases
355+
if workers !== nothing && isempty(workers)
356+
throw(ArgumentError("Cannot construct scope with workers=[]"))
357+
end
358+
if threads !== nothing && isempty(threads)
359+
throw(ArgumentError("Cannot construct scope with threads=[]"))
360+
end
337361
if workers !== nothing && threads !== nothing
338362
subscopes = AbstractScope[]
339363
for w in workers, t in threads
340364
push!(subscopes, ExactScope(ThreadProc(w, t)))
341365
end
342366
return simplified_union_scope(subscopes)
343-
elseif workers !== nothing && threads === nothing
344-
subscopes = AbstractScope[ProcessScope(w) for w in workers]
345-
return simplified_union_scope(subscopes)
367+
end
368+
if workers !== nothing && threads === nothing
369+
subscopes = simplified_union_scope(AbstractScope[ProcessScope(w) for w in workers])
370+
if all_threads
371+
return constrain(subscopes, ProcessorTypeScope(ThreadProc))
372+
else
373+
return subscopes
374+
end
375+
end
376+
if all_threads && want_threads
377+
if all_workers
378+
return ProcessorTypeScope(ThreadProc)
379+
end
380+
return UnionScope([ProcessorTypeScope(ThreadProc, ProcessScope(w)) for w in workers])
346381
end
347382

348383
# More complex cases that require querying the cluster
349384
# FIXME: Use per-field scope taint
350385
if workers === nothing
351-
workers = procs()
386+
workers = map(p->p.pid, filter(p->p isa OSProc, procs(Dagger.Sch.eager_context())))
352387
end
353388
subscopes = AbstractScope[]
354389
for w in workers
355-
if threads === nothing
356-
threads = map(c->c.tid,
357-
filter(c->c isa ThreadProc,
358-
collect(children(OSProc(w)))))
359-
end
360-
for t in threads
361-
push!(subscopes, ExactScope(ThreadProc(w, t)))
390+
if want_threads
391+
if threads === nothing
392+
threads = map(c->c.tid,
393+
filter(c->c isa ThreadProc,
394+
collect(children(OSProc(w)))))
395+
end
396+
for t in threads
397+
push!(subscopes, ExactScope(ThreadProc(w, t)))
398+
end
399+
else
400+
push!(subscopes, ProcessScope(w))
362401
end
363402
end
364403
return simplified_union_scope(subscopes)

test/scopes.jl

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -169,17 +169,16 @@
169169
@test Dagger.scope(()) == UnionScope()
170170

171171
@test Dagger.scope(worker=wid1) ==
172-
Dagger.scope(workers=[wid1]) ==
173-
ProcessScope(wid1)
172+
Dagger.scope(workers=[wid1])
174173
@test Dagger.scope(workers=[wid1,wid2]) == UnionScope([ProcessScope(wid1),
175174
ProcessScope(wid2)])
176-
@test Dagger.scope(workers=[]) == UnionScope()
175+
@test_throws ArgumentError Dagger.scope(workers=[])
177176

178177
@test Dagger.scope(thread=1) ==
179178
Dagger.scope(threads=[1]) ==
180179
UnionScope([ExactScope(Dagger.ThreadProc(w,1)) for w in procs()])
181180
@test Dagger.scope(threads=[1,2]) == UnionScope([ExactScope(Dagger.ThreadProc(w,t)) for t in [1,2] for w in procs()])
182-
@test Dagger.scope(threads=[]) == UnionScope()
181+
@test_throws ArgumentError Dagger.scope(threads=[])
183182

184183
@test Dagger.scope(worker=wid1,thread=1) ==
185184
Dagger.scope(thread=1,worker=wid1) ==
@@ -202,6 +201,12 @@
202201
@test_throws ArgumentError Dagger.scope((;blah=1))
203202
@test_throws ArgumentError Dagger.scope((thread=1, blah=1))
204203

204+
@test Dagger.scope(workers=:) == UnionScope([Dagger.ProcessScope(w) for w in procs()])
205+
@test Dagger.scope(threads=:) == Dagger.ProcessorTypeScope(Dagger.ThreadProc)
206+
@test Dagger.scope(worker=1, threads=:) == Dagger.ProcessorTypeScope(Dagger.ThreadProc, Dagger.ProcessScope(1))
207+
@test Dagger.scope(workers=:, thread=1) == UnionScope([ExactScope(Dagger.ThreadProc(w, 1)) for w in procs()]...)
208+
@test Dagger.scope(workers=:, threads=:) == Dagger.ProcessorTypeScope(Dagger.ThreadProc)
209+
205210
@testset "custom handler" begin
206211
@eval begin
207212
Dagger.scope_key_precedence(::Val{:gpu}) = 1
@@ -240,12 +245,6 @@
240245
end
241246

242247
@testset "compatible_processors" begin
243-
scope = Dagger.scope(workers=[])
244-
comp_procs = Dagger.compatible_processors(scope)
245-
@test Dagger.num_processors(scope) == length(comp_procs)
246-
@test !any(proc->proc in comp_procs, Dagger.get_processors(OSProc(wid1)))
247-
@test !any(proc->proc in comp_procs, Dagger.get_processors(OSProc(wid2)))
248-
249248
scope = Dagger.scope(worker=wid1)
250249
comp_procs = Dagger.compatible_processors(scope)
251250
@test Dagger.num_processors(scope) == length(comp_procs)

0 commit comments

Comments
 (0)