Skip to content

Commit edef568

Browse files
authored
Merge pull request #299 from JuliaGPU/tb/constexpr_materializer
When cloning, update constant uses with a materializer.
2 parents 3eabb5f + efa4602 commit edef568

File tree

1 file changed

+71
-40
lines changed

1 file changed

+71
-40
lines changed

src/irgen.jl

Lines changed: 71 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -501,30 +501,42 @@ function lower_byval(@nospecialize(job::CompilerJob), mod::LLVM.Module, f::LLVM.
501501
else
502502
changes = LLVM.API.LLVMCloneFunctionChangeTypeLocalChangesOnly
503503
end
504-
clone_into!(new_f, f; value_map, changes)
504+
505+
# use a value materializer for replacing uses of the function in constants
506+
# NOTE: we assume kernel functions can't be called. on-device kernel launches,
507+
# e.g. CUDA's dynamic parallelism, will pass the function to an API instead,
508+
# and we update those constant expressions arguments here.
509+
function materializer(val)
510+
opcodes = (LLVM.API.LLVMPtrToInt, LLVM.API.LLVMAddrSpaceCast, LLVM.API.LLVMBitCast)
511+
if val isa LLVM.ConstantExpr && opcode(val) in opcodes
512+
target = operands(val)[1]
513+
if target == f
514+
return if opcode(val) == LLVM.API.LLVMPtrToInt
515+
LLVM.const_ptrtoint(new_f, llvmtype(val))
516+
elseif opcode(val) == LLVM.API.LLVMAddrSpaceCast
517+
LLVM.const_addrspacecast(new_f, llvmtype(val))
518+
elseif opcode(val) == LLVM.API.LLVMBitCast
519+
LLVM.const_bitcast(new_f, llvmtype(val))
520+
end
521+
end
522+
end
523+
return val
524+
end
525+
526+
# we don't want module-level changes, because otherwise LLVM will clone metadata,
527+
# resulting in mismatching references between `!dbg` metadata and `dbg` instructions
528+
clone_into!(new_f, f; value_map, changes, materializer)
505529

506530
# fall through
507531
br!(builder, blocks(new_f)[2])
508532
end
509533

510-
# update uses of the kernel
511-
# NOTE: we assume kernel functions can't be called. on-device kernel launches,
512-
# e.g. CUDA's dynamic parallelism, will pass the function to an API instead,
513-
# and we update those constant expressions arguments here.
534+
# drop unused constants that may be referring to the old functions
535+
# XXX: can we do this differently?
514536
for use in uses(f)
515537
val = user(use)
516-
if val isa LLVM.ConstantExpr && opcode(val) == LLVM.API.LLVMPtrToInt
517-
target = operands(val)[1]
518-
if target == f
519-
new_val = LLVM.const_ptrtoint(new_f, llvmtype(val))
520-
replace_uses!(val, new_val)
521-
522-
# drop the old constant if it is unused
523-
# XXX: can we do this differently?
524-
if isempty(uses(val))
525-
LLVM.unsafe_destroy!(val)
526-
end
527-
end
538+
if val isa LLVM.ConstantExpr && isempty(uses(val))
539+
LLVM.unsafe_destroy!(val)
528540
end
529541
end
530542

@@ -576,8 +588,30 @@ function add_kernel_state!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
576588
# this is both for extern uses, and to make this transformation a two-step process.
577589
state_intr = kernel_state_intr(mod, T_state)
578590

579-
# add a state argument to every function
580-
worklist = filter(!isdeclaration, collect(functions(mod)))
591+
# determine which functions need a kernel state argument
592+
#
593+
# previously, we add the argument to every function and relied on unused arg elim to
594+
# clean-up the IR. however, some libraries do Funny Stuff, e.g., libdevice bitcasting
595+
# function pointers. such IR is hard to rewrite, so instead be more conservative.
596+
worklist = Set{LLVM.Function}([entry, state_intr])
597+
worklist_length = 0
598+
while worklist_length != length(worklist)
599+
# iteratively discover functions that use the intrinsic or any function calling it
600+
worklist_length = length(worklist)
601+
additions = LLVM.Function[]
602+
for f in worklist, use in uses(f)
603+
inst = user(use)::Instruction
604+
bb = LLVM.parent(inst)
605+
new_f = LLVM.parent(bb)
606+
in(new_f, worklist) || push!(additions, new_f)
607+
end
608+
for f in additions
609+
push!(worklist, f)
610+
end
611+
end
612+
delete!(worklist, state_intr)
613+
614+
# add a state argument
581615
workmap = Dict{LLVM.Function, LLVM.Function}()
582616
for f in worklist
583617
fn = LLVM.name(f)
@@ -608,10 +642,17 @@ function add_kernel_state!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
608642

609643
# use a value materializer for replacing uses of the function in constants
610644
function materializer(val)
611-
if val isa LLVM.ConstantExpr && opcode(val) == LLVM.API.LLVMPtrToInt
645+
opcodes = (LLVM.API.LLVMPtrToInt, LLVM.API.LLVMAddrSpaceCast, LLVM.API.LLVMBitCast)
646+
if val isa LLVM.ConstantExpr && opcode(val) in opcodes
612647
src = operands(val)[1]
613648
if haskey(workmap, src)
614-
return LLVM.const_ptrtoint(workmap[src], llvmtype(val))
649+
return if opcode(val) == LLVM.API.LLVMPtrToInt
650+
LLVM.const_ptrtoint(workmap[src], llvmtype(val))
651+
elseif opcode(val) == LLVM.API.LLVMAddrSpaceCast
652+
LLVM.const_addrspacecast(workmap[src], llvmtype(val))
653+
elseif opcode(val) == LLVM.API.LLVMBitCast
654+
LLVM.const_bitcast(workmap[src], llvmtype(val))
655+
end
615656
end
616657
end
617658
return val
@@ -677,20 +718,6 @@ function add_kernel_state!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
677718
replace_uses!(val, new_val)
678719
@assert isempty(uses(val))
679720
unsafe_delete!(LLVM.parent(val), val)
680-
elseif val isa LLVM.ConstantExpr && opcode(val) == LLVM.API.LLVMBitCast
681-
# XXX: why isn't this caught by the value materializer above?
682-
target = operands(val)[1]
683-
@assert target == f
684-
new_val = LLVM.const_bitcast(new_f, llvmtype(val))
685-
rewrite_uses!(val, new_val)
686-
# we can't simply replace this constant expression, as it may be used
687-
# as a call, taking arguments (so we need to rewrite it to pass the state)
688-
689-
# drop the old constant if it is unused
690-
# XXX: can we do this differently?
691-
if isempty(uses(val))
692-
LLVM.unsafe_destroy!(val)
693-
end
694721
else
695722
error("Cannot rewrite unknown use of function: $val")
696723
end
@@ -721,14 +748,10 @@ function lower_kernel_state!(fun::LLVM.Function)
721748
return false
722749
end
723750

724-
# find the kernel state argument. this should be the first argument of the function.
725-
state_arg = parameters(fun)[1]
726-
T_state = convert(LLVMType, state; ctx)
727-
@assert llvmtype(state_arg) == T_state
728-
729751
# fixup all uses of the state getter to use the newly introduced function state argument
730752
if haskey(functions(mod), "julia.gpu.state_getter")
731753
state_intr = functions(mod)["julia.gpu.state_getter"]
754+
state_arg = nothing # only look-up when needed
732755

733756
Builder(ctx) do builder
734757
for use in uses(state_intr)
@@ -741,6 +764,14 @@ function lower_kernel_state!(fun::LLVM.Function)
741764
bb = LLVM.parent(inst)
742765
f = LLVM.parent(bb)
743766

767+
if state_arg === nothing
768+
# find the kernel state argument. this should be the first argument of
769+
# the function, but only when this function needs the state!
770+
state_arg = parameters(fun)[1]
771+
T_state = convert(LLVMType, state; ctx)
772+
@assert llvmtype(state_arg) == T_state
773+
end
774+
744775
replace_uses!(inst, state_arg)
745776

746777
@assert isempty(uses(inst))

0 commit comments

Comments
 (0)