Skip to content

Commit 49eea98

Browse files
committed
scopes: Add Dagger.scope helper
1 parent 5cd574c commit 49eea98

File tree

2 files changed

+212
-1
lines changed

2 files changed

+212
-1
lines changed

src/scopes.jl

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,3 +205,137 @@ constrain(x::ProcessScope, y::ExactScope) =
205205
x == y.parent ? y : InvalidScope(x, y)
206206
constrain(x::NodeScope, y::ExactScope) =
207207
x == y.parent.parent ? y : InvalidScope(x, y)
208+
209+
### Scopes helper
210+
211+
"""
212+
scope(scs...) -> AbstractScope
213+
scope(;scs...) -> AbstractScope
214+
215+
Constructs an `AbstractScope` from a set of scope specifiers. Each element in
216+
`scs` is a separate specifier; if `scs` is empty, an empty `UnionScope()` is
217+
produced; if `scs` has one element, then exactly one specifier is constructed;
218+
if `scs` has more than one element, a `UnionScope` of the scopes specified by
219+
`scs` is constructed. A variety of specifiers can be passed to construct a
220+
scope:
221+
- `:any` - Constructs an `AnyScope()`
222+
- `:default` - Constructs a `DefaultScope()`
223+
- `(scs...,)` - Constructs a `UnionScope` of scopes, each specified by `scs`
224+
- `thread=tid` or `threads=[tids...]` - Constructs an `ExactScope` or `UnionScope` containing all `Dagger.ThreadProc`s with thread ID `tid`/`tids` across all workers.
225+
- `worker=wid` or `workers=[wids...]` - Constructs a `ProcessScope` or `UnionScope` containing all `Dagger.ThreadProc`s with worker ID `wid`/`wids` across all threads.
226+
- `thread=tid`/`threads=tids` and `worker=wid`/`workers=wids` - Constructs an `ExactScope`, `ProcessScope`, or `UnionScope` containing all `Dagger.ThreadProc`s with worker ID `wid`/`wids` and threads `tid`/`tids`.
227+
228+
Aside from the worker and thread specifiers, it's possible to add custom
229+
specifiers for scoping to other kinds of processors (like GPUs) or providing
230+
different ways to specify a scope. Specifier selection is determined by a
231+
precedence ordering: by default, all specifiers have precedence `0`, which can
232+
be changed by defining `scope_key_precedence(::Val{spec}) = precedence` (where
233+
`spec` is the specifier as a `Symbol)`. The specifier with the highest
234+
precedence in a set of specifiers is used to determine the scope by calling
235+
`to_scope(::Val{spec}, sc::NamedTuple)` (where `sc` is the full set of
236+
specifiers), which should be overriden for each custom specifier, and which
237+
returns an `AbstractScope`. For example:
238+
239+
```julia
240+
# Setup a GPU specifier
241+
Dagger.scope_key_precedence(::Val{:gpu}) = 1
242+
Dagger.to_scope(::Val{:gpu}, sc::NamedTuple) = ExactScope(MyGPUDevice(sc.worker, sc.gpu))
243+
244+
# Generate an `ExactScope` for `MyGPUDevice` on worker 2, device 3
245+
Dagger.scope(gpu=3, worker=2)
246+
```
247+
"""
248+
scope(scs...) = simplified_union_scope(map(to_scope, scs))
249+
scope(; kwargs...) = to_scope((;kwargs...))
250+
251+
function simplified_union_scope(scopes)
252+
if length(scopes) == 1
253+
return only(scopes)
254+
else
255+
return UnionScope(scopes)
256+
end
257+
end
258+
to_scope(scope::AbstractScope) = scope
259+
function to_scope(sc::Symbol)
260+
if sc == :any
261+
return AnyScope()
262+
elseif sc == :default
263+
return DefaultScope()
264+
else
265+
throw(ArgumentError("Cannot construct scope from: $(repr(sc))"))
266+
end
267+
end
268+
function to_scope(sc::NamedTuple)
269+
if isempty(sc)
270+
return UnionScope()
271+
end
272+
273+
# FIXME: node and nodes
274+
known_keys = (:worker, :workers, :thread, :threads)
275+
unknown_keys = filter(key->!in(key, known_keys), keys(sc))
276+
if length(unknown_keys) > 0
277+
# Hand off construction if unknown members encountered
278+
precs = map(scope_key_precedence, map(Val, unknown_keys))
279+
max_prec = maximum(precs)
280+
if length(findall(prec->prec==max_prec, precs)) > 1
281+
throw(ArgumentError("Incompatible scope specifiers detected: $unknown_keys"))
282+
end
283+
max_prec_key = unknown_keys[argmax(precs)]
284+
return to_scope(Val(max_prec_key), sc)
285+
end
286+
287+
workers = if haskey(sc, :worker)
288+
Int[sc.worker]
289+
elseif haskey(sc, :workers)
290+
Int[sc.workers...]
291+
else
292+
nothing
293+
end
294+
threads = if haskey(sc, :thread)
295+
Int[sc.thread]
296+
elseif haskey(sc, :threads)
297+
Int[sc.threads...]
298+
else
299+
nothing
300+
end
301+
302+
# Simple cases
303+
if workers !== nothing && threads !== nothing
304+
subscopes = AbstractScope[]
305+
for w in workers, t in threads
306+
push!(subscopes, ExactScope(ThreadProc(w, t)))
307+
end
308+
return simplified_union_scope(subscopes)
309+
elseif workers !== nothing && threads === nothing
310+
subscopes = AbstractScope[ProcessScope(w) for w in workers]
311+
return simplified_union_scope(subscopes)
312+
end
313+
314+
# More complex cases that require querying the cluster
315+
# FIXME: Use per-field scope taint
316+
if workers === nothing
317+
workers = procs()
318+
end
319+
subscopes = AbstractScope[]
320+
for w in workers
321+
if threads === nothing
322+
threads = map(c->c.tid,
323+
filter(c->c isa ThreadProc,
324+
collect(children(OSProc(w)))))
325+
end
326+
for t in threads
327+
push!(subscopes, ExactScope(ThreadProc(w, t)))
328+
end
329+
end
330+
return simplified_union_scope(subscopes)
331+
end
332+
to_scope(scs::Tuple) =
333+
simplified_union_scope(map(to_scope, scs))
334+
to_scope(sc) =
335+
throw(ArgumentError("Cannot construct scope from: $sc"))
336+
337+
to_scope(::Val{key}, sc::NamedTuple) where key =
338+
throw(ArgumentError("Scope construction not implemented for key: $key"))
339+
340+
# Base case for all Dagger-owned keys
341+
scope_key_precedence(::Val) = 0

