Skip to content

Commit f202cab

Browse files
committed
Refactor wakeup of threads and scheduler file logic
1 parent 34070fa commit f202cab

File tree

9 files changed

+182
-76
lines changed

9 files changed

+182
-76
lines changed

base/Base.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ const liblapack_name = libblas_name
149149
# Note that `atomics.jl` here should be deprecated
150150
Core.eval(Threads, :(include("atomics.jl")))
151151
include("channels.jl")
152-
include("partr.jl")
152+
include("scheduler/scheduler.jl")
153153
include("task.jl")
154154
include("threads_overloads.jl")
155155
include("weakkeydict.jl")

base/partr.jl renamed to base/scheduler/partr.jl

Lines changed: 12 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -19,63 +19,6 @@ const heap_d = UInt32(8)
1919
const heaps = [Vector{taskheap}(undef, 0), Vector{taskheap}(undef, 0)]
2020
const heaps_lock = [SpinLock(), SpinLock()]
2121

22-
23-
"""
24-
cong(max::UInt32)
25-
26-
Return a random UInt32 in the range `1:max` except if max is 0, in that case return 0.
27-
"""
28-
cong(max::UInt32) = iszero(max) ? UInt32(0) : rand_ptls(max) + UInt32(1) #TODO: make sure users don't use 0 and remove this check
29-
30-
get_ptls_rng() = ccall(:jl_get_ptls_rng, UInt64, ())
31-
32-
set_ptls_rng(seed::UInt64) = ccall(:jl_set_ptls_rng, Cvoid, (UInt64,), seed)
33-
34-
"""
35-
rand_ptls(max::UInt32)
36-
37-
Return a random UInt32 in the range `0:max-1` using the thread-local RNG
38-
state. Max must be greater than 0.
39-
"""
40-
Base.@assume_effects :removable :inaccessiblememonly :notaskstate function rand_ptls(max::UInt32)
41-
rngseed = get_ptls_rng()
42-
val, seed = rand_uniform_max_int32(max, rngseed)
43-
set_ptls_rng(seed)
44-
return val % UInt32
45-
end
46-
47-
# This implementation is based on OpenSSLs implementation of rand_uniform
48-
# https://github.com/openssl/openssl/blob/1d2cbd9b5a126189d5e9bc78a3bdb9709427d02b/crypto/rand/rand_uniform.c#L13-L99
49-
# Comments are vendored from their implementation as well.
50-
# For the original developer check the PR to swift https://github.com/apple/swift/pull/39143.
51-
52-
# Essentially it boils down to incrementally generating a fixed point
53-
# number on the interval [0, 1) and multiplying this number by the upper
54-
# range limit. Once it is certain what the fractional part contributes to
55-
# the integral part of the product, the algorithm has produced a definitive
56-
# result.
57-
"""
58-
rand_uniform_max_int32(max::UInt32, seed::UInt64)
59-
60-
Return a random UInt32 in the range `0:max-1` using the given seed.
61-
Max must be greater than 0.
62-
"""
63-
Base.@assume_effects :total function rand_uniform_max_int32(max::UInt32, seed::UInt64)
64-
if max == UInt32(1)
65-
return UInt32(0), seed
66-
end
67-
# We are generating a fixed point number on the interval [0, 1).
68-
# Multiplying this by the range gives us a number on [0, upper).
69-
# The high word of the multiplication result represents the integral part
70-
# This is not completely unbiased as it's missing the fractional part of the original implementation but it's good enough for our purposes
71-
seed = UInt64(69069) * seed + UInt64(362437)
72-
prod = (UInt64(max)) * (seed % UInt32) # 64 bit product
73-
i = prod >> 32 % UInt32 # integral part
74-
return i % UInt32, seed
75-
end
76-
77-
78-
7922
function multiq_sift_up(heap::taskheap, idx::Int32)
8023
while idx > Int32(1)
8124
parent = (idx - Int32(2)) ÷ heap_d + Int32(1)
@@ -147,10 +90,10 @@ function multiq_insert(task::Task, priority::UInt16)
14790

