Skip to content

Commit e273141

Browse files
authored
Fix 1.11 vcat analysis (#2602)
1 parent 62fcfea commit e273141

File tree

2 files changed

+25
-5
lines changed

2 files changed

+25
-5
lines changed

src/absint.jl

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -255,12 +255,16 @@ function should_recurse(@nospecialize(typ2), @nospecialize(arg_t::LLVM.LLVMType)
255255
end
256256
end
257257

258-
function get_base_and_offset(@nospecialize(larg::LLVM.Value); offsetAllowed::Bool = true, inttoptr::Bool = false, inst::Union{LLVM.Instruction, Nothing} = nothing)::Tuple{LLVM.Value, Int}
258+
function get_base_and_offset(@nospecialize(larg::LLVM.Value); offsetAllowed::Bool = true, inttoptr::Bool = false, inst::Union{LLVM.Instruction, Nothing} = nothing, addrcast::Bool=true)::Tuple{LLVM.Value, Int}
259259
offset = 0
260260
pinst = isa(larg, LLVM.Instruction) ? larg::LLVM.Instruction : inst
261261
while true
262262
if isa(larg, LLVM.ConstantExpr)
263-
if opcode(larg) == LLVM.API.LLVMBitCast || opcode(larg) == LLVM.API.LLVMAddrSpaceCast || opcode(larg) == LLVM.API.LLVMPtrToInt
263+
if opcode(larg) == LLVM.API.LLVMBitCast || opcode(larg) == LLVM.API.LLVMPtrToInt
264+
larg = operands(larg)[1]
265+
continue
266+
end
267+
if addrcast && opcode(larg) == LLVM.API.LLVMAddrSpaceCast
264268
larg = operands(larg)[1]
265269
continue
266270
end
@@ -287,7 +291,11 @@ function get_base_and_offset(@nospecialize(larg::LLVM.Value); offsetAllowed::Boo
287291
end
288292
end
289293
end
290-
if isa(larg, LLVM.BitCastInst) || isa(larg, LLVM.AddrSpaceCastInst) || isa(larg, LLVM.IntToPtrInst)
294+
if isa(larg, LLVM.BitCastInst) || isa(larg, LLVM.IntToPtrInst)
295+
larg = operands(larg)[1]
296+
continue
297+
end
298+
if addrcast && isa(larg, LLVM.AddrSpaceCastInst)
291299
larg = operands(larg)[1]
292300
continue
293301
end
@@ -332,7 +340,7 @@ function abs_typeof(
332340
)::Union{Tuple{Bool, Type, GPUCompiler.ArgumentCC}, Tuple{Bool, Nothing, Nothing}}
333341
if (value_type(arg) == LLVM.PointerType(LLVM.StructType(LLVMType[]), Tracked)) || (value_type(arg) == LLVM.PointerType(LLVM.StructType(LLVMType[]), Derived))
334342
ce, _ = get_base_and_offset(arg; offsetAllowed = false, inttoptr = true)
335-
if isa(ce, GlobalVariable)
343+
if isa(ce, GlobalVariable)
336344
gname = LLVM.name(ce)
337345
for (k, v) in JuliaGlobalNameMap
338346
if gname == k
@@ -778,7 +786,7 @@ function abs_typeof(
778786
end
779787
push!(seen, cur)
780788
for (v, _) in LLVM.incoming(cur)
781-
v2, off = get_base_and_offset(v)
789+
v2, off = get_base_and_offset(v, inttoptr=false, addrcast=false)
782790
if off != 0
783791
if isa(v, LLVM.Instruction) && arg in collect(operands(v))
784792
legal = false

test/runtests.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -817,6 +817,18 @@ end
817817
@test res[2] 1.0
818818
end
819819

820+
@testset "1.11 vcat" begin
821+
822+
function fcat(x)
823+
r = vcat(Any[1])
824+
return x
825+
end
826+
827+
res = Enzyme.autodiff(Reverse, fcat, Active(2.0))
828+
@test res[1][1] 1.0
829+
830+
end
831+
820832
@testset "Taylor series tests" begin
821833

822834
# Taylor series for `-log(1-x)`

0 commit comments

Comments
 (0)