Skip to content

Commit 29ab58f

Browse files
committed
Metal: Replace unreachable control flow with exit block branches.
1 parent dd6a8eb commit 29ab58f

File tree

2 files changed

+97
-1
lines changed

2 files changed

+97
-1
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
2020
[compat]
2121
ExprTools = "0.1"
2222
InteractiveUtils = "1"
23-
LLVM = "8, 9"
23+
LLVM = "9"
2424
Libdl = "1"
2525
Logging = "1"
2626
PrecompileTools = "1"

src/metal.jl

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,13 @@ function finish_ir!(@nospecialize(job::CompilerJob{MetalCompilerTarget}), mod::L
134134
entry::LLVM.Function)
135135
entry_fn = LLVM.name(entry)
136136

137+
# get rid of unreachable control flow (JuliaLang/Metal.jl#370)
138+
if job.config.target.macos < v"15"
139+
for f in functions(mod)
140+
replace_unreachable!(job, f)
141+
end
142+
end
143+
137144
# add kernel metadata
138145
if job.config.kernel
139146
entry = add_address_spaces!(job, mod, entry)
@@ -142,6 +149,7 @@ function finish_ir!(@nospecialize(job::CompilerJob{MetalCompilerTarget}), mod::L
142149

143150
add_module_metadata!(job, mod)
144151

152+
# JuliaLang/Metal.jl#113
145153
hide_noreturn!(mod)
146154
end
147155

@@ -1075,3 +1083,91 @@ function annotate_air_intrinsics!(@nospecialize(job::CompilerJob), mod::LLVM.Mod
10751083

10761084
return changed
10771085
end
1086+
1087+
# replace unreachable control flow with branches to the exit block
1088+
#
1089+
# before macOS 15, code generated by Julia 1.11 causes compilation failures in the back-end.
1090+
# the reduced example contains unreachable control flow executed divergently, so this is a
1091+
# similar issue as encountered with NVIDIA, albeit causing crashes instead of miscompiles.
1092+
#
1093+
# the proposed solution is to avoid (divergent) unreachable control flow, instead replacing
1094+
# it by branches to the exit block. since `unreachable` doesn't lower to anything that
1095+
# aborts the kernel anyway (can we fix this?), this transformation should be safe.
1096+
function replace_unreachable!(@nospecialize(job::CompilerJob), f::LLVM.Function)
1097+
# find unreachable instructions and exit blocks
1098+
unreachables = Instruction[]
1099+
exit_blocks = BasicBlock[]
1100+
for bb in blocks(f), inst in instructions(bb)
1101+
if isa(inst, LLVM.UnreachableInst)
1102+
push!(unreachables, inst)
1103+
end
1104+
if isa(inst, LLVM.RetInst)
1105+
push!(exit_blocks, bb)
1106+
end
1107+
end
1108+
isempty(unreachables) && return false
1109+
1110+
# if we don't have an exit block, we can't do much. we could insert a return, but that
1111+
# would probably keep the problematic control flow just as it is.
1112+
isempty(exit_blocks) && return false
1113+
1114+
@dispose builder=IRBuilder() begin
1115+
# if we have multiple exit blocks, take the last one, which is hopefully the least
1116+
# divergent (assuming divergent control flow is the root of the problem here).
1117+
exit_block = last(exit_blocks)
1118+
1119+
ret = terminator(exit_block)
1120+
if first(instructions(exit_block)) == ret
1121+
return_block = exit_block
1122+
else
1123+
# split the exit block right before the ret, so that we only have to care about
1124+
# the value that's returned, and not about any other SSA value in the block.
1125+
return_block = BasicBlock(f, "ret")
1126+
move_after(return_block, exit_block)
1127+
1128+
# emit a return
1129+
position!(builder, return_block)
1130+
if isempty(operands(ret))
1131+
ret!(builder)
1132+
else
1133+
# XXX: support aggregate returns?
1134+
val = only(operands(ret))
1135+
phi = phi!(builder, value_type(val))
1136+
push!(incoming(phi), (val, exit_block))
1137+
ret!(builder, phi)
1138+
end
1139+
1140+
# replace with a branch
1141+
position!(builder, ret)
1142+
br!(builder, return_block)
1143+
unsafe_delete!(exit_block, ret)
1144+
end
1145+
1146+
# replace the unreachable with a branch to the return block
1147+
for unreachable in unreachables
1148+
bb = LLVM.parent(unreachable)
1149+
1150+
# remove preceding traps to avoid reconstructing unreachable control flow
1151+
prev = previnst(unreachable)
1152+
if isa(prev, LLVM.CallInst) && name(called_operand(prev)) == "llvm.trap"
1153+
unsafe_delete!(bb, prev)
1154+
end
1155+
1156+
# replace the unreachable with a branch to the return block
1157+
position!(builder, unreachable)
1158+
br!(builder, return_block)
1159+
unsafe_delete!(bb, unreachable)
1160+
1161+
# patch up any phi nodes in the return block
1162+
for inst in instructions(return_block)
1163+
if isa(inst, LLVM.PHIInst)
1164+
undef = UndefValue(value_type(inst))
1165+
vals = incoming(inst)
1166+
push!(vals, (undef, bb))
1167+
end
1168+
end
1169+
end
1170+
end
1171+
1172+
return true
1173+
end

0 commit comments

Comments
 (0)