diff --git a/src/array/indexing.jl b/src/array/indexing.jl index 69725eb7a..0020bf10a 100644 --- a/src/array/indexing.jl +++ b/src/array/indexing.jl @@ -127,7 +127,7 @@ function Base.setindex!(A::DArray{T,N}, value, idx::NTuple{N,Int}) where {T,N} # Set the value part = A.chunks[part_idx...] space = memory_space(part) - scope = Dagger.scope(worker=root_worker_id(space)) + scope = UnionScope(map(ExactScope, collect(processors(space)))) return fetch(Dagger.@spawn scope=scope setindex!(part, value, offset_idx...)) end Base.setindex!(A::DArray, value, idx::Integer...) = diff --git a/src/scopes.jl b/src/scopes.jl index 834993c9f..0545c573e 100644 --- a/src/scopes.jl +++ b/src/scopes.jl @@ -82,8 +82,8 @@ ProcessScope() = ProcessScope(myid()) struct ProcessorTypeTaint{T} <: AbstractScopeTaint end "Scoped to any processor with a given supertype." -ProcessorTypeScope(T) = - TaintScope(AnyScope(), +ProcessorTypeScope(T, inner_scope=AnyScope()) = + TaintScope(inner_scope, Set{AbstractScopeTaint}([ProcessorTypeTaint{T}()])) "Scoped to a specific processor." @@ -318,47 +318,86 @@ function to_scope(sc::NamedTuple) return to_scope(Val(max_prec_key), sc) end + all_workers = false workers = if haskey(sc, :worker) Int[sc.worker] elseif haskey(sc, :workers) - Int[sc.workers...] + if sc.workers == Colon() + all_workers = true + nothing + else + Int[sc.workers...] + end else + all_workers = true nothing end + + all_threads = false + want_threads = false threads = if haskey(sc, :thread) + want_threads = true Int[sc.thread] elseif haskey(sc, :threads) - Int[sc.threads...] + want_threads = true + if sc.threads == Colon() + all_threads = true + nothing + else + Int[sc.threads...] + end else + all_threads = true nothing end # Simple cases + if workers !== nothing && isempty(workers) + throw(ArgumentError("Cannot construct scope with workers=[]")) + end + if threads !== nothing && isempty(threads) + throw(ArgumentError("Cannot construct scope with threads=[]")) + end if workers !== nothing && threads !== nothing subscopes = AbstractScope[] for w in workers, t in threads push!(subscopes, ExactScope(ThreadProc(w, t))) end return simplified_union_scope(subscopes) - elseif workers !== nothing && threads === nothing - subscopes = AbstractScope[ProcessScope(w) for w in workers] - return simplified_union_scope(subscopes) + end + if workers !== nothing && threads === nothing + subscopes = simplified_union_scope(AbstractScope[ProcessScope(w) for w in workers]) + if all_threads + return constrain(subscopes, ProcessorTypeScope(ThreadProc)) + else + return subscopes + end + end + if all_threads && want_threads + if all_workers + return ProcessorTypeScope(ThreadProc) + end + return UnionScope([ProcessorTypeScope(ThreadProc, ProcessScope(w)) for w in workers]) end # More complex cases that require querying the cluster # FIXME: Use per-field scope taint if workers === nothing - workers = procs() + workers = map(p->p.pid, filter(p->p isa OSProc, procs(Dagger.Sch.eager_context()))) end subscopes = AbstractScope[] for w in workers - if threads === nothing - threads = map(c->c.tid, - filter(c->c isa ThreadProc, - collect(children(OSProc(w))))) - end - for t in threads - push!(subscopes, ExactScope(ThreadProc(w, t))) + if want_threads + if threads === nothing + threads = map(c->c.tid, + filter(c->c isa ThreadProc, + collect(children(OSProc(w))))) + end + for t in threads + push!(subscopes, ExactScope(ThreadProc(w, t))) + end + else + push!(subscopes, ProcessScope(w)) end end return simplified_union_scope(subscopes) diff --git a/test/scopes.jl b/test/scopes.jl index 065e5158f..ecade7ab4 100644 --- a/test/scopes.jl +++ b/test/scopes.jl @@ -169,17 +169,16 @@ @test Dagger.scope(()) == UnionScope() @test Dagger.scope(worker=wid1) == - Dagger.scope(workers=[wid1]) == - ProcessScope(wid1) + Dagger.scope(workers=[wid1]) @test Dagger.scope(workers=[wid1,wid2]) == UnionScope([ProcessScope(wid1), ProcessScope(wid2)]) - @test Dagger.scope(workers=[]) == UnionScope() + @test_throws ArgumentError Dagger.scope(workers=[]) @test Dagger.scope(thread=1) == Dagger.scope(threads=[1]) == UnionScope([ExactScope(Dagger.ThreadProc(w,1)) for w in procs()]) @test Dagger.scope(threads=[1,2]) == UnionScope([ExactScope(Dagger.ThreadProc(w,t)) for t in [1,2] for w in procs()]) - @test Dagger.scope(threads=[]) == UnionScope() + @test_throws ArgumentError Dagger.scope(threads=[]) @test Dagger.scope(worker=wid1,thread=1) == Dagger.scope(thread=1,worker=wid1) == @@ -202,6 +201,12 @@ @test_throws ArgumentError Dagger.scope((;blah=1)) @test_throws ArgumentError Dagger.scope((thread=1, blah=1)) + @test Dagger.scope(workers=:) == UnionScope([Dagger.ProcessScope(w) for w in procs()]) + @test Dagger.scope(threads=:) == Dagger.ProcessorTypeScope(Dagger.ThreadProc) + @test Dagger.scope(worker=1, threads=:) == Dagger.ProcessorTypeScope(Dagger.ThreadProc, Dagger.ProcessScope(1)) + @test Dagger.scope(workers=:, thread=1) == UnionScope([ExactScope(Dagger.ThreadProc(w, 1)) for w in procs()]...) + @test Dagger.scope(workers=:, threads=:) == Dagger.ProcessorTypeScope(Dagger.ThreadProc) + @testset "custom handler" begin @eval begin Dagger.scope_key_precedence(::Val{:gpu}) = 1 @@ -240,12 +245,6 @@ end @testset "compatible_processors" begin - scope = Dagger.scope(workers=[]) - comp_procs = Dagger.compatible_processors(scope) - @test Dagger.num_processors(scope) == length(comp_procs) - @test !any(proc->proc in comp_procs, Dagger.get_processors(OSProc(wid1))) - @test !any(proc->proc in comp_procs, Dagger.get_processors(OSProc(wid2))) - scope = Dagger.scope(worker=wid1) comp_procs = Dagger.compatible_processors(scope) @test Dagger.num_processors(scope) == length(comp_procs)