Skip to content

Commit b16afad

Browse files
committed
Cancellation and test fixes
1 parent 6e35c76 commit b16afad

File tree

3 files changed

+59
-25
lines changed

3 files changed

+59
-25
lines changed

src/cancellation.jl

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -100,19 +100,30 @@ function _cancel!(state, tid, force, graceful, halt_sch)
100100
@dagdebug tid :cancel "Cancelling ready task"
101101
ex = DTaskFailedException(task, task, InterruptException())
102102
Sch.store_result!(state, task, ex; error=true)
103-
Sch.set_failed!(state, task)
103+
Sch.finish_failed!(state, task, task)
104+
end
105+
if tid === nothing
106+
empty!(state.ready)
107+
else
108+
idx = findfirst(t->t.id == tid, state.ready)
109+
idx !== nothing && deleteat!(state.ready, idx)
104110
end
105-
empty!(state.ready)
106111

107112
# Cancel waiting tasks
108113
for task in keys(state.waiting)
109114
tid !== nothing && task.id != tid && continue
110115
@dagdebug tid :cancel "Cancelling waiting task"
111116
ex = DTaskFailedException(task, task, InterruptException())
112117
Sch.store_result!(state, task, ex; error=true)
113-
Sch.set_failed!(state, task)
118+
Sch.finish_failed!(state, task, task)
119+
end
120+
if tid === nothing
121+
empty!(state.waiting)
122+
else
123+
if haskey(state.waiting, tid)
124+
delete!(state.waiting, tid)
125+
end
114126
end
115-
empty!(state.waiting)
116127

117128
# Cancel running tasks at the processor level
118129
wids = unique(map(root_worker_id, values(state.running_on)))
@@ -155,6 +166,7 @@ function _cancel!(state, tid, force, graceful, halt_sch)
155166
return
156167
end
157168
end
169+
put!(state.chan, Sch.RescheduleSignal())
158170

159171
if halt_sch
160172
unlock(state.lock)

src/sch/util.jl

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -587,15 +587,12 @@ end
587587
# TODO: Measure and model processor move overhead
588588
tx_cost = impute_sum(affinity(chunk)[2] for chunk in chunks_filt)
589589

590-
# Estimate total cost to move data and get task running after currently-scheduled tasks
591-
est_business = get(state.worker_time_pressure[get_parent(proc).pid], proc, 0)
592-
593590
# Add fixed cost for cross-worker task transfer (esimated at 1ms)
594591
# TODO: Actually estimate/benchmark this
595592
task_xfer_cost = gproc.pid != myid() ? 1_000_000 : 0 # 1ms
596593

597594
# Compute final cost
598-
costs[proc] = est_time_util + est_business + (tx_cost/tx_rate) + task_xfer_cost
595+
costs[proc] = est_time_util + (tx_cost/tx_rate) + task_xfer_cost
599596
end
600597
chunks_cleanup()
601598

test/scheduler.jl

Lines changed: 42 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -400,13 +400,18 @@ end
400400
end
401401

402402
state = Dagger.Sch.EAGER_STATE[]
403-
tproc1 = Dagger.ThreadProc(1, 1)
404-
tproc2 = Dagger.ThreadProc(first(workers()), 1)
405-
procs = [tproc1, tproc2]
403+
tproc1_1 = Dagger.ThreadProc(1, 1)
404+
tproc2_1 = Dagger.ThreadProc(first(workers()), 1)
405+
procs = [tproc1_1, tproc2_1]
406406

407-
pres1 = state.worker_time_pressure[1][tproc1]
408-
pres2 = state.worker_time_pressure[first(workers())][tproc2]
407+
# Ensure that this worker has been used at least once
408+
fetch(Dagger.@spawn scope=Dagger.ExactScope(tproc2_1) 1+1)
409+
410+
#pres1_1 = state.worker_time_pressure[1][tproc1_1]
411+
#pres2_1 = state.worker_time_pressure[first(workers())][tproc2_1]
409412
tx_rate = state.transfer_rate[]
413+
tx_xfer_cost = 1e6
414+
sig_unknown_cost = 1e9
410415

411416
for (args, tx_size) in [
412417
([1, 2], 0),
@@ -433,18 +438,18 @@ end
433438
Dagger.Sch.collect_task_inputs!(state, t)
434439
sorted_procs, costs = Dagger.Sch.estimate_task_costs(state, procs, t)
435440

436-
@test tproc1 in sorted_procs
437-
@test tproc2 in sorted_procs
441+
@test tproc1_1 in sorted_procs
442+
@test tproc2_1 in sorted_procs
438443
if length(cargs) > 0
439-
@test sorted_procs[1] == tproc1
440-
@test sorted_procs[2] == tproc2
444+
@test sorted_procs[1] == tproc1_1
445+
@test sorted_procs[2] == tproc2_1
441446
end
442447

443-
@test haskey(costs, tproc1)
444-
@test haskey(costs, tproc2)
445-
@test costs[tproc1] pres1 # All chunks are local
448+
@test haskey(costs, tproc1_1)
449+
@test haskey(costs, tproc2_1)
450+
@test costs[tproc1_1] #=pres1_1 +=# sig_unknown_cost # All chunks are local, and this signature is unknown
446451
if nprocs() > 1
447-
@test costs[tproc2] (tx_size/tx_rate) + pres2 # All chunks are remote
452+
@test costs[tproc2_1] (tx_size/tx_rate) + tx_xfer_cost + #=pres2_1 +=# sig_unknown_cost # All chunks are remote, and this signature is unknown
448453
end
449454
end
450455
end
@@ -564,12 +569,32 @@ end
564569
end
565570

566571
@testset "Cancellation" begin
572+
# Ready task cancellation
573+
start_time = time_ns()
567574
t = Dagger.@spawn scope=Dagger.scope(worker=1, thread=1) sleep(100)
575+
Dagger.cancel!(t)
576+
@test timedwait(()->istaskdone(t), 10) == :ok
577+
if istaskdone(t)
578+
@test_throws_unwrap (Dagger.DTaskFailedException, InterruptException) fetch(t)
579+
@test (time_ns() - start_time) * 1e-9 < 100
580+
end
581+
582+
# Running task cancellation
568583
start_time = time_ns()
584+
t = Dagger.@spawn scope=Dagger.scope(worker=1, thread=1) sleep(100)
585+
sleep(0.1) # Give the scheduler a chance to schedule the task
569586
Dagger.cancel!(t)
570-
@test_throws_unwrap (Dagger.DTaskFailedException, InterruptException) fetch(t)
587+
@test timedwait(()->istaskdone(t), 10) == :ok
588+
if istaskdone(t)
589+
@test_throws_unwrap (Dagger.DTaskFailedException, InterruptException) fetch(t)
590+
@test (time_ns() - start_time) * 1e-9 < 100
591+
end
592+
593+
# Normal task execution
594+
start_time = time_ns()
571595
t = Dagger.@spawn scope=Dagger.scope(worker=1, thread=1) yield()
572-
fetch(t)
573-
finish_time = time_ns()
574-
@test (finish_time - start_time) * 1e-9 < 100
596+
@test timedwait(()->istaskdone(t), 10) == :ok
597+
if istaskdone(t)
598+
@test (time_ns() - start_time) * 1e-9 < 100
599+
end
575600
end

0 commit comments

Comments
 (0)