Skip to content

Commit 80be019

Browse files
committed
Shard: Define iteration over chunks
1 parent 09f70b0 commit 80be019

File tree

2 files changed

+19
-27
lines changed

2 files changed

+19
-27
lines changed

src/chunks.jl

Lines changed: 15 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ Keyword arguments:
152152
- `workers` -- The list of workers to create pieces on. May be any iterable container of `Integer`s.
153153
- `per_thread::Bool=false` -- If `true`, creates a piece per each thread, rather than a piece per each worker.
154154
"""
155-
function shard(f; procs=nothing, workers=nothing, per_thread=false)
155+
function shard(@nospecialize(f); procs=nothing, workers=nothing, per_thread=false)
156156
if procs === nothing
157157
if workers !== nothing
158158
procs = [OSProc(w) for w in workers]
@@ -176,13 +176,15 @@ function shard(f; procs=nothing, workers=nothing, per_thread=false)
176176
end
177177
end
178178
isempty(procs) && throw(ArgumentError("Cannot create empty Shard"))
179-
scopes = [p isa OSProc ? ProcessScope(p) : ExactScope(p) for p in procs]
180-
thunks = [proc=>Dagger.@spawn single=Dagger.get_parent(proc).pid _shard_inner(f, proc, scope) for (proc,scope) in zip(procs,scopes)]
181-
shard = Shard(Dict{Processor,Chunk}(thunk[1]=>fetch(thunk[2])[] for thunk in thunks))
182-
scope = UnionScope(scopes)
183-
Dagger.tochunk(shard, OSProc(), scope)
179+
shard_dict = Dict{Processor,Chunk}()
180+
for proc in procs
181+
scope = proc isa OSProc ? ProcessScope(proc) : ExactScope(proc)
182+
thunk = Dagger.@spawn scope=scope _shard_inner(f, proc, scope)
183+
shard_dict[proc] = fetch(thunk)[]
184+
end
185+
return Shard(shard_dict)
184186
end
185-
function _shard_inner(f, proc, scope)
187+
function _shard_inner(@nospecialize(f), proc, scope)
186188
Ref(Dagger.@mutable proc scope f())
187189
end
188190

@@ -191,7 +193,7 @@ macro shard(exs...)
191193
opts = esc.(exs[1:end-1])
192194
ex = exs[end]
193195
quote
194-
let f = ()->$(esc(ex))
196+
let f = @noinline ()->$(esc(ex))
195197
$shard(f; $(opts...))
196198
end
197199
end
@@ -202,28 +204,22 @@ function move(from_proc::Processor, to_proc::Processor, shard::Shard)
202204
# N.B. This behavior may bypass the piece's scope restriction
203205
proc = to_proc
204206
if haskey(shard.chunks, proc)
205-
return shard.chunks[proc]
207+
return move(from_proc, to_proc, shard.chunks[proc])
206208
end
207209
parent = Dagger.get_parent(proc)
208210
while parent != proc
209211
proc = parent
210212
parent = Dagger.get_parent(proc)
211213
if haskey(shard.chunks, proc)
212-
return shard.chunks[proc]
214+
return move(from_proc, to_proc, shard.chunks[proc])
213215
end
214216
end
215217

216218
throw(KeyError(to_proc))
217219
end
218-
function move(from_proc::Processor, to_proc::Processor, x::Chunk{Shard})
219-
piece = remotecall_fetch(x.handle.owner, x.handle, from_proc, to_proc) do ref, from_proc, to_proc
220-
shard = MemPool.poolget(ref)
221-
move(from_proc, to_proc, shard)
222-
end::Chunk
223-
move(from_proc, to_proc, piece)
224-
end
225-
Base.map(f, cs::Chunk{Shard}) = map(f, fetch(cs; raw=true))
226-
Base.map(f, s::Shard) = [Dagger.spawn(f, c) for c in values(s.chunks)]
220+
Base.iterate(s::Shard) = iterate(values(s.chunks))
221+
Base.iterate(s::Shard, state) = iterate(values(s.chunks), state)
222+
Base.length(s::Shard) = length(s.chunks)
227223

228224
### Core Stuff
229225

test/mutation.jl

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,7 @@ end
5050
end # @testset "@mutable"
5151

5252
@testset "Shard" begin
53-
cs = Dagger.@shard Threads.Atomic{Int}(0)
54-
s = fetch(cs; raw=true)
53+
s = Dagger.@shard Threads.Atomic{Int}(0)
5554
ctxprocs = Dagger.Sch.eager_context().procs
5655
for p in keys(s.chunks)
5756
@test p isa OSProc
@@ -67,8 +66,7 @@ end # @testset "@mutable"
6766

6867
@testset "procs kwarg" begin
6968
procs = [OSProc(first(workers()))]
70-
cs = Dagger.@shard procs=procs Threads.Atomic{Int}(0)
71-
s = fetch(cs; raw=true)
69+
s = Dagger.@shard procs=procs Threads.Atomic{Int}(0)
7270
@test length(keys(s.chunks)) == 1
7371
p = first(keys(s.chunks))
7472
@test p isa Dagger.OSProc
@@ -81,8 +79,7 @@ end # @testset "@mutable"
8179
end
8280

8381
@testset "workers kwarg" begin
84-
cs = Dagger.@shard workers=[first(workers())] Threads.Atomic{Int}(0)
85-
s = fetch(cs; raw=true)
82+
s = Dagger.@shard workers=[first(workers())] Threads.Atomic{Int}(0)
8683
@test length(keys(s.chunks)) == 1
8784
p = first(keys(s.chunks))
8885
@test p isa Dagger.OSProc
@@ -95,8 +92,7 @@ end # @testset "@mutable"
9592
end
9693

9794
@testset "per_thread kwarg" begin
98-
cs = Dagger.@shard per_thread=true Threads.Atomic{Int}(0)
99-
s = fetch(cs; raw=true)
95+
s = Dagger.@shard per_thread=true Threads.Atomic{Int}(0)
10096
for p in keys(s.chunks)
10197
@test p isa Dagger.ThreadProc
10298
gp = Dagger.get_parent(p)

0 commit comments

Comments
 (0)