diff --git a/src/orcv2.jl b/src/orcv2.jl index 7434fbaf..5edffe19 100644 --- a/src/orcv2.jl +++ b/src/orcv2.jl @@ -176,12 +176,13 @@ function JITDylib(es::ExecutionSession, name; bare=false) JITDylib(ref) end -@checked struct DefinitionGenerator +abstract type AbstractDefinitionGenerator end +@checked struct DefinitionGenerator <: AbstractDefinitionGenerator ref::API.LLVMOrcDefinitionGeneratorRef end Base.unsafe_convert(::Type{API.LLVMOrcDefinitionGeneratorRef}, dg::DefinitionGenerator) = dg.ref -function add!(jd::JITDylib, dg::DefinitionGenerator) +function add!(jd::JITDylib, dg::AbstractDefinitionGenerator) API.LLVMOrcJITDylibAddGenerator(jd, dg) end @@ -191,7 +192,82 @@ function CreateDynamicLibrarySearchGeneratorForProcess(prefix) DefinitionGenerator(ref[]) end -# LLVMOrcCreateCustomCAPIDefinitionGenerator(F, Ctx) +# We can do this async by copying content of `LookupState` and setting it to +# C_NULL and returning ErrorSuccess. We then would need to call `LookupContinue` +# but that function was only added in LLVM 15 to the API. +# +# Note LookupSet get's destroyed when we return here +function DefinitionGeneratorTryToGenerateFunction( + GeneratorObj::API.LLVMOrcDefinitionGeneratorRef, ctx::Ptr{Cvoid}, + LookupState::Ptr{API.LLVMOrcLookupStateRef}, Kind::API.LLVMOrcLookupKind, + JD::API.LLVMOrcJITDylibRef, JDLookupFlags::API.LLVMOrcJITDylibLookupFlags, + LookupSet::API.LLVMOrcCLookupSet, LookupSetSize::Csize_t)::API.LLVMErrorRef + + dg = Base.unsafe_pointer_to_objref(ctx)::CustomDefinitionGenerator + @assert dg.dg.ref === GeneratorObj + lookupSet = Base.unsafe_wrap(Array, LookupSet, LookupSetSize, own=false) + return dg.callback(Kind, JITDylib(JD), JDLookupFlags, lookupSet)::API.LLVMErrorRef +end + +mutable struct CustomDefinitionGenerator <: AbstractDefinitionGenerator + callback + dg::DefinitionGenerator + function CustomDefinitionGenerator(callback) + this = new(callback) + push!(CUSTOM_DG_ROOTS, this) # Globally root DefinitionGenerator + + ref = API.LLVMOrcCreateCustomCAPIDefinitionGenerator( + @cfunction(DefinitionGeneratorTryToGenerateFunction, + API.LLVMErrorRef, + (API.LLVMOrcDefinitionGeneratorRef, Ptr{Cvoid}, + Ptr{API.LLVMOrcLookupStateRef}, API.LLVMOrcLookupKind, + API.LLVMOrcJITDylibRef, API.LLVMOrcJITDylibLookupFlags, + API.LLVMOrcCLookupSet, Csize_t)), + Base.pointer_from_objref(this) + ) + + this.dg = DefinitionGenerator(ref) + return this + end +end +Base.cconvert(::Type{API.LLVMOrcDefinitionGeneratorRef}, dg::CustomDefinitionGenerator) = dg.dg + +# todo: Delete +const CUSTOM_DG_ROOTS = Base.IdSet{CustomDefinitionGenerator}() + +function DynamicLibDefinitionGenerator(path) + handle = Libdl.dlopen(path) + + function libdl_definitions(kind, JD, lookupFlags, lookupSet) + @assert kind == API.LLVMOrcLookupKindStatic + @assert lookupFlags == API.LLVMOrcJITDylibLookupFlagsMatchAllSymbols + + symbols = API.LLVMJITCSymbolMapPair[] + for lookup in lookupSet + if lookup.LookupFlags == API.LLVMOrcSymbolLookupFlagsRequiredSymbol + name = LLVMSymbol(lookup.Name) + ptr = Libdl.dlsym(handle, name; throw_error=false) + + if ptr !== C_NULL + LLVM.retain(name) + address = API.LLVMOrcJITTargetAddress( + reinterpret(UInt, ptr)) + flags = API.LLVMJITSymbolFlags( + API.LLVMJITSymbolGenericFlagsCallable, 0) + symbol = API.LLVMJITEvaluatedSymbol(address, flags) + push!(symbols, API.LLVMJITCSymbolMapPair(name, symbol)) + end + else + @warn "Unkown" lookup.LookupFlags + end + end + mu = absolute_symbols(symbols) + define(JD, mu) + # TODO: API.LLVMErrorSuccess is not a LLVMErrorRef + return reinterpret(API.LLVMErrorRef, API.LLVMErrorSuccess) + end + return LLVM.CustomDefinitionGenerator(libdl_definitions) +end function lookup_dylib(es::ExecutionSession, name) ref = API.LLVMOrcExecutionSessionGetJITDylibByName(es, name) diff --git a/test/Project.toml b/test/Project.toml index 6dd5450d..19351b9e 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,5 +1,6 @@ [deps] InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" LLVMExtra_jll = "dad2f222-ce93-54a1-a47d-0025e8a3acab" +NUMA_jll = "7f51dc2b-bb24-59f8-b771-bb1490e4195d" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Unicode = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" diff --git a/test/orcv2.jl b/test/orcv2.jl index df0a51c4..5bff2ff2 100644 --- a/test/orcv2.jl +++ b/test/orcv2.jl @@ -302,4 +302,21 @@ end end end +import NUMA_jll + +@testset "CustomDefinitionGenerator" begin + @dispose lljit=LLJIT() begin + if NUMA_jll.is_available() + @test_throws ErrorException lookup(lljit, "numa_available") + + dg = LLVM.DynamicLibDefinitionGenerator(NUMA_jll.libnuma) + jd = JITDylib(lljit) + LLVM.add!(jd, dg) + + addr = lookup(lljit, "numa_available") + @test addr !== C_NULL + end + end +end + end