Skip to content

Commit 15bb7e6

Browse files
committed
feat: create ErrorLockable for erroring upon multithreaded access
1 parent d798710 commit 15bb7e6

File tree

7 files changed

+42
-34
lines changed

7 files changed

+42
-34
lines changed

src/Convert/Convert.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ using ..Core
99
using ..Core:
1010
C,
1111
Utils,
12-
Lockable,
12+
ErrorLockable,
1313
@autopy,
1414
getptr,
1515
incref,

src/Convert/pyconvert.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ struct PyConvertRule
1212
priority::PyConvertPriority
1313
end
1414

15-
const PYCONVERT_RULES = Lockable(Dict{String,Vector{PyConvertRule}}())
16-
const PYCONVERT_EXTRATYPES = Lockable(Py[])
15+
const PYCONVERT_RULES = ErrorLockable(Dict{String,Vector{PyConvertRule}}())
16+
const PYCONVERT_EXTRATYPES = ErrorLockable(Py[])
1717

1818
"""
1919
pyconvert_add_rule(tname::String, T::Type, func::Function, priority::PyConvertPriority=PYCONVERT_PRIORITY_NORMAL)
@@ -261,7 +261,7 @@ function _pyconvert_get_rules(pytype::Py)
261261
return rules
262262
end
263263

264-
const PYCONVERT_PREFERRED_TYPE = Lockable(Dict{Py,Type}())
264+
const PYCONVERT_PREFERRED_TYPE = ErrorLockable(Dict{Py,Type}())
265265

266266
pyconvert_preferred_type(pytype::Py) =
267267
Base.@lock PYCONVERT_PREFERRED_TYPE get!(PYCONVERT_PREFERRED_TYPE[], pytype) do
@@ -307,7 +307,7 @@ end
307307

308308
pyconvert_fix(::Type{T}, func) where {T} = x -> func(T, x)
309309

310-
const PYCONVERT_RULES_CACHE = Lockable(IdDict{Any,Dict{C.PyPtr,Vector{Function}}}())
310+
const PYCONVERT_RULES_CACHE = ErrorLockable(IdDict{Any,Dict{C.PyPtr,Vector{Function}}}())
311311

312312
function pyconvert_rules_cache(::Type{T}) where {T}
313313
Base.@lock PYCONVERT_RULES_CACHE get!(

src/Core/Core.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ const ROOT_DIR = dirname(dirname(@__DIR__))
1111
using ..PythonCall: PythonCall # needed for docstring cross-refs
1212
using ..C: C
1313
using ..GC: GC
14-
using ..Utils: Utils, Lockable
14+
using ..Utils: Utils, ErrorLockable
1515
using Base: @propagate_inbounds, @kwdef
1616
using Dates:
1717
Date,

src/Core/Py.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ decref(x::Py) = Base.GC.@preserve x (decref(getptr(x)); x)
5656

5757
Base.unsafe_convert(::Type{C.PyPtr}, x::Py) = getptr(x)
5858

59-
const PYNULL_CACHE = Lockable(Py[])
59+
const PYNULL_CACHE = ErrorLockable(Py[])
6060

6161
"""
6262
pynew([ptr])

src/Core/builtins.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1206,7 +1206,7 @@ export pyfraction
12061206

12071207
### eval/exec
12081208

1209-
const MODULE_GLOBALS = Lockable(Dict{Module,Py}())
1209+
const MODULE_GLOBALS = ErrorLockable(Dict{Module,Py}())
12101210

12111211
function _pyeval_args(code, globals, locals)
12121212
if code isa AbstractString

src/JlWrap/C.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
module Cjl
22

33
using ...C: C
4-
using ...Utils: Utils, Lockable
4+
using ...Utils: Utils, ErrorLockable
55
using Base: @kwdef
66
using UnsafePointers: UnsafePtr
77
using Serialization: serialize, deserialize
@@ -16,7 +16,7 @@ const PyJuliaBase_Type = Ref(C.PyNULL)
1616

