Skip to content

Commit bf097de

Browse files
committed
Imaging-mode PLT codegen and relocatable TLS
1 parent 531dccd commit bf097de

File tree

4 files changed

+76
-4
lines changed

4 files changed

+76
-4
lines changed

src/interface.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,10 @@ end
222222
# Has the runtime available and does not require special handling
223223
uses_julia_runtime(@nospecialize(job::CompilerJob)) = false
224224

225+
# Should we emit code in imaging mode (i.e. without embedding concrete runtime addresses)?
226+
imaging_mode(@nospecialize(job::CompilerJob)) = imaging_mode(job.config.target)
227+
imaging_mode(@nospecialize(target::AbstractCompilerTarget)) = false
228+
225229
# Is it legal to run vectorization passes on this target
226230
can_vectorize(@nospecialize(job::CompilerJob)) = false
227231

src/jlgen.jl

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -717,6 +717,8 @@ function compile_method_instance(@nospecialize(job::CompilerJob))
717717

718718
# set-up the compiler interface
719719
debug_info_kind = llvm_debug_info(job)
720+
imaging = imaging_mode(job)
721+
720722
cgparams = (;
721723
track_allocations = false,
722724
code_coverage = false,
@@ -725,6 +727,9 @@ function compile_method_instance(@nospecialize(job::CompilerJob))
725727
debug_info_kind = Cint(debug_info_kind),
726728
safepoint_on_entry = can_safepoint(job),
727729
gcstack_arg = false)
730+
if :use_jlplt in fieldnames(Base.CodegenParams)
731+
cgparams = (; cgparams..., use_jlplt = imaging)
732+
end
728733
if VERSION < v"1.12.0-DEV.1667"
729734
cgparams = (; lookup = Base.unsafe_convert(Ptr{Nothing}, lookup_cb), cgparams... )
730735
end
@@ -748,6 +753,8 @@ function compile_method_instance(@nospecialize(job::CompilerJob))
748753
Metadata(ConstantInt(DEBUG_METADATA_VERSION()))
749754
end
750755

