Skip to content

Commit 9446bca

Browse files
committed
Sch: Fix incorrect task signature calculation
1 parent 0350acd commit 9446bca

File tree

2 files changed

+95
-54
lines changed

2 files changed

+95
-54
lines changed

src/sch/util.jl

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -285,13 +285,34 @@ function report_catch_error(err, desc=nothing)
285285
end
286286

287287
chunktype(x) = typeof(x)
288-
signature(state, task::Thunk) = signature(state, task.f, task.inputs)
289-
function signature(state, f, inputs::Vector)
290-
sig = Any[chunktype(f)]
291-
for (pos, input) in collect_task_inputs(state, inputs)
292-
# N.B. Skips kwargs
288+
signature(state, task::Thunk) =
289+
signature(task.f, collect_task_inputs(state, task.inputs))
290+
function signature(f, args)
291+
sig = DataType[chunktype(f)]
292+
sig_kwarg_names = Symbol[]
293+
sig_kwarg_types = []
294+
for (pos, arg) in args
295+
if arg isa Dagger.DTask
296+
# Only occurs via manual usage of signature
297+
arg = fetch(arg; raw=true)
298+
end
299+
T = chunktype(arg)
293300
if pos === nothing
294-
push!(sig, chunktype(input))
301+
push!(sig, T)
302+
else
303+
push!(sig_kwarg_names, pos)
304+
push!(sig_kwarg_types, T)
305+
end
306+
end
307+
if !isempty(sig_kwarg_names)
308+
NT = NamedTuple{(sig_kwarg_names...,), Base.to_tuple_type(sig_kwarg_types)}
309+
pushfirst!(sig, NT)
310+
@static if isdefined(Core, :kwcall)
311+
pushfirst!(sig, typeof(Core.kwcall))
312+
else
313+
f_instance = chunktype(f).instance
314+
kw_f = Core.kwfunc(f_instance)
315+
pushfirst!(sig, typeof(kw_f))
295316
end
296317
end
297318
return sig
@@ -423,12 +444,12 @@ end
423444
collect_task_inputs(state, task::Thunk) =
424445
collect_task_inputs(state, task.inputs)
425446
function collect_task_inputs(state, inputs)
426-
inputs = Pair{Union{Symbol,Nothing},Any}[]
447+
new_inputs = Pair{Union{Symbol,Nothing},Any}[]
427448
for (pos, input) in inputs
428449
input = unwrap_weak_checked(input)
429-
push!(inputs, pos => (istask(input) ? state.cache[input] : input))
450+
push!(new_inputs, pos => (istask(input) ? state.cache[input] : input))
430451
end
431-
return inputs
452+
return new_inputs
432453
end
433454

