Skip to content

Commit 73d4f93

Browse files
authored
Support typetag in absint (#2800)
* Support typetag in absint * fix * simplevector * Update absint.jl * fix * fix * fix
1 parent 8527858 commit 73d4f93

File tree

10 files changed

+283
-59
lines changed

10 files changed

+283
-59
lines changed

src/absint.jl

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
# Abstractly interpret julia from LLVM
22

33
# Return (bool if could interpret, julia object interpreted to)
4-
function absint(@nospecialize(arg::LLVM.Value), partial::Bool = false, istracked::Bool=false)::Tuple{Bool, Any}
4+
5+
6+
const JL_MAX_TAGS = 64 # see `enum jl_small_typeof_tags` in julia.h
7+
8+
function absint(@nospecialize(arg::LLVM.Value), partial::Bool = false, istracked::Bool=false, typetag::Bool=false)::Tuple{Bool, Any}
59
if (value_type(arg) == LLVM.PointerType(LLVM.StructType(LLVMType[]), Tracked)) || (value_type(arg) == LLVM.PointerType(LLVM.StructType(LLVMType[]), Derived)) || istracked
610
ce, _ = get_base_and_offset(arg; offsetAllowed = false, inttoptr = true)
711
if isa(ce, GlobalVariable)
@@ -37,18 +41,26 @@ function absint(@nospecialize(arg::LLVM.Value), partial::Bool = false, istracked
3741
end
3842
end
3943
if isa(ce, LLVM.ConstantInt)
40-
ptr = reinterpret(Ptr{Cvoid}, convert(UInt, ce))
44+
ce = convert(UInt, ce)
45+
# "small" type tags are indices into a special array
46+
ptr = if typetag && ce < (JL_MAX_TAGS << 4)
47+
jl_small_typeof = Ptr{Ptr{Cvoid}}(cglobal(:jl_small_typeof))
48+
type_idx = ce ÷ Core.sizeof(Ptr{Cvoid})
49+
unsafe_load(jl_small_typeof, type_idx + 1)
50+
else
51+
reinterpret(Ptr{Cvoid}, ce)
52+
end
4153
val = Base.unsafe_pointer_to_objref(ptr)
4254
return (true, val)
4355
end
4456
end
4557
if isa(arg, ConstantExpr)
4658
if opcode(arg) == LLVM.API.LLVMAddrSpaceCast || opcode(arg) == LLVM.API.LLVMBitCast
47-
return absint(operands(arg)[1], partial)
59+
return absint(operands(arg)[1], partial, false, typetag)
4860
end
4961
end
5062
if isa(arg, LLVM.BitCastInst) || isa(arg, LLVM.AddrSpaceCastInst) || isa(arg, LLVM.IntToPtrInst)
51-
return absint(operands(arg)[1], partial)
63+
return absint(operands(arg)[1], partial, false, typetag)
5264
end
5365
if isa(arg, LLVM.CallInst)
5466
fn = LLVM.called_operand(arg)
@@ -334,6 +346,7 @@ function get_base_and_offset(@nospecialize(larg::LLVM.Value); offsetAllowed::Boo
334346
return larg, offset
335347
end
336348

349+
337350
function abs_typeof(
338351
@nospecialize(arg::LLVM.Value),
339352
partial::Bool = false, seenphis = Set{LLVM.PHIInst}()
@@ -438,13 +451,13 @@ function abs_typeof(
438451
if nm == "julia.gc_alloc_obj" ||
439452
nm == "jl_gc_alloc_typed" ||
440453
nm == "ijl_gc_alloc_typed"
441-
vals = absint(operands(arg)[3], partial)
454+
vals = absint(operands(arg)[3], partial, false, #=typetag=#true)
442455
return (vals[1], vals[2], vals[1] ? GPUCompiler.BITS_REF : nothing)
443456
end
444457
# Type tag is arg 3
445458
if nm == "jl_alloc_genericmemory_unchecked" ||
446459
nm == "ijl_alloc_genericmemory_unchecked"
447-
vals = absint(operands(arg)[3], partial, true)
460+
vals = absint(operands(arg)[3], partial, true, #=typetag=#true)
448461
return (vals[1], vals[2], vals[1] ? GPUCompiler.MUT_REF : nothing)
449462
end
450463
# Type tag is arg 1
@@ -458,12 +471,12 @@ function abs_typeof(
458471
nm == "ijl_new_array" ||
459472
nm == "jl_alloc_genericmemory" ||
460473
nm == "ijl_alloc_genericmemory"
461-
vals = absint(operands(arg)[1], partial)
474+
vals = absint(operands(arg)[1], partial, false, #=typetag=#true)
462475
return (vals[1], vals[2], vals[1] ? GPUCompiler.MUT_REF : nothing)
463476
end
464477

465478
if nm == "jl_new_structt" || nm == "ijl_new_structt"
466-
vals = absint(operands(arg)[1], partial)
479+
vals = absint(operands(arg)[1], partial, false, #=typetag=#true)
467480
return (vals[1], vals[2], vals[1] ? GPUCompiler.MUT_REF : nothing)
468481
end
469482

@@ -481,7 +494,7 @@ function abs_typeof(
481494

482495
if nm == "jl_new_structv" || nm == "ijl_new_structv"
483496
@assert index == 2
484-
vals = absint(operands(arg)[index], partial)
497+
vals = absint(operands(arg)[index], partial, false, #=typetag=#true)
485498
return (vals[1], vals[2], vals[1] ? GPUCompiler.MUT_REF : nothing)
486499
end
487500

src/analyses/activity.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ end
1212
@inline element(::Val{T}) where {T} = T
1313

1414
@inline ptreltype(::Type{Ptr{T}}) where {T} = T
15+
@inline ptreltype(::Type{Core.SimpleVector}) = Any
1516
@inline ptreltype(::Type{Core.LLVMPtr{T,N}}) where {T,N} = T
1617
@inline ptreltype(::Type{Core.LLVMPtr{T} where N}) where {T} = T
1718
@inline ptreltype(::Type{Base.RefValue{T}}) where {T} = T
@@ -29,6 +30,7 @@ end
2930

3031
@inline is_arrayorvararg_ty(::Type) = false
3132
@inline is_arrayorvararg_ty(::Type{Tuple{Vararg{T2}}}) where {T2} = true
33+
@inline is_arrayorvararg_ty(::Type{Core.SimpleVector}) = true
3234
@inline is_arrayorvararg_ty(::Type{Ptr{T}}) where {T} = true
3335
@inline is_arrayorvararg_ty(::Type{Core.LLVMPtr{T,N}}) where {T,N} = true
3436
@inline is_arrayorvararg_ty(::Type{Core.LLVMPtr{T,N} where N}) where {T} = true

src/api.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -594,6 +594,14 @@ EnzymeGradientUtilsErase(gutils, a) = ccall(
594594
gutils,
595595
a,
596596
)
597+
EnzymeReplaceOriginalToNew(gutils, orig, rep) = ccall(
598+
(:EnzymeReplaceOriginalToNew, libEnzyme),
599+
Cvoid,
600+
(EnzymeGradientUtilsRef, LLVMValueRef, LLVMValueRef),
601+
gutils,
602+
orig,
603+
rep
604+
)
597605
EnzymeGradientUtilsEraseWithPlaceholder(gutils, a, orig, erase) = ccall(
598606
(:EnzymeGradientUtilsEraseWithPlaceholder, libEnzyme),
599607
Cvoid,

src/compiler.jl

Lines changed: 106 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1155,6 +1155,12 @@ function set_module_types!(interp, mod::LLVM.Module, primalf::Union{Nothing, LLV
11551155
parameter_attributes(f, arg.codegen.i),
11561156
StringAttribute("enzymejl_parmtype_ref", string(UInt(arg.cc))),
11571157
)
1158+
if arg.rooted_typ !== nothing
1159+
push!(
1160+
parameter_attributes(f, arg.codegen.i),
1161+
StringAttribute("enzymejl_rooted_typ", string(convert(UInt, unsafe_to_pointer(arg.rooted_typ))))
1162+
)
1163+
end
11581164

11591165
byref = arg.cc
11601166

@@ -1559,6 +1565,23 @@ function create_recursive_stores(B::LLVM.IRBuilder, @nospecialize(Ty::DataType),
15591565
nothing
15601566
end
15611567
else
1568+
if Ty == Core.SimpleVector
1569+
@assert count === nothing
1570+
@assert isa(prev, LLVM.CallInst)
1571+
@assert LLVM.name(LLVM.called_operand(prev)::LLVM.Function) == "julia.gc_alloc_obj"
1572+
sz = operands(prev)[2]
1573+
sz = sub!(B, sz, LLVM.ConstantInt(Int(sizeof(Ptr{Cvoid}))))
1574+
T_jlvalue = LLVM.StructType(LLVM.LLVMType[])
1575+
T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked)
1576+
prev = addrspacecast!(B, prev, LLVM.PointerType(T_jlvalue, Derived))
1577+
prev = bitcast!(B, prev, LLVM.PointerType(T_prjlvalue, Derived))
1578+
gep = LLVM.gep!(B, T_prjlvalue, prev, LLVM.Value[LLVM.ConstantInt(Int64(1))])
1579+
zeroAll = false
1580+
atomic = true
1581+
zero_allocation(B, Any, T_prjlvalue, prev, LLVM.ConstantInt(sizeof(Ptr{Cvoid})), sz, zeroAll, atomic)
1582+
return
1583+
end
1584+
15621585
if fieldcount(Ty) == 0
15631586
error("Error handling recursive stores for $Ty which has a fieldcount of 0")
15641587
end
@@ -3895,7 +3918,7 @@ function recombine_value!(builder::LLVM.IRBuilder, sret::LLVM.Value, roots::LLVM
38953918
jltype = value_type(sret)
38963919
tracked = CountTrackedPointers(jltype)
38973920
@assert tracked.count > 0
3898-
@assert !tracked.all
3921+
@assert !tracked.all "Not tracked.all, jltype ($(string(jltype)))"
38993922
root_ty = convert(LLVMType, AnyArray(Int(tracked.count)))
39003923
move_sret_tofrom_roots!(builder, jltype, sret, root_ty, roots, RootPointerToSRetValue)
39013924
end
@@ -3904,11 +3927,74 @@ function extract_roots_from_value!(builder::LLVM.IRBuilder, sret::LLVM.Value, ro
39043927
jltype = value_type(sret)
39053928
tracked = CountTrackedPointers(jltype)
39063929
@assert tracked.count > 0
3907-
@assert !tracked.all
3930+
@assert !tracked.all "Not tracked.all, jltype ($(string(jltype)))"
39083931
root_ty = convert(LLVMType, AnyArray(Int(tracked.count)))
39093932
move_sret_tofrom_roots!(builder, jltype, sret, root_ty, roots, SRetValueToRootPointer)
39103933
end
39113934

3935+
function copy_floats_into!(builder::LLVM.IRBuilder, jltype::LLVM.LLVMType, dst::LLVM.Value, src::LLVM.Value)
3936+
count = 0
3937+
todo = Tuple{Vector{Cuint},LLVM.LLVMType}[(
3938+
Cuint[],
3939+
jltype,
3940+
)]
3941+
function to_llvm(lst::Vector{Cuint})
3942+
vals = LLVM.Value[]
3943+
push!(vals, LLVM.ConstantInt(LLVM.IntType(64), 0))
3944+
for i in lst
3945+
push!(vals, LLVM.ConstantInt(LLVM.IntType(32), i))
3946+
end
3947+
return vals
3948+
end
3949+
3950+
extracted = LLVM.Value[]
3951+
3952+
while length(todo) != 0
3953+
path, ty = popfirst!(todo)
3954+
3955+
if isa(ty, LLVM.PointerType) || isa(ty, LLVM.IntegerType)
3956+
continue
3957+
end
3958+
3959+
if isa(ty, LLVM.FloatingPointType)
3960+
dstloc = inbounds_gep!(builder, jltype, dst, to_llvm(path), "dstloc")
3961+
srcloc = inbounds_gep!(builder, jltype, src, to_llvm(path), "srcloc")
3962+
val = load!(builder, ty, srcloc)
3963+
st = store!(builder, val, dstloc)
3964+
continue
3965+
end
3966+
3967+
if isa(ty, LLVM.ArrayType)
3968+
for i = 1:length(ty)
3969+
npath = copy(path)
3970+
push!(npath, i - 1)
3971+
push!(todo, (npath, eltype(ty)))
3972+
end
3973+
continue
3974+
end
3975+
3976+
if isa(ty, LLVM.VectorType)
3977+
for i = 1:size(ty)
3978+
npath = copy(path)
3979+
push!(npath, i - 1)
3980+
push!(todo, (npath, eltype(ty)))
3981+
end
3982+
continue
3983+
end
3984+
3985+
if isa(ty, LLVM.StructType)
3986+
for (i, t) in enumerate(LLVM.elements(ty))
3987+
npath = copy(path)
3988+
push!(npath, i - 1)
3989+
push!(todo, (npath, t))
3990+
end
3991+
continue
3992+
end
3993+
end
3994+
3995+
return nothing
3996+
end
3997+
39123998

39133999
# Modified from GPUCompiler/src/irgen.jl:365 lower_byval
39144000
function lower_convention(
@@ -4314,6 +4400,15 @@ function lower_convention(
43144400
string(UInt(GPUCompiler.BITS_VALUE)),
43154401
),
43164402
)
4403+
if arg.rooted_typ !== nothing
4404+
push!(
4405+
parameter_attributes(wrapper_f, wrapper_idx - 1),
4406+
StringAttribute(
4407+
"enzymejl_rooted_typ",
4408+
string(convert(UInt, unsafe_to_pointer(arg.rooted_typ))),
4409+
),
4410+
)
4411+
end
43174412
elseif arg.arg_i in raisedArgs
43184413
wrapparm = load!(builder, convert(LLVMType, arg.typ), wrapparm)
43194414
ctx = LLVM.context(wrapparm)
@@ -4342,6 +4437,15 @@ function lower_convention(
43424437
string(UInt(GPUCompiler.BITS_REF)),
43434438
),
43444439
)
4440+
if arg.rooted_typ !== nothing
4441+
push!(
4442+
parameter_attributes(wrapper_f, wrapper_idx - 1),
4443+
StringAttribute(
4444+
"enzymejl_rooted_typ",
4445+
string(convert(UInt, unsafe_to_pointer(arg.rooted_typ)))
4446+
),
4447+
)
4448+
end
43454449
else
43464450
push!(wrapper_args, wrapparm)
43474451
for attr in collect(parameter_attributes(entry_f, arg.codegen.i))

src/llvm/transforms.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1289,6 +1289,10 @@ function fix_decayaddr!(mod::LLVM.Module)
12891289
sret_elty = sret_ty(fop, i)
12901290
t_sret = true
12911291
end
1292+
if kind(a) == kind(StringAttribute("enzymejl_rooted_typ"))
1293+
sret_elty = get_rooted_typ(fop, i)
1294+
t_sret = true
1295+
end
12921296
# if kind(a) == kind(StringAttribute("enzyme_sret_v"))
12931297
# t_sret = true
12941298
# end

src/rules/customrules.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1567,7 +1567,7 @@ function enzyme_custom_common_rev(
15671567
if tape_roots != 0
15681568
roots_ty = convert(LLVMType, AnyArray(tape_roots))
15691569
tape_al = alloca!(B, roots_ty)
1570-
extract_roots_from_value!(B, tape, ral)
1570+
extract_roots_from_value!(B, tape, tape_al)
15711571
end
15721572

15731573
al0 = al = emit_allocobj!(B, TapeT, "tape.$TapeT")

0 commit comments

Comments
 (0)