|
| 1 | +struct TaskState |
| 2 | + task::Task |
| 3 | + sticky::Bool # original stickiness of the task |
| 4 | + state::C.PyGILState_STATE |
| 5 | +end |
| 6 | + |
| 7 | +struct TaskStack |
| 8 | + stack::Vector{TaskState} |
| 9 | + count::IdDict{Task,Int} |
| 10 | + condvar::Threads.Condition |
| 11 | + function TaskStack() |
| 12 | + return new(TaskState[], IdDict{Task,Int}(), Threads.Condition()) |
| 13 | + end |
| 14 | +end |
| 15 | +function Base.last(task_stack::TaskStack)::Task |
| 16 | + return last(task_stack.stack).task |
| 17 | +end |
| 18 | +function Base.push!(task_stack::TaskStack, task::Task) |
| 19 | + original_sticky = task.sticky |
| 20 | + # The task should not migrate threads while acquiring or holding the GIL |
| 21 | + task.sticky = true |
| 22 | + gil_state = C.PyGILState_Ensure() |
| 23 | + |
| 24 | + # Save the stickiness and state for when we release |
| 25 | + state = TaskState(task, original_sticky, gil_state) |
| 26 | + push!(task_stack.stack, state) |
| 27 | + |
| 28 | + # Increment the count for this task |
| 29 | + count = get(task_stack.count, task, 0) |
| 30 | + task_stack.count[task] = count + 1 |
| 31 | + |
| 32 | + return task_stack |
| 33 | +end |
| 34 | +function Base.pop!(task_stack::TaskStack)::Task |
| 35 | + state = pop!(task_stack.stack) |
| 36 | + task = state.task |
| 37 | + sticky = state.sticky |
| 38 | + gil_state = state.state |
| 39 | + |
| 40 | + # Decrement the count for this task |
| 41 | + count = task_stack.count[task] - 1 |
| 42 | + if count == 0 |
| 43 | + # If 0, remove it from the key set |
| 44 | + pop!(task_stack.count, task) |
| 45 | + else |
| 46 | + task_stack[task] = count |
| 47 | + end |
| 48 | + |
| 49 | + C.PyGILState_Release(gil_state) |
| 50 | + |
| 51 | + # Restore sticky state after releasing the GIL |
| 52 | + task.sticky = sticky |
| 53 | + |
| 54 | + Base.lock(task_stack.condvar) do |
| 55 | + notify(task_stack.condvar) |
| 56 | + end |
| 57 | + |
| 58 | + return task |
| 59 | +end |
| 60 | +Base.isempty(task_stack::TaskStack) = isempty(task_stack.stack) |
| 61 | + |
| 62 | +if !isdefined(Base, :OncePerThread) |
| 63 | + |
| 64 | + # OncePerThread is implemented in full in Julia 1.12 |
| 65 | + # This implementation is meant for compatibility with Julia 1.10 and 1.11 |
| 66 | + # and only supports a static number of threads. Use Julia 1.12 for dynamic |
| 67 | + # thread usage. |
| 68 | + mutable struct OncePerThread{T,F} <: Function |
| 69 | + @atomic xs::Vector{T} # values |
| 70 | + @atomic ss::Vector{UInt8} # states: 0=initial, 1=hasrun, 2=error, 3==concurrent |
| 71 | + const initializer::F |
| 72 | + function OncePerThread{T,F}(initializer::F) where {T,F} |
| 73 | + nt = Threads.maxthreadid() |
| 74 | + return new{T,F}(Vector{T}(undef, nt), zeros(UInt8, nt), initializer) |
| 75 | + end |
| 76 | + end |
| 77 | + OncePerThread{T}(initializer::Type{U}) where {T, U} = OncePerThread{T,Type{U}}(initializer) |
| 78 | + (once::OncePerThread{T,F})() where {T,F} = once[Threads.threadid()] |
| 79 | + function Base.getindex(once::OncePerThread, tid::Integer) |
| 80 | + tid = Threads.threadid() |
| 81 | + ss = @atomic :acquire once.ss |
| 82 | + xs = @atomic :monotonic once.xs |
| 83 | + if checkbounds(Bool, xs, tid) |
| 84 | + if ss[tid] == 0 |
| 85 | + xs[tid] = once.initializer() |
| 86 | + ss[tid] = 1 |
| 87 | + end |
| 88 | + return xs[tid] |
| 89 | + else |
| 90 | + throw(ErrorException("Thread id $tid is out of bounds as initially allocated. Use Julia 1.12 for dynamic thread usage.")) |
| 91 | + end |
| 92 | + end |
| 93 | + |
| 94 | +end |
| 95 | + |
| 96 | +struct GlobalInterpreterLock <: Base.AbstractLock |
| 97 | + lock_owners::OncePerThread{TaskStack} |
| 98 | + function GlobalInterpreterLock() |
| 99 | + return new(OncePerThread{TaskStack}(TaskStack)) |
| 100 | + end |
| 101 | +end |
| 102 | +function Base.lock(gil::GlobalInterpreterLock) |
| 103 | + push!(gil.lock_owners(), current_task()) |
| 104 | + return nothing |
| 105 | +end |
| 106 | +function Base.unlock(gil::GlobalInterpreterLock) |
| 107 | + lock_owner::TaskStack = gil.lock_owners() |
| 108 | + while last(lock_owner) != current_task() |
| 109 | + wait(lock_owner.condvar) |
| 110 | + end |
| 111 | + task = pop!(lock_owner) |
| 112 | + @assert task == current_task() |
| 113 | + return nothing |
| 114 | +end |
| 115 | +function Base.islocked(gil::GlobalInterpreterLock) |
| 116 | + # TODO: handle Julia 1.10 and 1.11 case when have not allocated up to maxthreadid |
| 117 | + return any(!isempty(gil.lock_owners[thread_index]) for thread_index in 1:Threads.maxthreadid()) |
| 118 | +end |
| 119 | + |
| 120 | +const _GIL = GlobalInterpreterLock() |
0 commit comments