14891
task.priority = priority
14992

150-
rn = cong(heap_p)
93+
rn = Base.Scheduler.cong(heap_p)
15194
tpheaps = heaps[tp]
15295
while !trylock(tpheaps[rn].lock)
153-
rn = cong(heap_p)
96+
rn = Base.Scheduler.cong(heap_p)
15497
end
15598

15699
heap = tpheaps[rn]
@@ -190,8 +133,8 @@ function multiq_deletemin()
190133
if i == heap_p
191134
return nothing
192135
end
193-
rn1 = cong(heap_p)
194-
rn2 = cong(heap_p)
136+
rn1 = Base.Scheduler.cong(heap_p)
137+
rn2 = Base.Scheduler.cong(heap_p)
195138
prio1 = tpheaps[rn1].priority
196139
prio2 = tpheaps[rn2].priority
197140
if prio1 > prio2
@@ -235,6 +178,9 @@ function multiq_check_empty()
235178
if tp == 0 # Foreign thread
236179
return true
237180
end
181+
if !isempty(Base.workqueue_for(tid))
182+
return false
183+
end
238184
for i = UInt32(1):length(heaps[tp])
239185
if heaps[tp][i].ntasks != 0
240186
return false
@@ -243,4 +189,9 @@ function multiq_check_empty()
243189
return true
244190
end
245191

192+
193+
enqueue!(t::Task) = multiq_insert(t, t.priority)
194+
dequeue!() = multiq_deletemin()
195+
checktaskempty() = multiq_check_empty()
196+
246197
end

base/scheduler/scheduler.jl

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# This file is a part of Julia. License is MIT: https://julialang.org/license
2+
3+
module Scheduler
4+
5+
"""
6+
cong(max::UInt32)
7+
8+
Return a random UInt32 in the range `1:max` except if max is 0, in that case return 0.
9+
"""
10+
cong(max::UInt32) = iszero(max) ? UInt32(0) : rand_ptls(max) + UInt32(1) #TODO: make sure users don't use 0 and remove this check
11+
12+
get_ptls_rng() = ccall(:jl_get_ptls_rng, UInt64, ())
13+
14+
set_ptls_rng(seed::UInt64) = ccall(:jl_set_ptls_rng, Cvoid, (UInt64,), seed)
15+
16+
"""
17+
rand_ptls(max::UInt32)
18+
19+
Return a random UInt32 in the range `0:max-1` using the thread-local RNG
20+
state. Max must be greater than 0.
21+
"""
22+
Base.@assume_effects :removable :inaccessiblememonly :notaskstate function rand_ptls(max::UInt32)
23+
rngseed = get_ptls_rng()
24+
val, seed = rand_uniform_max_int32(max, rngseed)
25+
set_ptls_rng(seed)
26+
return val % UInt32
27+
end
28+
29+
# This implementation is based on OpenSSLs implementation of rand_uniform
30+
# https://github.com/openssl/openssl/blob/1d2cbd9b5a126189d5e9bc78a3bdb9709427d02b/crypto/rand/rand_uniform.c#L13-L99
31+
# Comments are vendored from their implementation as well.
32+
# For the original developer check the PR to swift https://github.com/apple/swift/pull/39143.
33+
34+
# Essentially it boils down to incrementally generating a fixed point
35+
# number on the interval [0, 1) and multiplying this number by the upper
36+
# range limit. Once it is certain what the fractional part contributes to
37+
# the integral part of the product, the algorithm has produced a definitive
38+
# result.
39+
"""
40+
rand_uniform_max_int32(max::UInt32, seed::UInt64)
41+
42+
Return a random UInt32 in the range `0:max-1` using the given seed.
43+
Max must be greater than 0.
44+
"""
45+
Base.@assume_effects :total function rand_uniform_max_int32(max::UInt32, seed::UInt64)
46+
if max == UInt32(1)
47+
return UInt32(0), seed
48+
end
49+
# We are generating a fixed point number on the interval [0, 1).
50+
# Multiplying this by the range gives us a number on [0, upper).
51+
# The high word of the multiplication result represents the integral part
52+
# This is not completely unbiased as it's missing the fractional part of the original implementation but it's good enough for our purposes
53+
seed = UInt64(69069) * seed + UInt64(362437)
54+
prod = (UInt64(max)) * (seed % UInt32) # 64 bit product
55+
i = prod >> 32 % UInt32 # integral part
56+
return i % UInt32, seed
57+
end
58+
59+
include("scheduler/partr.jl")
60+
61+
const ChosenScheduler = Partr
62+
63+
64+
65+
# Scheduler interface:
66+
# enqueue! which pushes a runnable Task into it
67+
# dequeue! which pops a runnable Task from it
68+
# checktaskempty which returns true if the scheduler has no available Tasks
69+
70+
enqueue!(t::Task) = ChosenScheduler.enqueue!(t)
71+
dequeue!() = ChosenScheduler.dequeue!()
72+
checktaskempty() = ChosenScheduler.checktaskempty()
73+
74+
end

