@@ -1311,9 +1311,44 @@ function fix_decayaddr!(mod::LLVM.Module)
13111311 return nothing
13121312end
13131313
1314- function pre_attr! (mod:: LLVM.Module )
1314+ function pre_attr! (mod:: LLVM.Module , run_attr)
1315+ if run_attr
1316+ for fn in functions (mod)
1317+ if isempty (blocks (fn))
1318+ continue
1319+ end
1320+ attrs = collect (function_attributes (fn))
1321+ prevent = any (
1322+ kind (attr) == kind (StringAttribute (" enzyme_preserve_primal" )) for attr in attrs
1323+ )
1324+ if ! prevent
1325+ continue
1326+ end
1327+
1328+ if linkage (fn) == LLVM. API. LLVMInternalLinkage
1329+ push! (LLVM. function_attributes (fn), StringAttribute (" restorelinkage_internal" ))
1330+ linkage! (fn, LLVM. API. LLVMExternalLinkage)
1331+ end
1332+
1333+ if linkage (fn) == LLVM. API. LLVMPrivateLinkage
1334+ push! (LLVM. function_attributes (fn), StringAttribute (" restorelinkage_private" ))
1335+ linkage! (fn, LLVM. API. LLVMExternalLinkage)
1336+ end
1337+ continue
1338+
1339+ if ! has_fn_attr (fn, EnumAttribute (" noinline" ))
1340+ push! (LLVM. function_attributes (fn), EnumAttribute (" noinline" ))
1341+ push! (LLVM. function_attributes (fn), StringAttribute (" remove_noinline" ))
1342+ end
1343+
1344+ if ! has_fn_attr (fn, EnumAttribute (" optnone" ))
1345+ push! (LLVM. function_attributes (fn), EnumAttribute (" optnone" ))
1346+ push! (LLVM. function_attributes (fn), StringAttribute (" remove_optnone" ))
1347+ end
1348+ end
1349+ end
13151350 return nothing
1316- tofinalize = Tuple{LLVM . Function,Bool,Vector{Int64}}[]
1351+
13171352 for fn in collect (functions (mod))
13181353 if isempty (blocks (fn))
13191354 continue
@@ -1337,6 +1372,32 @@ function pre_attr!(mod::LLVM.Module)
13371372 end
13381373 end
13391374 end
1375+ end
1376+
1377+ function post_attr! (mod:: LLVM.Module , run_attr)
1378+ if run_attr
1379+ for fn in functions (mod)
1380+ if has_fn_attr (fn, StringAttribute (" restorelinkage_internal" ))
1381+ delete! (LLVM. function_attributes (fn), StringAttribute (" restorelinkage_internal" ))
1382+ linkage! (fn, LLVM. API. LLVMInternalLinkage)
1383+ end
1384+
1385+ if has_fn_attr (fn, StringAttribute (" restorelinkage_private" ))
1386+ delete! (LLVM. function_attributes (fn), StringAttribute (" restorelinkage_private" ))
1387+ linkage! (fn, LLVM. API. LLVMPrivateLinkage)
1388+ end
1389+
1390+ if has_fn_attr (fn, StringAttribute (" remove_noinline" ))
1391+ delete! (LLVM. function_attributes (fn), EnumAttribute (" noinline" ))
1392+ delete! (LLVM. function_attributes (fn), StringAttribute (" remove_noinline" ))
1393+ end
1394+
1395+ if has_fn_attr (fn, StringAttribute (" remove_optnone" ))
1396+ delete! (LLVM. function_attributes (fn), EnumAttribute (" optnone" ))
1397+ delete! (LLVM. function_attributes (fn), StringAttribute (" remove_optnone" ))
1398+ end
1399+ end
1400+ end
13401401 return nothing
13411402end
13421403
@@ -1781,37 +1842,40 @@ function propagate_returned!(mod::LLVM.Module)
17811842 LLVM. replace_uses! (arg, val)
17821843 end
17831844 end
1784- # see if there are no users of the value (excluding recursive/return)
1785- baduse = false
1786- for u in LLVM. uses (arg)
1787- u = LLVM. user (u)
1788- if argn == i && LLVM. API. LLVMIsAReturnInst (u) != C_NULL
1789- continue
1790- end
1791- if ! isa (u, LLVM. CallInst)
1792- baduse = true
1793- break
1794- end
1795- if LLVM. called_operand (u) != fn
1796- baduse = true
1797- break
1798- end
1799- for (si, op) in enumerate (operands (u))
1800- if si == i
1801- continue
1802- end
1803- if op == arg
1804- baduse = true
1805- break
1806- end
1807- end
1808- if baduse
1809- break
1810- end
1811- end
1812- if ! baduse
1813- push! (toremove, i - 1 )
1814- end
1845+
1846+ # see if there are no users of the value (excluding recursive/return)
1847+ if ! prevent
1848+ baduse = false
1849+ for u in LLVM. uses (arg)
1850+ u = LLVM. user (u)
1851+ if argn == i && LLVM. API. LLVMIsAReturnInst (u) != C_NULL
1852+ continue
1853+ end
1854+ if ! isa (u, LLVM. CallInst)
1855+ baduse = true
1856+ break
1857+ end
1858+ if LLVM. called_operand (u) != fn
1859+ baduse = true
1860+ break
1861+ end
1862+ for (si, op) in enumerate (operands (u))
1863+ if si == i
1864+ continue
1865+ end
1866+ if op == arg
1867+ baduse = true
1868+ break
1869+ end
1870+ end
1871+ if baduse
1872+ break
1873+ end
1874+ end
1875+ if ! baduse
1876+ push! (toremove, i - 1 )
1877+ end
1878+ end
18151879 end
18161880 illegalUse = ! (
18171881 linkage (fn) == LLVM. API. LLVMInternalLinkage ||
@@ -2498,7 +2562,7 @@ function removeDeadArgs!(mod::LLVM.Module, tm::LLVM.TargetMachine)
24982562 LLVM. run! (pm, mod)
24992563 end
25002564 propagate_returned! (mod)
2501- pre_attr! (mod)
2565+ pre_attr! (mod, RunAttributor[] )
25022566 if RunAttributor[]
25032567 if LLVM. version (). major >= 13
25042568 ModulePassManager () do pm
@@ -2521,8 +2585,9 @@ function removeDeadArgs!(mod::LLVM.Module, tm::LLVM.TargetMachine)
25212585 cse! (pm)
25222586 LLVM. run! (pm, mod)
25232587 end
2524- post_attr! (mod)
2588+ post_attr! (mod, RunAttributor[] )
25252589 propagate_returned! (mod)
2590+
25262591
25272592 for u in LLVM. uses (rfunc)
25282593 u = LLVM. user (u)
0 commit comments