Skip to content

Commit 0e70137

Browse files
authored
Fix parameter removal bug (#2572)
* Fix parameter removal bug * fix
1 parent a991cd4 commit 0e70137

File tree

4 files changed

+132
-37
lines changed

4 files changed

+132
-37
lines changed

src/compiler/optimize.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -446,8 +446,6 @@ function jl_inst_simplify!(PM::LLVM.ModulePassManager)
446446
)
447447
end
448448

449-
function post_attr!(mod::LLVM.Module) end
450-
451449
cse!(pm) = LLVM.API.LLVMAddEarlyCSEPass(pm)
452450

453451
function optimize!(mod::LLVM.Module, tm::LLVM.TargetMachine)

src/compiler/validation.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,7 @@ function check_ir!(interp, @nospecialize(job::CompilerJob), errors::Vector{IRErr
407407
newf = operands(newf)[1]
408408
end
409409
push!(function_attributes(newf), StringAttribute("enzyme_math", fname))
410+
push!(function_attributes(newf), StringAttribute("enzyme_preserve_primal", "*"))
410411
# TODO we can make this relocatable if desired by having restore lookups re-create this got initializer/etc
411412
# metadata(newf)["enzymejl_flib"] = flib
412413
# metadata(newf)["enzymejl_flib"] = flib

src/llvm/transforms.jl

Lines changed: 100 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1311,9 +1311,44 @@ function fix_decayaddr!(mod::LLVM.Module)
13111311
return nothing
13121312
end
13131313

1314-
function pre_attr!(mod::LLVM.Module)
1314+
function pre_attr!(mod::LLVM.Module, run_attr)
1315+
if run_attr
1316+
for fn in functions(mod)
1317+
if isempty(blocks(fn))
1318+
continue
1319+
end
1320+
attrs = collect(function_attributes(fn))
1321+
prevent = any(
1322+
kind(attr) == kind(StringAttribute("enzyme_preserve_primal")) for attr in attrs
1323+
)
1324+
if !prevent
1325+
continue
1326+
end
1327+
1328+
if linkage(fn) == LLVM.API.LLVMInternalLinkage
1329+
push!(LLVM.function_attributes(fn), StringAttribute("restorelinkage_internal"))
1330+
linkage!(fn, LLVM.API.LLVMExternalLinkage)
1331+
end
1332+
1333+
if linkage(fn) == LLVM.API.LLVMPrivateLinkage
1334+
push!(LLVM.function_attributes(fn), StringAttribute("restorelinkage_private"))
1335+
linkage!(fn, LLVM.API.LLVMExternalLinkage)
1336+
end
1337+
continue
1338+
1339+
if !has_fn_attr(fn, EnumAttribute("noinline"))
1340+
push!(LLVM.function_attributes(fn), EnumAttribute("noinline"))
1341+
push!(LLVM.function_attributes(fn), StringAttribute("remove_noinline"))
1342+
end
1343+
1344+
if !has_fn_attr(fn, EnumAttribute("optnone"))
1345+
push!(LLVM.function_attributes(fn), EnumAttribute("optnone"))
1346+
push!(LLVM.function_attributes(fn), StringAttribute("remove_optnone"))
1347+
end
1348+
end
1349+
end
13151350
return nothing
1316-
tofinalize = Tuple{LLVM.Function,Bool,Vector{Int64}}[]
1351+
13171352
for fn in collect(functions(mod))
13181353
if isempty(blocks(fn))
13191354
continue
@@ -1337,6 +1372,32 @@ function pre_attr!(mod::LLVM.Module)
13371372
end
13381373
end
13391374
end
1375+
end
1376+
1377+
function post_attr!(mod::LLVM.Module, run_attr)
1378+
if run_attr
1379+
for fn in functions(mod)
1380+
if has_fn_attr(fn, StringAttribute("restorelinkage_internal"))
1381+
delete!(LLVM.function_attributes(fn), StringAttribute("restorelinkage_internal"))
1382+
linkage!(fn, LLVM.API.LLVMInternalLinkage)
1383+
end
1384+
1385+
if has_fn_attr(fn, StringAttribute("restorelinkage_private"))
1386+
delete!(LLVM.function_attributes(fn), StringAttribute("restorelinkage_private"))
1387+
linkage!(fn, LLVM.API.LLVMPrivateLinkage)
1388+
end
1389+
1390+
if has_fn_attr(fn, StringAttribute("remove_noinline"))
1391+
delete!(LLVM.function_attributes(fn), EnumAttribute("noinline"))
1392+
delete!(LLVM.function_attributes(fn), StringAttribute("remove_noinline"))
1393+
end
1394+
1395+
if has_fn_attr(fn, StringAttribute("remove_optnone"))
1396+
delete!(LLVM.function_attributes(fn), EnumAttribute("optnone"))
1397+
delete!(LLVM.function_attributes(fn), StringAttribute("remove_optnone"))
1398+
end
1399+
end
1400+
end
13401401
return nothing
13411402
end
13421403

@@ -1781,37 +1842,40 @@ function propagate_returned!(mod::LLVM.Module)
17811842
LLVM.replace_uses!(arg, val)
17821843
end
17831844
end
1784-
# see if there are no users of the value (excluding recursive/return)
1785-
baduse = false
1786-
for u in LLVM.uses(arg)
1787-
u = LLVM.user(u)
1788-
if argn == i && LLVM.API.LLVMIsAReturnInst(u) != C_NULL
1789-
continue
1790-
end
1791-
if !isa(u, LLVM.CallInst)
1792-
baduse = true
1793-
break
1794-
end
1795-
if LLVM.called_operand(u) != fn
1796-
baduse = true
1797-
break
1798-
end
1799-
for (si, op) in enumerate(operands(u))
1800-
if si == i
1801-
continue
1802-
end
1803-
if op == arg
1804-
baduse = true
1805-
break
1806-
end
1807-
end
1808-
if baduse
1809-
break
1810-
end
1811-
end
1812-
if !baduse
1813-
push!(toremove, i - 1)
1814-
end
1845+
1846+
# see if there are no users of the value (excluding recursive/return)
1847+
if !prevent
1848+
baduse = false
1849+
for u in LLVM.uses(arg)
1850+
u = LLVM.user(u)
1851+
if argn == i && LLVM.API.LLVMIsAReturnInst(u) != C_NULL
1852+
continue
1853+
end
1854+
if !isa(u, LLVM.CallInst)
1855+
baduse = true
1856+
break
1857+
end
1858+
if LLVM.called_operand(u) != fn
1859+
baduse = true
1860+
break
1861+
end
1862+
for (si, op) in enumerate(operands(u))
1863+
if si == i
1864+
continue
1865+
end
1866+
if op == arg
1867+
baduse = true
1868+
break
1869+
end
1870+
end
1871+
if baduse
1872+
break
1873+
end
1874+
end
1875+
if !baduse
1876+
push!(toremove, i - 1)
1877+
end
1878+
end
18151879
end
18161880
illegalUse = !(
18171881
linkage(fn) == LLVM.API.LLVMInternalLinkage ||
@@ -2498,7 +2562,7 @@ function removeDeadArgs!(mod::LLVM.Module, tm::LLVM.TargetMachine)
24982562
LLVM.run!(pm, mod)
24992563
end
25002564
propagate_returned!(mod)
2501-
pre_attr!(mod)
2565+
pre_attr!(mod, RunAttributor[])
25022566
if RunAttributor[]
25032567
if LLVM.version().major >= 13
25042568
ModulePassManager() do pm
@@ -2521,8 +2585,9 @@ function removeDeadArgs!(mod::LLVM.Module, tm::LLVM.TargetMachine)
25212585
cse!(pm)
25222586
LLVM.run!(pm, mod)
25232587
end
2524-
post_attr!(mod)
2588+
post_attr!(mod, RunAttributor[])
25252589
propagate_returned!(mod)
2590+
25262591

25272592
for u in LLVM.uses(rfunc)
25282593
u = LLVM.user(u)

test/optimize.jl

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,3 +205,34 @@ end
205205
@testset "Indirect function call return type analysis" begin
206206
RetTypeMod.main()
207207
end
208+
209+
@noinline function blasdot(x, y)
210+
n = length(x)
211+
s = GC.@preserve x y begin
212+
DX, incx = LinearAlgebra.BLAS.vec_pointer_stride(x)
213+
DY, incy = LinearAlgebra.BLAS.vec_pointer_stride(y)
214+
result = Ref{ComplexF64}()
215+
ccall((LinearAlgebra.BLAS.@blasfunc(cblas_zdotc_sub), LinearAlgebra.BLAS.libblastrampoline), Cvoid,
216+
(LinearAlgebra.BLAS.BlasInt, Ptr{ComplexF64}, LinearAlgebra.BLAS.BlasInt, Ptr{ComplexF64}, LinearAlgebra.BLAS.BlasInt, Ptr{ComplexF64}),
217+
n, DX, incx, DY, incy, result)
218+
result[]
219+
end
220+
return s
221+
end
222+
223+
function fwd(x, y)
224+
blasdot(x, y)
225+
end
226+
227+
@testset "Parameter removal" begin
228+
# Test that we do not remove parameters, or replace with undef, any parameters from externally linked code (even if replaced via blas)
229+
fn = sprint() do io
230+
Enzyme.Compiler.enzyme_code_llvm(io, fwd, Const, Tuple{Const{Vector{ComplexF64}},Const{Vector{ComplexF64}}}; dump_module=true)
231+
end
232+
233+
for s in split(fn, "\n")
234+
if occursin(s, "ejlstr")
235+
@test !(occursin(" undef",s) || occursin(" poison",s))
236+
end
237+
end
238+
end

0 commit comments

Comments
 (0)