base/task.jl

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -937,7 +937,6 @@ end
937937

938938
function enq_work(t::Task)
939939
(t._state === task_state_runnable && t.queue === nothing) || error("schedule: Task not runnable")
940-
941940
# Sticky tasks go into their thread's work queue.
942941
if t.sticky
943942
tid = Threads.threadid(t)
@@ -968,19 +967,40 @@ function enq_work(t::Task)
968967
ccall(:jl_set_task_tid, Cint, (Any, Cint), t, tid-1)
969968
push!(workqueue_for(tid), t)
970969
else
971-
# Otherwise, put the task in the multiqueue.
972-
Partr.multiq_insert(t, t.priority)
970+
# Otherwise, push the task to the scheduler
971+
Scheduler.enqueue!(t)
973972
tid = 0
974973
end
975974
end
976-
ccall(:jl_wakeup_thread, Cvoid, (Int16,), (tid - 1) % Int16)
975+
976+
if (tid == 0)
977+
ccall(:jl_wake_any_thread, Cvoid, (Any,), current_task())
978+
else
979+
ccall(:jl_wakeup_thread, Cvoid, (Int16,), (tid - 1) % Int16)
980+
end
977981
return t
978982
end
979983

984+
const ChildFirst = false
985+
980986
function schedule(t::Task)
981987
# [task] created -scheduled-> wait_time
982988
maybe_record_enqueued!(t)
983-
enq_work(t)
989+
if ChildFirst
990+
ct = current_task()
991+
if ct.sticky || t.sticky
992+
maybe_record_enqueued!(t)
993+
enq_work(t)
994+
else
995+
maybe_record_enqueued!(t)
996+
enq_work(ct)
997+
yieldto(t)
998+
end
999+
else
1000+
maybe_record_enqueued!(t)
1001+
enq_work(t)
1002+
end
1003+
return t
9841004
end
9851005

9861006
"""
@@ -1186,10 +1206,10 @@ function trypoptask(W::StickyWorkqueue)
11861206
end
11871207
return t
11881208
end
1189-
return Partr.multiq_deletemin()
1209+
return Scheduler.dequeue!()
11901210
end
11911211

1192-
checktaskempty = Partr.multiq_check_empty
1212+
checktaskempty = Scheduler.checktaskempty
11931213

11941214
function wait()
11951215
ct = current_task()

src/jl_exported_funcs.inc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -441,6 +441,7 @@
441441
XX(jl_tagged_gensym) \
442442
XX(jl_take_buffer) \
443443
XX(jl_task_get_next) \
444+
XX(jl_wake_any_thread) \
444445
XX(jl_termios_size) \
445446
XX(jl_test_cpu_feature) \
446447
XX(jl_threadid) \

src/julia_threads.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,7 @@ typedef struct _jl_tls_states_t {
214214
uint64_t uv_run_leave;
215215
uint64_t sleep_enter;
216216
uint64_t sleep_leave;
217+
uint64_t woken_up;
217218
)
218219

219220
// some hidden state (usually just because we don't have the type's size declaration)

0 commit comments

Comments
 (0)