Skip to content
Closed
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/Convert/Convert.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ using ..Core
using ..Core:
C,
Utils,
Lockable,
@autopy,
getptr,
incref,
Expand Down
57 changes: 30 additions & 27 deletions src/Convert/pyconvert.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ struct PyConvertRule
priority::PyConvertPriority
end

const PYCONVERT_RULES = Dict{String,Vector{PyConvertRule}}()
const PYCONVERT_EXTRATYPES = Py[]
const PYCONVERT_RULES = Lockable(Dict{String,Vector{PyConvertRule}}())
const PYCONVERT_EXTRATYPES = Lockable(Py[])

"""
pyconvert_add_rule(tname::String, T::Type, func::Function, priority::PyConvertPriority=PYCONVERT_PRIORITY_NORMAL)
Expand Down Expand Up @@ -69,11 +69,11 @@ function pyconvert_add_rule(
priority::PyConvertPriority = PYCONVERT_PRIORITY_NORMAL,
)
@nospecialize type func
push!(
get!(Vector{PyConvertRule}, PYCONVERT_RULES, pytypename),
Base.@lock PYCONVERT_RULES push!(
get!(Vector{PyConvertRule}, PYCONVERT_RULES[], pytypename),
PyConvertRule(type, func, priority),
)
empty!.(values(PYCONVERT_RULES_CACHE))
Base.@lock PYCONVERT_RULES_CACHE empty!.(values(PYCONVERT_RULES_CACHE[]))
return
end

Expand Down Expand Up @@ -163,7 +163,7 @@ function _pyconvert_get_rules(pytype::Py)
omro = collect(pytype.__mro__)
basetypes = Py[pytype]
basemros = Vector{Py}[omro]
for xtype in PYCONVERT_EXTRATYPES
Base.@lock PYCONVERT_EXTRATYPES for xtype in PYCONVERT_EXTRATYPES[]
# find the topmost supertype of
xbase = PyNULL
for base in omro
Expand Down Expand Up @@ -248,9 +248,9 @@ function _pyconvert_get_rules(pytype::Py)
mro = String[x for xs in xmro for x in xs]

# get corresponding rules
rules = PyConvertRule[
rules = Base.@lock PYCONVERT_RULES PyConvertRule[
rule for tname in mro for
rule in get!(Vector{PyConvertRule}, PYCONVERT_RULES, tname)
rule in get!(Vector{PyConvertRule}, PYCONVERT_RULES[], tname)
]

# order the rules by priority, then by original order
Expand All @@ -261,10 +261,10 @@ function _pyconvert_get_rules(pytype::Py)
return rules
end

const PYCONVERT_PREFERRED_TYPE = Dict{Py,Type}()
const PYCONVERT_PREFERRED_TYPE = Lockable(Dict{Py,Type}())

pyconvert_preferred_type(pytype::Py) =
get!(PYCONVERT_PREFERRED_TYPE, pytype) do
Base.@lock PYCONVERT_PREFERRED_TYPE get!(PYCONVERT_PREFERRED_TYPE[], pytype) do
if pyissubclass(pytype, pybuiltins.int)
Union{Int,BigInt}
else
Expand Down Expand Up @@ -307,10 +307,10 @@ end

pyconvert_fix(::Type{T}, func) where {T} = x -> func(T, x)

const PYCONVERT_RULES_CACHE = Dict{Type,Dict{C.PyPtr,Vector{Function}}}()
const PYCONVERT_RULES_CACHE = Lockable(Dict{Type,Dict{C.PyPtr,Vector{Function}}}())

@generated pyconvert_rules_cache(::Type{T}) where {T} =
get!(Dict{C.PyPtr,Vector{Function}}, PYCONVERT_RULES_CACHE, T)
Base.@lock PYCONVERT_RULES_CACHE get!(Dict{C.PyPtr,Vector{Function}}, PYCONVERT_RULES_CACHE[], T)

function pyconvert_rule_fast(::Type{T}, x::Py) where {T}
if T isa Union
Expand Down Expand Up @@ -351,12 +351,13 @@ function pytryconvert(::Type{T}, x_) where {T}
# get rules from the cache
# TODO: we should hold weak references and clear the cache if types get deleted
tptr = C.Py_Type(x)
trules = pyconvert_rules_cache(T)
rules = get!(trules, tptr) do
t = pynew(incref(tptr))
ans = pyconvert_get_rules(T, t)::Vector{Function}
pydel!(t)
ans
rules = Base.@lock PYCONVERT_RULES_CACHE let trules = pyconvert_rules_cache(T)
get!(trules, tptr) do
t = pynew(incref(tptr))
ans = pyconvert_get_rules(T, t)::Vector{Function}
pydel!(t)
ans
end
end

# apply the rules
Expand Down Expand Up @@ -418,15 +419,17 @@ pyconvertarg(::Type{T}, x, name) where {T} = @autopy x @pyconvert T x_ begin
end

function init_pyconvert()
push!(PYCONVERT_EXTRATYPES, pyimport("io" => "IOBase"))
push!(
PYCONVERT_EXTRATYPES,
pyimport("numbers" => ("Number", "Complex", "Real", "Rational", "Integral"))...,
)
push!(
PYCONVERT_EXTRATYPES,
pyimport("collections.abc" => ("Iterable", "Sequence", "Set", "Mapping"))...,
)
Base.@lock PYCONVERT_EXTRATYPES begin
push!(PYCONVERT_EXTRATYPES[], pyimport("io" => "IOBase"))
push!(
PYCONVERT_EXTRATYPES[],
pyimport("numbers" => ("Number", "Complex", "Real", "Rational", "Integral"))...,
)
push!(
PYCONVERT_EXTRATYPES[],
pyimport("collections.abc" => ("Iterable", "Sequence", "Set", "Mapping"))...,
)
end

priority = PYCONVERT_PRIORITY_CANONICAL
pyconvert_add_rule("builtins:NoneType", Nothing, pyconvert_rule_none, priority)
Expand Down
2 changes: 1 addition & 1 deletion src/Core/Core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ const ROOT_DIR = dirname(dirname(@__DIR__))
using ..PythonCall: PythonCall # needed for docstring cross-refs
using ..C: C
using ..GC: GC
using ..Utils: Utils
using ..Utils: Utils, Lockable
using Base: @propagate_inbounds, @kwdef
using Dates:
Date,
Expand Down
11 changes: 6 additions & 5 deletions src/Core/Py.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ decref(x::Py) = Base.GC.@preserve x (decref(getptr(x)); x)

Base.unsafe_convert(::Type{C.PyPtr}, x::Py) = getptr(x)

const PYNULL_CACHE = Py[]
const PYNULL_CACHE = Lockable(Py[])

"""
pynew([ptr])
Expand All @@ -69,12 +69,13 @@ points at, i.e. the new `Py` object owns a reference.
Note that NULL Python objects are not safe in the sense that most API functions will probably
crash your Julia session if you pass a NULL argument.
"""
pynew() =
if isempty(PYNULL_CACHE)
pynew() = Base.@lock PYNULL_CACHE begin
if isempty(PYNULL_CACHE[])
Py(Val(:new), C.PyNULL)
else
pop!(PYNULL_CACHE)
pop!(PYNULL_CACHE[])
end
end

const PyNULL = pynew()

Expand Down Expand Up @@ -119,7 +120,7 @@ function pydel!(x::Py)
C.Py_DecRef(ptr)
setptr!(x, C.PyNULL)
end
push!(PYNULL_CACHE, x)
Base.@lock PYNULL_CACHE push!(PYNULL_CACHE[], x)
return
end

Expand Down
4 changes: 2 additions & 2 deletions src/Core/builtins.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1206,7 +1206,7 @@ export pyfraction

### eval/exec

const MODULE_GLOBALS = Dict{Module,Py}()
const MODULE_GLOBALS = Lockable(Dict{Module,Py}())

function _pyeval_args(code, globals, locals)
if code isa AbstractString
Expand All @@ -1217,7 +1217,7 @@ function _pyeval_args(code, globals, locals)
throw(ArgumentError("code must be a string or Python code"))
end
if globals isa Module
globals_ = get!(pydict, MODULE_GLOBALS, globals)
globals_ = Base.@lock MODULE_GLOBALS get!(pydict, MODULE_GLOBALS[], globals)
elseif ispy(globals)
globals_ = globals
else
Expand Down
53 changes: 31 additions & 22 deletions src/JlWrap/C.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module Cjl

using ...C: C
using ...Utils: Utils
using ...Utils: Utils, Lockable
using Base: @kwdef
using UnsafePointers: UnsafePtr
using Serialization: serialize, deserialize
Expand All @@ -16,9 +16,7 @@ const PyJuliaBase_Type = Ref(C.PyNULL)

# we store the actual julia values here
# the `value` field of `PyJuliaValueObject` indexes into here
const PYJLVALUES = []
# unused indices in PYJLVALUES
const PYJLFREEVALUES = Int[]
const PYJLVALUES = Lockable((; values=IdDict{Int,Any}(), free_slots=Int[], next_slot=Ref(1)))

function _pyjl_new(t::C.PyPtr, ::C.PyPtr, ::C.PyPtr)
o = ccall(UnsafePtr{C.PyTypeObject}(t).alloc[!], C.PyPtr, (C.PyPtr, C.Py_ssize_t), t, 0)
Expand All @@ -31,20 +29,24 @@ end
function _pyjl_dealloc(o::C.PyPtr)
idx = UnsafePtr{PyJuliaValueObject}(o).value[]
if idx != 0
PYJLVALUES[idx] = nothing
push!(PYJLFREEVALUES, idx)
Base.@lock PYJLVALUES begin
delete!(PYJLVALUES[].values, idx)
push!(PYJLVALUES[].free_slots, idx)
end
end
UnsafePtr{PyJuliaValueObject}(o).weaklist[!] == C.PyNULL || C.PyObject_ClearWeakRefs(o)
ccall(UnsafePtr{C.PyTypeObject}(C.Py_Type(o)).free[!], Cvoid, (C.PyPtr,), o)
nothing
end

const PYJLMETHODS = Vector{Any}()
const PYJLMETHODS = Lockable([])

function PyJulia_MethodNum(f)
@nospecialize f
push!(PYJLMETHODS, f)
return length(PYJLMETHODS)
Base.@lock PYJLMETHODS begin
push!(PYJLMETHODS[], f)
return length(PYJLMETHODS[])
end
end

function _pyjl_isnull(o::C.PyPtr, ::C.PyPtr)
Expand All @@ -58,12 +60,12 @@ function _pyjl_callmethod(o::C.PyPtr, args::C.PyPtr)
@assert nargs > 0
num = C.PyLong_AsLongLong(C.PyTuple_GetItem(args, 0))
num == -1 && return C.PyNULL
f = PYJLMETHODS[num]
f = Base.@lock PYJLMETHODS PYJLMETHODS[][num]
# this form gets defined in jlwrap/base.jl
return _pyjl_callmethod(f, o, args, nargs)::C.PyPtr
end

const PYJLBUFCACHE = Dict{Ptr{Cvoid},Any}()
const PYJLBUFCACHE = Lockable(Dict{Ptr{Cvoid},Any}())

@kwdef struct PyBufferInfo{N}
# data
Expand Down Expand Up @@ -177,7 +179,7 @@ function _pyjl_get_buffer_impl(

# internal
cptr = Base.pointer_from_objref(c)
PYJLBUFCACHE[cptr] = c
Base.@lock PYJLBUFCACHE PYJLBUFCACHE[][cptr] = c
b.internal[] = cptr

# obj
Expand All @@ -195,7 +197,7 @@ function _pyjl_get_buffer(o::C.PyPtr, buf::Ptr{C.Py_buffer}, flags::Cint)
C.Py_DecRef(num_)
num == -1 && return Cint(-1)
try
f = PYJLMETHODS[num]
f = Base.@lock PYJLMETHODS PYJLMETHODS[][num]
x = PyJuliaValue_GetValue(o)
return _pyjl_get_buffer_impl(o, buf, flags, x, f)::Cint
catch exc
Expand All @@ -209,7 +211,7 @@ function _pyjl_get_buffer(o::C.PyPtr, buf::Ptr{C.Py_buffer}, flags::Cint)
end

function _pyjl_release_buffer(xo::C.PyPtr, buf::Ptr{C.Py_buffer})
delete!(PYJLBUFCACHE, UnsafePtr(buf).internal[!])
Base.@lock PYJLBUFCACHE delete!(PYJLBUFCACHE[], UnsafePtr(buf).internal[!])
nothing
end

Expand Down Expand Up @@ -339,22 +341,29 @@ end

PyJuliaValue_IsNull(o) = Base.GC.@preserve o UnsafePtr{PyJuliaValueObject}(C.asptr(o)).value[] == 0

PyJuliaValue_GetValue(o) = Base.GC.@preserve o PYJLVALUES[UnsafePtr{PyJuliaValueObject}(C.asptr(o)).value[]]
PyJuliaValue_GetValue(o) = Base.GC.@preserve o begin
idx = UnsafePtr{PyJuliaValueObject}(C.asptr(o)).value[]
Base.@lock PYJLVALUES PYJLVALUES[].values[idx]
end

PyJuliaValue_SetValue(_o, @nospecialize(v)) = Base.GC.@preserve _o begin
o = C.asptr(_o)
idx = UnsafePtr{PyJuliaValueObject}(o).value[]
if idx == 0
if isempty(PYJLFREEVALUES)
push!(PYJLVALUES, v)
idx = length(PYJLVALUES)
else
idx = pop!(PYJLFREEVALUES)
PYJLVALUES[idx] = v
Base.@lock PYJLVALUES begin
if isempty(PYJLVALUES[].free_slots)
idx = PYJLVALUES[].next_slot[]
PYJLVALUES[].next_slot[] += 1
else
idx = pop!(PYJLVALUES[].free_slots)
end
PYJLVALUES[].values[idx] = v
end
UnsafePtr{PyJuliaValueObject}(o).value[] = idx
else
PYJLVALUES[idx] = v
Base.@lock PYJLVALUES begin
PYJLVALUES[].values[idx] = v
end
end
nothing
end
Expand Down
1 change: 1 addition & 0 deletions src/JlWrap/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ function Cjl._pyjl_callmethod(f, self_::C.PyPtr, args_::C.PyPtr, nargs::C.Py_ssi
pybuiltins.NotImplementedError,
"__jl_callmethod not implemented for this many arguments",
)
return C.PyNULL
end
return getptr(incref(ans))
catch exc
Expand Down
26 changes: 26 additions & 0 deletions src/Utils/Utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -308,4 +308,30 @@ function Base.iterate(x::StaticString{UInt32,N}, i::Int = 1) where {N}
end
end

@static if !isdefined(Base, :Lockable)
"""
Compat for `Base.Lockable` (introduced in Julia 1.11)
"""
struct Lockable{T,L}
value::T
lock::L
end

Lockable(value) = Lockable(value, ReentrantLock())

function Base.lock(f, l::Lockable)
lock(l.lock) do
f(l.value)
end
end

Base.lock(l::Lockable) = lock(l.lock)
Base.trylock(l::Lockable) = trylock(l.lock)
Base.unlock(l::Lockable) = unlock(l.lock)
Base.islocked(l::Lockable) = islocked(l.lock)
Base.getindex(l::Lockable) = (@assert islocked(l); l.value)
else
const Lockable = Base.Lockable
end

end
Loading