test/scopes.jl

Lines changed: 78 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@
124124
@test fetch(Dagger.@spawn exact_scope_test(us_es1_multi_ch)) == es1.processor
125125

126126
# No inner scopes
127-
@test_throws ArgumentError UnionScope()
127+
@test UnionScope() isa UnionScope
128128

129129
# Same inner scope
130130
@test fetch(Dagger.@spawn exact_scope_test(us_es1_ch, us_es1_ch)) == es1.processor
@@ -161,5 +161,82 @@
161161
end
162162
# TODO: Test scope propagation
163163

164+
@testset "scope helper" begin
165+
@test Dagger.scope(:any) isa AnyScope
166+
@test Dagger.scope(:default) == DefaultScope()
167+
@test_throws ArgumentError Dagger.scope(:blah)
168+
@test Dagger.scope(()) == UnionScope()
169+
170+
@test Dagger.scope(worker=wid1) ==
171+
Dagger.scope(workers=[wid1]) ==
172+
ProcessScope(wid1)
173+
@test Dagger.scope(workers=[wid1,wid2]) == UnionScope([ProcessScope(wid1),
174+
ProcessScope(wid2)])
175+
@test Dagger.scope(workers=[]) == UnionScope()
176+
177+
@test Dagger.scope(thread=1) ==
178+
Dagger.scope(threads=[1]) ==
179+
UnionScope([ExactScope(Dagger.ThreadProc(w,1)) for w in procs()])
180+
@test Dagger.scope(threads=[1,2]) == UnionScope([ExactScope(Dagger.ThreadProc(w,t)) for t in [1,2] for w in procs()])
181+
@test Dagger.scope(threads=[]) == UnionScope()
182+
183+
@test Dagger.scope(worker=wid1,thread=1) ==
184+
Dagger.scope(thread=1,worker=wid1) ==
185+
Dagger.scope(workers=[wid1],thread=1) ==
186+
Dagger.scope(worker=wid1,threads=[1]) ==
187+
Dagger.scope(workers=[wid1],threads=[1]) ==
188+
ExactScope(Dagger.ThreadProc(wid1,1))
189+
190+
@test_throws ArgumentError Dagger.scope(blah=1)
191+
@test_throws ArgumentError Dagger.scope(thread=1, blah=1)
192+
193+
@test Dagger.scope(worker=1,thread=1) ==
194+
Dagger.scope((worker=1,thread=1)) ==
195+
Dagger.scope(((worker=1,thread=1),))
196+
@test Dagger.scope((worker=1,thread=1),(worker=wid1,thread=2)) ==
197+
Dagger.scope(((worker=1,thread=1),(worker=wid1,thread=2),)) ==
198+
Dagger.scope(((worker=1,thread=1),), ((worker=wid1,thread=2),)) ==
199+
UnionScope([ExactScope(Dagger.ThreadProc(1, 1)),
200+
ExactScope(Dagger.ThreadProc(wid1, 2))])
201+
@test_throws ArgumentError Dagger.scope((;blah=1))
202+
@test_throws ArgumentError Dagger.scope((thread=1, blah=1))
203+
204+
@testset "custom handler" begin
205+
@eval begin
206+
Dagger.scope_key_precedence(::Val{:gpu}) = 1
207+
Dagger.scope_key_precedence(::Val{:rocm}) = 2
208+
Dagger.scope_key_precedence(::Val{:cuda}) = 2
209+
210+
# Some fake scopes to use as sentinels
211+
Dagger.to_scope(::Val{:gpu}, sc::NamedTuple) = ExactScope(Dagger.ThreadProc(1, sc.device))
212+
Dagger.to_scope(::Val{:rocm}, sc::NamedTuple) = ExactScope(Dagger.ThreadProc($wid1, sc.gpu))
213+
Dagger.to_scope(::Val{:cuda}, sc::NamedTuple) = ExactScope(Dagger.ThreadProc($wid2, sc.gpu))
214+
end
215+
216+
@test Dagger.scope(gpu=1,device=2) ==
217+
Dagger.scope(device=2,gpu=1) ==
218+
Dagger.scope(gpu=1,device=2,blah=3) ==
219+
Dagger.scope((gpu=1,device=2,blah=3)) ==
220+
ExactScope(Dagger.ThreadProc(1, 2))
221+
@test Dagger.scope((gpu=1,device=2),(worker=1,thread=1)) ==
222+
Dagger.scope((worker=1,thread=1),(device=2,gpu=1)) ==
223+
UnionScope([ExactScope(Dagger.ThreadProc(1, 2)),
224+
ExactScope(Dagger.ThreadProc(1, 1))])
225+
@test Dagger.scope((gpu=1,device=2),(device=3,gpu=1)) ==
226+
UnionScope([ExactScope(Dagger.ThreadProc(1, 2)),
227+
ExactScope(Dagger.ThreadProc(1, 3))])
228+
229+
@test Dagger.scope(rocm=1,gpu=2) ==
230+
Dagger.scope(gpu=2,rocm=1) ==
231+
ExactScope(Dagger.ThreadProc(wid1, 2))
232+
@test Dagger.scope(cuda=1,gpu=2) ==
233+
Dagger.scope(gpu=2,cuda=1) ==
234+
ExactScope(Dagger.ThreadProc(wid2, 2))
235+
@test_throws ArgumentError Dagger.scope(rocm=1,cuda=1,gpu=2)
236+
@test_throws ArgumentError Dagger.scope(gpu=2,rocm=1,cuda=1)
237+
@test_throws ArgumentError Dagger.scope((rocm=1,cuda=1,gpu=2))
238+
end
239+
end
240+
164241
rmprocs([wid1, wid2])
165242
end

0 commit comments

Comments
 (0)