Skip to content

Commit ac56504

Browse files
authored
Merge pull request #443 from JuliaParallel/jps/submission-remote-lock
Sch: Fix various locking and WeakChunk issues
2 parents 46cfb48 + 080849f commit ac56504

File tree

3 files changed

+40
-13
lines changed

3 files changed

+40
-13
lines changed

src/chunks.jl

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ affinity(c::Chunk) = affinity(c.handle)
6565
is_task_or_chunk(c::Chunk) = true
6666

6767
Base.:(==)(c1::Chunk, c2::Chunk) = c1.handle == c2.handle
68-
Base.hash(c::Chunk, x::UInt64) = hash(c.handle, x)
68+
Base.hash(c::Chunk, x::UInt64) = hash(c.handle, hash(Chunk, x))
6969

7070
collect_remote(chunk::Chunk) =
7171
move(chunk.processor, OSProc(), poolget(chunk.handle))
@@ -281,16 +281,22 @@ function savechunk(data, dir, f)
281281
end
282282

283283
struct WeakChunk
284+
wid::Int
285+
id::Int
284286
x::WeakRef
287+
function WeakChunk(c::Chunk)
288+
return new(c.handle.owner, c.handle.id, WeakRef(c))
289+
end
285290
end
286-
WeakChunk(c::Chunk) = WeakChunk(WeakRef(c))
287291
unwrap_weak(c::WeakChunk) = c.x.value
288292
function unwrap_weak_checked(c::WeakChunk)
289-
c = unwrap_weak(c)
290-
@assert c !== nothing
291-
return c
293+
cw = unwrap_weak(c)
294+
@assert cw !== nothing "WeakChunk expired: ($(c.wid), $(c.id))"
295+
return cw
292296
end
293297
is_task_or_chunk(c::WeakChunk) = true
298+
Serialization.serialize(io::AbstractSerializer, wc::WeakChunk) =
299+
error("Cannot serialize a WeakChunk")
294300

295301
Base.@deprecate_binding AbstractPart Union{Chunk, Thunk}
296302
Base.@deprecate_binding Part Chunk

src/sch/util.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,18 @@ function reschedule_syncdeps!(state, thunk, seen=Set{Thunk}())
107107
if haskey(state.cache, thunk) || (thunk in state.ready) || (thunk in state.running)
108108
continue
109109
end
110+
for (_,input) in thunk.inputs
111+
if input isa WeakChunk
112+
input = unwrap_weak_checked(input)
113+
end
114+
if input isa Chunk
115+
# N.B. Different Chunks with the same DRef handle will hash to the same slot,
116+
# so we just pick an equivalent Chunk as our upstream
117+
if !haskey(state.waiting_data, input)
118+
push!(get!(()->Set{Thunk}(), state.waiting_data, input), thunk)
119+
end
120+
end
121+
end
110122
w = get!(()->Set{Thunk}(), state.waiting, thunk)
111123
for input in thunk.syncdeps
112124
input = unwrap_weak_checked(input)

src/submission.jl

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
# Remote
22
function eager_submit_internal!(@nospecialize(payload))
3-
Sch.init_eager()
4-
53
ctx = Dagger.Sch.eager_context()
64
state = Dagger.Sch.EAGER_STATE[]
75
task = current_task()
@@ -12,8 +10,6 @@ function eager_submit_internal!(ctx, state, task, tid, payload; uid_to_tid=Dict{
1210
@nospecialize payload
1311
ntasks, uid, future, ref, f, args, options, reschedule = payload
1412

15-
Sch.init_eager()
16-
1713
if uid isa Vector
1814
thunk_ids = Sch.ThunkID[]
1915
for i in 1:ntasks
@@ -31,6 +27,7 @@ function eager_submit_internal!(ctx, state, task, tid, payload; uid_to_tid=Dict{
3127
timespan_start(ctx, :add_thunk, tid, 0)
3228

3329
# Lookup EagerThunk/ThunkID -> Thunk
30+
old_args = copy(args)
3431
args::Vector{Any}
3532
syncdeps = if haskey(options, :syncdeps)
3633
collect(options.syncdeps)
@@ -51,6 +48,13 @@ function eager_submit_internal!(ctx, state, task, tid, payload; uid_to_tid=Dict{
5148
elseif arg isa Sch.ThunkID
5249
arg_tid = arg.id
5350
state.thunk_dict[arg_tid]
51+
elseif arg isa Chunk
52+
# N.B. Different Chunks with the same DRef handle will hash to the same slot,
53+
# so we just pick an equivalent Chunk as our upstream
54+
if haskey(state.waiting_data, arg)
55+
arg = only(filter(o->o isa Chunk && o.handle == arg.handle, keys(state.waiting_data)))::Chunk
56+
end
57+
WeakChunk(arg)
5458
else
5559
arg
5660
end
@@ -80,7 +84,7 @@ function eager_submit_internal!(ctx, state, task, tid, payload; uid_to_tid=Dict{
8084
options = merge(options, (;syncdeps))
8185
end
8286

83-
GC.@preserve args begin
87+
GC.@preserve old_args args begin
8488
# Create the `Thunk`
8589
thunk = Thunk(f, args...; options...)
8690

@@ -125,7 +129,14 @@ function eager_submit!(ntasks, uid, future, finalizer_ref, f, args, options)
125129
h = Dagger.sch_handle()
126130
return exec!(eager_submit_internal!, h, ntasks, uid, future, finalizer_ref, f, args, options, true)
127131
elseif myid() != 1
128-
return remotecall_fetch(eager_submit_internal!, 1, (ntasks, uid, future, finalizer_ref, f, args, options, true))
132+
return remotecall_fetch(1, (ntasks, uid, future, finalizer_ref, f, args, options, true)) do payload
133+
@nospecialize payload
134+
Sch.init_eager()
135+
state = Dagger.Sch.EAGER_STATE[]
136+
lock(state.lock) do
137+
eager_submit_internal!(payload)
138+
end
139+
end
129140
else
130141
Sch.init_eager()
131142
state = Dagger.Sch.EAGER_STATE[]
@@ -143,8 +154,6 @@ function eager_process_elem_submission_to_local(id_map, x)
143154
@assert !isa(x, Thunk) "Cannot use `Thunk`s in `@spawn`/`spawn`"
144155
if x isa Dagger.EagerThunk && haskey(id_map, x.uid)
145156
return Sch.ThunkID(id_map[x.uid], x.thunk_ref)
146-
elseif x isa Dagger.Chunk
147-
return WeakChunk(x)
148157
else
149158
return x
150159
end

0 commit comments

Comments
 (0)