434455
"""

test/scheduler.jl

Lines changed: 65 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -350,59 +350,79 @@ end
350350
end
351351

352352
@testset "Scheduler algorithms" begin
353-
# New function to hide from scheduler's function cost cache
354-
mynothing(args...) = nothing
353+
@testset "Signature Calculation" begin
354+
@test Dagger.Sch.signature(+, [nothing=>1, nothing=>2]) isa Vector{DataType}
355+
@test Dagger.Sch.signature(+, [nothing=>1, nothing=>2]) == [typeof(+), Int, Int]
356+
if isdefined(Core, :kwcall)
357+
@test Dagger.Sch.signature(+, [nothing=>1, :a=>2]) == [typeof(Core.kwcall), @NamedTuple{a::Int64}, typeof(+), Int]
358+
else
359+
kw_f = Core.kwfunc(+)
360+
@test Dagger.Sch.signature(+, [nothing=>1, :a=>2]) == [typeof(kw_f), @NamedTuple{a::Int64}, typeof(+), Int]
361+
end
362+
@test Dagger.Sch.signature(+, []) == [typeof(+)]
363+
@test Dagger.Sch.signature(+, [nothing=>1]) == [typeof(+), Int]
355364

356-
# New non-singleton struct to hide from `approx_size`
357-
struct MyStruct
358-
x::Int
365+
c = Dagger.tochunk(1.0)
366+
@test Dagger.Sch.signature(*, [nothing=>c, nothing=>3]) == [typeof(*), Float64, Int]
367+
t = Dagger.@spawn 1+2
368+
@test Dagger.Sch.signature(/, [nothing=>t, nothing=>c, nothing=>3]) == [typeof(/), Int, Float64, Int]
359369
end
360370

361-
state = Dagger.Sch.EAGER_STATE[]
362-
tproc1 = Dagger.ThreadProc(1, 1)
363-
tproc2 = Dagger.ThreadProc(first(workers()), 1)
364-
procs = [tproc1, tproc2]
365-
366-
pres1 = state.worker_time_pressure[1][tproc1]
367-
pres2 = state.worker_time_pressure[first(workers())][tproc2]
368-
tx_rate = state.transfer_rate[]
369-
370-
for (args, tx_size) in [
371-
([1, 2], 0),
372-
([Dagger.tochunk(1), 2], sizeof(Int)),
373-
([1, Dagger.tochunk(2)], sizeof(Int)),
374-
([Dagger.tochunk(1), Dagger.tochunk(2)], 2*sizeof(Int)),
375-
# TODO: Why does this work? Seems slow
376-
([Dagger.tochunk(MyStruct(1))], sizeof(MyStruct)),
377-
([Dagger.tochunk(MyStruct(1)), Dagger.tochunk(1)], sizeof(MyStruct)+sizeof(Int)),
378-
]
379-
for arg in args
380-
if arg isa Chunk
381-
aff = Dagger.affinity(arg)
382-
@test aff[1] == OSProc(1)
383-
@test aff[2] == MemPool.approx_size(MemPool.poolget(arg.handle))
384-
end
371+
@testset "Cost Estimation" begin
372+
# New function to hide from scheduler's function cost cache
373+
mynothing(args...) = nothing
374+
375+
# New non-singleton struct to hide from `approx_size`
376+
struct MyStruct
377+
x::Int
385378
end
386379

387-
cargs = map(arg->MemPool.poolget(arg.handle), filter(arg->isa(arg, Chunk), args))
388-
est_tx_size = Dagger.Sch.impute_sum(map(MemPool.approx_size, cargs))
389-
@test est_tx_size == tx_size
380+
state = Dagger.Sch.EAGER_STATE[]
381+
tproc1 = Dagger.ThreadProc(1, 1)
382+
tproc2 = Dagger.ThreadProc(first(workers()), 1)
383+
procs = [tproc1, tproc2]
384+
385+
pres1 = state.worker_time_pressure[1][tproc1]
386+
pres2 = state.worker_time_pressure[first(workers())][tproc2]
387+
tx_rate = state.transfer_rate[]
388+
389+
for (args, tx_size) in [
390+
([1, 2], 0),
391+
([Dagger.tochunk(1), 2], sizeof(Int)),
392+
([1, Dagger.tochunk(2)], sizeof(Int)),
393+
([Dagger.tochunk(1), Dagger.tochunk(2)], 2*sizeof(Int)),
394+
# TODO: Why does this work? Seems slow
395+
([Dagger.tochunk(MyStruct(1))], sizeof(MyStruct)),
396+
([Dagger.tochunk(MyStruct(1)), Dagger.tochunk(1)], sizeof(MyStruct)+sizeof(Int)),
397+
]
398+
for arg in args
399+
if arg isa Chunk
400+
aff = Dagger.affinity(arg)
401+
@test aff[1] == OSProc(1)
402+
@test aff[2] == MemPool.approx_size(MemPool.poolget(arg.handle))
403+
end
404+
end
390405

391-
t = delayed(mynothing)(args...)
392-
inputs = Dagger.Sch.collect_task_inputs(state, t)
393-
sorted_procs, costs = Dagger.Sch.estimate_task_costs(state, procs, t, inputs)
406+
cargs = map(arg->MemPool.poolget(arg.handle), filter(arg->isa(arg, Chunk), args))
407+
est_tx_size = Dagger.Sch.impute_sum(map(MemPool.approx_size, cargs))
408+
@test est_tx_size == tx_size
394409

395-
@test tproc1 in sorted_procs
396-
@test tproc2 in sorted_procs
397-
if length(cargs) > 0
398-
@test sorted_procs[1] == tproc1
399-
@test sorted_procs[2] == tproc2
400-
end
410+
t = delayed(mynothing)(args...)
411+
inputs = Dagger.Sch.collect_task_inputs(state, t)
412+
sorted_procs, costs = Dagger.Sch.estimate_task_costs(state, procs, t, inputs)
401413

402-
@test haskey(costs, tproc1)
403-
@test haskey(costs, tproc2)
404-
@test costs[tproc1] pres1 # All chunks are local
405-
@test costs[tproc2] (tx_size/tx_rate) + pres2 # All chunks are remote
414+
@test tproc1 in sorted_procs
415+
@test tproc2 in sorted_procs
416+
if length(cargs) > 0
417+
@test sorted_procs[1] == tproc1
418+
@test sorted_procs[2] == tproc2
419+
end
420+
421+
@test haskey(costs, tproc1)
422+
@test haskey(costs, tproc2)
423+
@test costs[tproc1] pres1 # All chunks are local
424+
@test costs[tproc2] (tx_size/tx_rate) + pres2 # All chunks are remote
425+
end
406426
end
407427
end
408428

0 commit comments

Comments
 (0)