756+
imaging_flag = imaging ? 1 : 0
757+
751758
native_code = if VERSION >= v"1.12.0-DEV.1823"
752759
codeinfos = Any[]
753760
for (ci, src) in populated
@@ -760,11 +767,11 @@ function compile_method_instance(@nospecialize(job::CompilerJob))
760767
elseif VERSION >= v"1.12.0-DEV.1667"
761768
ccall(:jl_create_native, Ptr{Cvoid},
762769
(Vector{MethodInstance}, LLVM.API.LLVMOrcThreadSafeModuleRef, Ptr{Base.CodegenParams}, Cint, Cint, Cint, Csize_t, Ptr{Cvoid}),
763-
[job.source], ts_mod, Ref(params), CompilationPolicyExtern, #=imaging mode=# 0, #=external linkage=# 0, job.world, Base.unsafe_convert(Ptr{Nothing}, lookup_cb))
770+
[job.source], ts_mod, Ref(params), CompilationPolicyExtern, imaging_flag, #=external linkage=# 0, job.world, Base.unsafe_convert(Ptr{Nothing}, lookup_cb))
764771
else
765772
ccall(:jl_create_native, Ptr{Cvoid},
766773
(Vector{MethodInstance}, LLVM.API.LLVMOrcThreadSafeModuleRef, Ptr{Base.CodegenParams}, Cint, Cint, Cint, Csize_t),
767-
[job.source], ts_mod, Ref(params), CompilationPolicyExtern, #=imaging mode=# 0, #=external linkage=# 0, job.world)
774+
[job.source], ts_mod, Ref(params), CompilationPolicyExtern, imaging_flag, #=external linkage=# 0, job.world)
768775
end
769776
@assert native_code != C_NULL
770777

src/optim.jl

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,14 @@ function optimize!(@nospecialize(job::CompilerJob), mod::LLVM.Module; opt_level=
1919
end
2020
run!(pb, mod, tm)
2121
end
22+
23+
# Make sure any lingering TLS getters are rewritten even if upstream LLVM passes
24+
# transformed them before the GPULowerPTLSPass had a chance to run.
25+
if occursin("StaticCompilerTarget", string(typeof(job.config.target))) &&
26+
uses_julia_runtime(job)
27+
lower_ptls!(mod)
28+
end
29+
2230
optimize_module!(job, mod)
2331
run!(DeadArgumentEliminationPass(), mod, tm)
2432
return
@@ -405,7 +413,31 @@ function lower_ptls!(mod::LLVM.Module)
405413

406414
intrinsic = "julia.get_pgcstack"
407415

408-
if haskey(functions(mod), intrinsic)
416+
# On host-style static targets we want a relocatable call into libjulia instead of
417+
# embedding the pointer to the TLS getter. Replace the intrinsic with a declared
418+
# libjulia call to avoid baking absolute addresses that crash in standalone binaries.
419+
if haskey(functions(mod), intrinsic) &&
420+
occursin("StaticCompilerTarget", string(typeof(job.config.target))) &&
421+
uses_julia_runtime(job)
422+
423+
pgc_fn = functions(mod)[intrinsic]
424+
jl_decl = if haskey(functions(mod), "jl_get_pgcstack")
425+
functions(mod)["jl_get_pgcstack"]
426+
else
427+
LLVM.Function(mod, "jl_get_pgcstack", LLVM.FunctionType(LLVM.PointerType()))
428+
end
429+
430+
for use in uses(pgc_fn)
431+
call = user(use)::LLVM.CallInst
432+
@dispose builder=IRBuilder() begin
433+
position!(builder, call)
434+
repl = call!(builder, function_type(jl_decl), jl_decl, LLVM.Value[])
435+
replace_uses!(call, repl)
436+
end
437+
erase!(call)
438+
changed = true
439+
end
440+
elseif haskey(functions(mod), intrinsic)
409441
ptls_getter = functions(mod)[intrinsic]
410442

411443
for use in uses(ptls_getter)
@@ -419,6 +451,34 @@ function lower_ptls!(mod::LLVM.Module)
419451
end
420452
end
421453

454+
# Newer Julia versions sometimes lower the TLS getter to an inttoptr call that bakes
455+
# the address of `jl_get_pgcstack_static` into the IR. Rewrite those calls as well to
456+
# make sure we always end up with a relocatable reference into libjulia when the
457+
# runtime is linked.
458+
if uses_julia_runtime(job) && occursin("StaticCompilerTarget", string(typeof(job.config.target)))
459+
jl_decl = if haskey(functions(mod), "jl_get_pgcstack")
460+
functions(mod)["jl_get_pgcstack"]
461+
else
462+
LLVM.Function(mod, "jl_get_pgcstack", LLVM.FunctionType(LLVM.PointerType()))
463+
end
464+
465+
for f in functions(mod), bb in blocks(f), inst in instructions(bb)
466+
inst isa LLVM.CallInst || continue
467+
468+
callee = LLVM.called_operand(inst)
469+
if callee isa LLVM.ConstantExpr && occursin("inttoptr", string(callee)) &&
470+
occursin("pgcstack", string(inst))
471+
@dispose builder=IRBuilder() begin
472+
position!(builder, inst)
473+
repl = call!(builder, function_type(jl_decl), jl_decl, LLVM.Value[])
474+
replace_uses!(inst, repl)
475+
end
476+
erase!(inst)
477+
changed = true
478+
end
479+
end
480+
end
481+
422482
return changed
423483
end
424484
GPULowerPTLSPass() = NewPMModulePass("GPULowerPTLS", lower_ptls!)

test/native.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -512,7 +512,8 @@ end
512512
Native.code_execution(mod.foobar, Tuple{Ptr{Int}})) do msg
513513
if VERSION >= v"1.11-"
514514
occursin("invalid LLVM IR", msg) &&
515-
occursin(GPUCompiler.LAZY_FUNCTION, msg) &&
515+
(occursin(GPUCompiler.LAZY_FUNCTION, msg) ||
516+
occursin(GPUCompiler.RUNTIME_FUNCTION, msg)) &&
516517
occursin("call to time", msg) &&
517518
occursin("[1] foobar", msg)
518519
else

0 commit comments

Comments
 (0)