1717
# we store the actual julia values here
1818
# the `value` field of `PyJuliaValueObject` indexes into here
19-
const PYJLVALUES = Lockable((; values=IdDict{Int,Any}(), free_slots=Int[], next_slot=Ref(1)))
19+
const PYJLVALUES = ErrorLockable((; values=IdDict{Int,Any}(), free_slots=Int[], next_slot=Ref(1)))
2020

2121
function _pyjl_new(t::C.PyPtr, ::C.PyPtr, ::C.PyPtr)
2222
o = ccall(UnsafePtr{C.PyTypeObject}(t).alloc[!], C.PyPtr, (C.PyPtr, C.Py_ssize_t), t, 0)
@@ -39,7 +39,7 @@ function _pyjl_dealloc(o::C.PyPtr)
3939
nothing
4040
end
4141

42-
const PYJLMETHODS = Lockable([])
42+
const PYJLMETHODS = ErrorLockable([])
4343

4444
function PyJulia_MethodNum(f)
4545
@nospecialize f
@@ -65,7 +65,7 @@ function _pyjl_callmethod(o::C.PyPtr, args::C.PyPtr)
6565
return _pyjl_callmethod(f, o, args, nargs)::C.PyPtr
6666
end
6767

68-
const PYJLBUFCACHE = Lockable(Dict{Ptr{Cvoid},Any}())
68+
const PYJLBUFCACHE = ErrorLockable(Dict{Ptr{Cvoid},Any}())
6969

7070
@kwdef struct PyBufferInfo{N}
7171
# data

src/Utils/Utils.jl

Lines changed: 30 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -308,30 +308,38 @@ function Base.iterate(x::StaticString{UInt32,N}, i::Int = 1) where {N}
308308
end
309309
end
310310

311-
@static if !isdefined(Base, :Lockable)
312-
"""
313-
Compat for `Base.Lockable` (introduced in Julia 1.11)
314-
"""
315-
struct Lockable{T,L}
316-
value::T
317-
lock::L
311+
struct RaceConditionError <: Exception
312+
msg::String
313+
end
314+
Base.showerror(io::IO, e::RaceConditionError) = print(io, e.msg)
315+
316+
317+
struct ErrorLock
318+
x::ReentrantLock
319+
ErrorLock() = new(ReentrantLock())
320+
end
321+
Base.trylock(l::ErrorLock) = trylock(l.x)
322+
Base.unlock(l::ErrorLock) = unlock(l.x)
323+
Base.islocked(l::ErrorLock) = islocked(l.x)
324+
325+
function Base.lock(l::ErrorLock)
326+
did_lock = trylock(l.x)
327+
if !did_lock
328+
throw(RaceConditionError("unsafe concurrent access to global mutable"))
318329
end
330+
return l
331+
end
319332

320-
Lockable(value) = Lockable(value, ReentrantLock())
321-
322-
# function Base.lock(f, l::Lockable)
323-
# lock(l.lock) do
324-
# f(l.value)
325-
# end
326-
# end
327-
328-
Base.lock(l::Lockable) = lock(l.lock)
329-
# Base.trylock(l::Lockable) = trylock(l.lock)
330-
Base.unlock(l::Lockable) = unlock(l.lock)
331-
Base.islocked(l::Lockable) = islocked(l.lock)
332-
Base.getindex(l::Lockable) = (@assert islocked(l); l.value)
333-
else
334-
const Lockable = Base.Lockable
333+
struct ErrorLockable{T,L}
334+
value::T
335+
lock::L
335336
end
336337

338+
ErrorLockable(value) = ErrorLockable(value, ErrorLock())
339+
340+
Base.lock(l::ErrorLockable) = lock(l.lock)
341+
Base.unlock(l::ErrorLockable) = unlock(l.lock)
342+
Base.islocked(l::ErrorLockable) = islocked(l.lock)
343+
Base.getindex(l::ErrorLockable) = (@assert islocked(l); l.value)
344+
337345
end

0 commit comments

Comments
 (0)