Skip to content

Commit 8f4cb60

Browse files
committed
When cloning, update constant uses with a materializer.
1 parent 3eabb5f commit 8f4cb60

File tree

1 file changed

+38
-33
lines changed

1 file changed

+38
-33
lines changed

src/irgen.jl

Lines changed: 38 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -501,30 +501,42 @@ function lower_byval(@nospecialize(job::CompilerJob), mod::LLVM.Module, f::LLVM.
501501
else
502502
changes = LLVM.API.LLVMCloneFunctionChangeTypeLocalChangesOnly
503503
end
504-
clone_into!(new_f, f; value_map, changes)
504+
505+
# use a value materializer for replacing uses of the function in constants
506+
# NOTE: we assume kernel functions can't be called. on-device kernel launches,
507+
# e.g. CUDA's dynamic parallelism, will pass the function to an API instead,
508+
# and we update those constant expressions arguments here.
509+
function materializer(val)
510+
opcodes = (LLVM.API.LLVMPtrToInt, LLVM.API.LLVMAddrSpaceCast, LLVM.API.LLVMBitCast)
511+
if val isa LLVM.ConstantExpr && opcode(val) in opcodes
512+
target = operands(val)[1]
513+
if target == f
514+
return if opcode(val) == LLVM.API.LLVMPtrToInt
515+
LLVM.const_ptrtoint(new_f, llvmtype(val))
516+
elseif opcode(val) == LLVM.API.LLVMAddrSpaceCast
517+
LLVM.const_addrspacecast(new_f, llvmtype(val))
518+
elseif opcode(val) == LLVM.API.LLVMBitCast
519+
LLVM.const_bitcast(new_f, llvmtype(val))
520+
end
521+
end
522+
end
523+
return val
524+
end
525+
526+
# we don't want module-level changes, because otherwise LLVM will clone metadata,
527+
# resulting in mismatching references between `!dbg` metadata and `dbg` instructions
528+
clone_into!(new_f, f; value_map, changes, materializer)
505529

506530
# fall through
507531
br!(builder, blocks(new_f)[2])
508532
end
509533

510-
# update uses of the kernel
511-
# NOTE: we assume kernel functions can't be called. on-device kernel launches,
512-
# e.g. CUDA's dynamic parallelism, will pass the function to an API instead,
513-
# and we update those constant expressions arguments here.
534+
# drop unused constants that may be referring to the old functions
535+
# XXX: can we do this differently?
514536
for use in uses(f)
515537
val = user(use)
516-
if val isa LLVM.ConstantExpr && opcode(val) == LLVM.API.LLVMPtrToInt
517-
target = operands(val)[1]
518-
if target == f
519-
new_val = LLVM.const_ptrtoint(new_f, llvmtype(val))
520-
replace_uses!(val, new_val)
521-
522-
# drop the old constant if it is unused
523-
# XXX: can we do this differently?
524-
if isempty(uses(val))
525-
LLVM.unsafe_destroy!(val)
526-
end
527-
end
538+
if val isa LLVM.ConstantExpr && isempty(uses(val))
539+
LLVM.unsafe_destroy!(val)
528540
end
529541
end
530542

@@ -608,10 +620,17 @@ function add_kernel_state!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
608620

609621
# use a value materializer for replacing uses of the function in constants
610622
function materializer(val)
611-
if val isa LLVM.ConstantExpr && opcode(val) == LLVM.API.LLVMPtrToInt
623+
opcodes = (LLVM.API.LLVMPtrToInt, LLVM.API.LLVMAddrSpaceCast, LLVM.API.LLVMBitCast)
624+
if val isa LLVM.ConstantExpr && opcode(val) in opcodes
612625
src = operands(val)[1]
613626
if haskey(workmap, src)
614-
return LLVM.const_ptrtoint(workmap[src], llvmtype(val))
627+
return if opcode(val) == LLVM.API.LLVMPtrToInt
628+
LLVM.const_ptrtoint(workmap[src], llvmtype(val))
629+
elseif opcode(val) == LLVM.API.LLVMAddrSpaceCast
630+
LLVM.const_addrspacecast(workmap[src], llvmtype(val))
631+
elseif opcode(val) == LLVM.API.LLVMBitCast
632+
LLVM.const_bitcast(workmap[src], llvmtype(val))
633+
end
615634
end
616635
end
617636
return val
@@ -677,20 +696,6 @@ function add_kernel_state!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
677696
replace_uses!(val, new_val)
678697
@assert isempty(uses(val))
679698
unsafe_delete!(LLVM.parent(val), val)
680-
elseif val isa LLVM.ConstantExpr && opcode(val) == LLVM.API.LLVMBitCast
681-
# XXX: why isn't this caught by the value materializer above?
682-
target = operands(val)[1]
683-
@assert target == f
684-
new_val = LLVM.const_bitcast(new_f, llvmtype(val))
685-
rewrite_uses!(val, new_val)
686-
# we can't simply replace this constant expression, as it may be used
687-
# as a call, taking arguments (so we need to rewrite it to pass the state)
688-
689-
# drop the old constant if it is unused
690-
# XXX: can we do this differently?
691-
if isempty(uses(val))
692-
LLVM.unsafe_destroy!(val)
693-
end
694699
else
695700
error("Cannot rewrite unknown use of function: $val")
696701
end

0 commit comments

Comments
 (0)