Skip to content

Commit 30b6b2d

Browse files
authored
Fanon2 (#2136)
* Fewer anonymous funcs * minor cleanup * cleanup * fix * Aggressively noinfer * more type annotations
1 parent b78ec7b commit 30b6b2d

File tree

7 files changed

+118
-72
lines changed

7 files changed

+118
-72
lines changed

src/absint.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Abstractly interpret julia from LLVM
22

33
# Return (bool if could interpret, julia object interpreted to)
4-
function absint(arg::LLVM.Value, partial::Bool = false)
4+
function absint(@nospecialize(arg::LLVM.Value), partial::Bool = false)::Tuple{Bool,Any}
55
if isa(arg, LLVM.BitCastInst) || isa(arg, LLVM.AddrSpaceCastInst)
66
return absint(operands(arg)[1], partial)
77
end
@@ -165,7 +165,7 @@ function absint(arg::LLVM.Value, partial::Bool = false)
165165
return (false, nothing)
166166
end
167167

168-
function actual_size(@nospecialize(typ2))
168+
function actual_size(@nospecialize(typ2))::Int
169169
@static if VERSION < v"1.11-"
170170
if typ2 <: Array
171171
return sizeof(Ptr{Cvoid}) + 2 + 2 + 4 + 2 * sizeof(Csize_t) + sizeof(Csize_t)
@@ -184,10 +184,10 @@ function actual_size(@nospecialize(typ2))
184184
end
185185
end
186186

187-
@inline function first_non_ghost(@nospecialize(typ2))
187+
@inline function first_non_ghost(@nospecialize(typ2))::Tuple{Int, Int}
188188
@static if VERSION < v"1.11-"
189189
if typ2 <: Array
190-
return (1, typed_fieldtype(typ2, 1))
190+
return (1, 0)
191191
end
192192
end
193193
fc = fieldcount(typ2)
@@ -204,7 +204,7 @@ end
204204
return (-1, 0)
205205
end
206206

207-
function should_recurse(@nospecialize(typ2), arg_t, byref, dl)
207+
function should_recurse(@nospecialize(typ2), @nospecialize(arg_t::LLVM.LLVMType), byref::GPUCompiler.ArgumentCC, dl::LLVM.DataLayout)::Bool
208208
sz = sizeof(dl, arg_t)
209209
if byref != GPUCompiler.BITS_VALUE
210210
if sz != sizeof(Int)
@@ -228,7 +228,7 @@ function should_recurse(@nospecialize(typ2), arg_t, byref, dl)
228228
end
229229
end
230230

231-
function get_base_and_offset(larg::LLVM.Value; offsetAllowed=true, inttoptr=false)::Tuple{LLVM.Value, Int}
231+
function get_base_and_offset(@nospecialize(larg::LLVM.Value); offsetAllowed::Bool=true, inttoptr::Bool=false)::Tuple{LLVM.Value, Int}
232232
offset = 0
233233
while true
234234
if isa(larg, LLVM.ConstantExpr)
@@ -277,7 +277,7 @@ function get_base_and_offset(larg::LLVM.Value; offsetAllowed=true, inttoptr=fals
277277
end
278278

279279
function abs_typeof(
280-
arg::LLVM.Value,
280+
@nospecialize(arg::LLVM.Value),
281281
partial::Bool = false, seenphis=Set{LLVM.PHIInst}()
282282
)::Union{Tuple{Bool,Type,GPUCompiler.ArgumentCC},Tuple{Bool,Nothing,Nothing}}
283283
if isa(arg, LLVM.BitCastInst) || isa(arg, LLVM.AddrSpaceCastInst)
@@ -729,7 +729,7 @@ function abs_typeof(
729729
return (false, nothing, nothing)
730730
end
731731

732-
function abs_cstring(arg::LLVM.Value)::Tuple{Bool,String}
732+
function abs_cstring(@nospecialize(arg::LLVM.Value))::Tuple{Bool,String}
733733
if isa(arg, ConstantExpr)
734734
ce = arg
735735
while isa(ce, ConstantExpr)

src/compiler.jl

Lines changed: 44 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1125,7 +1125,11 @@ struct Return2
11251125
end
11261126

11271127
function force_recompute!(mod::LLVM.Module)
1128-
for f in functions(mod), bb in blocks(f), inst in collect(instructions(bb))
1128+
for f in functions(mod), bb in blocks(f)
1129+
iter = LLVM.API.LLVMGetFirstInstruction(bb)
1130+
while iter != C_NULL
1131+
inst = LLVM.Instruction(iter)
1132+
iter = LLVM.API.LLVMGetNextInstruction(iter)
11291133
if isa(inst, LLVM.LoadInst)
11301134
has_loaded = false
11311135
for u in LLVM.uses(inst)
@@ -1170,6 +1174,7 @@ function force_recompute!(mod::LLVM.Module)
11701174
end
11711175
end
11721176
end
1177+
end
11731178
end
11741179

11751180
function permit_inlining!(f::LLVM.Function)
@@ -3275,7 +3280,7 @@ end
32753280
# Enzyme compiler step
32763281
##
32773282

3278-
function annotate!(mod, mode)
3283+
function annotate!(mod::LLVM.Module)
32793284
inactive = LLVM.StringAttribute("enzyme_inactive", "")
32803285
active = LLVM.StringAttribute("enzyme_active", "")
32813286
no_escaping_alloc = LLVM.StringAttribute("enzyme_no_escaping_allocation")
@@ -3891,7 +3896,7 @@ function enzyme_extract_world(fn::LLVM.Function)::UInt
38913896
throw(AssertionError("Enzyme: could not find world in $(string(fn))"))
38923897
end
38933898

3894-
function enzyme_custom_extract_mi(orig::LLVM.Instruction, error::Bool = true)
3899+
function enzyme_custom_extract_mi(orig::LLVM.CallInst, error::Bool = true)
38953900
operand = LLVM.called_operand(orig)
38963901
if isa(operand, LLVM.Function)
38973902
return enzyme_custom_extract_mi(operand::LLVM.Function, error)
@@ -6144,7 +6149,7 @@ end
61446149

61456150
using Random
61466151
# returns arg, return
6147-
function no_type_setting(@nospecialize(specTypes); world = nothing)
6152+
function no_type_setting(@nospecialize(specTypes::Type{<:Tuple}); world = nothing)
61486153
# Even though the julia type here is ptr{int8}, the actual data can be something else
61496154
if specTypes.parameters[1] == typeof(Random.XoshiroSimd.xoshiro_bulk_simd)
61506155
return (true, false)
@@ -7037,7 +7042,7 @@ end
70377042
end
70387043

70397044
# annotate
7040-
annotate!(mod, mode)
7045+
annotate!(mod)
70417046
for name in ("gpu_report_exception", "report_exception")
70427047
if haskey(functions(mod), name)
70437048
exc = functions(mod)[name]
@@ -8012,9 +8017,6 @@ end
80128017
::Type{TapeType},
80138018
args::Vararg{Any,N},
80148019
) where {RawCall,PT,FA,T,RT,TapeType,N,CC,width,returnPrimal}
8015-
8016-
JuliaContext() do ctx
8017-
Base.@_inline_meta
80188020
F = eltype(FA)
80198021
is_forward =
80208022
CC <: AugmentedForwardThunk || CC <: ForwardModeThunk || CC <: PrimalErrorThunk
@@ -8263,6 +8265,10 @@ end
82638265
i += 1
82648266
end
82658267

8268+
ts_ctx = JuliaContext()
8269+
ctx = context(ts_ctx)
8270+
activate(ctx)
8271+
(ir, fn, combinedReturn) = try
82668272

82678273
if is_adjoint
82688274
NT = Tuple{ActiveRetTypes...}
@@ -8441,31 +8447,35 @@ end
84418447

84428448
ir = string(mod)
84438449
fn = LLVM.name(llvm_f)
8450+
(ir, fn, combinedReturn)
8451+
finally
8452+
deactivate(ctx)
8453+
dispose(ts_ctx)
8454+
end
84448455

8445-
@assert length(types) == length(ccexprs)
8456+
@assert length(types) == length(ccexprs)
84468457

84478458

8448-
if !(GPUCompiler.isghosttype(PT) || Core.Compiler.isconstType(PT))
8449-
return quote
8450-
Base.@_inline_meta
8451-
Base.llvmcall(
8452-
($ir, $fn),
8453-
$combinedReturn,
8454-
Tuple{$PT,$(types...)},
8455-
fptr,
8456-
$(ccexprs...),
8457-
)
8458-
end
8459-
else
8460-
return quote
8461-
Base.@_inline_meta
8462-
Base.llvmcall(
8463-
($ir, $fn),
8464-
$combinedReturn,
8465-
Tuple{$(types...)},
8466-
$(ccexprs...),
8467-
)
8468-
end
8459+
if !(GPUCompiler.isghosttype(PT) || Core.Compiler.isconstType(PT))
8460+
return quote
8461+
Base.@_inline_meta
8462+
Base.llvmcall(
8463+
($ir, $fn),
8464+
$combinedReturn,
8465+
Tuple{$PT,$(types...)},
8466+
fptr,
8467+
$(ccexprs...),
8468+
)
8469+
end
8470+
else
8471+
return quote
8472+
Base.@_inline_meta
8473+
Base.llvmcall(
8474+
($ir, $fn),
8475+
$combinedReturn,
8476+
Tuple{$(types...)},
8477+
$(ccexprs...),
8478+
)
84698479
end
84708480
end
84718481
end
@@ -9071,7 +9081,10 @@ include("compiler/reflection.jl")
90719081
)
90729082
copysetfn = meta.entry
90739083
blk = first(blocks(copysetfn))
9074-
for inst in collect(instructions(blk))
9084+
iter = LLVM.API.LLVMGetFirstInstruction(blk)
9085+
while iter != C_NULL
9086+
inst = LLVM.Instruction(iter)
9087+
iter = LLVM.API.LLVMGetNextInstruction(iter)
90759088
if isa(inst, LLVM.FenceInst)
90769089
eraseInst(blk, inst)
90779090
end

src/compiler/interpreter.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ Core.Compiler.verbose_stmt_info(@nospecialize(::EnzymeInterpreter)) = false
123123
Core.Compiler.method_table(@nospecialize(interp::EnzymeInterpreter), sv::InferenceState) =
124124
Core.Compiler.OverlayMethodTable(interp.world, interp.method_table)
125125

126-
function is_alwaysinline_func(@nospecialize(TT))
126+
function is_alwaysinline_func(@nospecialize(TT))::Bool
127127
isa(TT, DataType) || return false
128128
@static if VERSION v"1.11-"
129129
if TT.parameters[1] == typeof(Core.memoryref)
@@ -133,7 +133,7 @@ function is_alwaysinline_func(@nospecialize(TT))
133133
return false
134134
end
135135

136-
function is_primitive_func(@nospecialize(TT))
136+
function is_primitive_func(@nospecialize(TT))::Bool
137137
isa(TT, DataType) || return false
138138
ft = TT.parameters[1]
139139
if ft == typeof(Enzyme.pmap)
@@ -156,11 +156,11 @@ function is_primitive_func(@nospecialize(TT))
156156
return false
157157
end
158158

159-
function isKWCallSignature(@nospecialize(TT))
159+
function isKWCallSignature(@nospecialize(TT))::Bool
160160
return TT <: Tuple{typeof(Core.kwcall),Any,Any,Vararg}
161161
end
162162

163-
function simplify_kw(@nospecialize specTypes)
163+
function simplify_kw(@nospecialize(specTypes))
164164
if isKWCallSignature(specTypes)
165165
return Base.tuple_type_tail(Base.tuple_type_tail(specTypes))
166166
else
@@ -742,15 +742,15 @@ end
742742
end
743743
end
744744

745-
@inline function array_or_number(@nospecialize(Ty))
745+
@inline function array_or_number(@nospecialize(Ty))::Bool
746746
return Ty <: AbstractArray || Ty <: Number
747747
end
748748

749-
@inline function isa_array_or_number(@nospecialize(x))
749+
@inline function isa_array_or_number(@nospecialize(x))::Bool
750750
return x isa AbstractArray || x isa Number
751751
end
752752

753-
@inline function num_or_eltype(@nospecialize(Ty))
753+
@inline function num_or_eltype(@nospecialize(Ty))::Type
754754
if Ty <: AbstractArray
755755
eltype(Ty)
756756
else

src/compiler/validation.jl

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ function get_blas_symbols()
2020
return symbols
2121
end
2222

23-
function lookup_blas_symbol(name)
23+
function lookup_blas_symbol(name::String)
2424
Libdl.dlsym(blas_handle::Ptr{Cvoid}, name; throw_error = false)
2525
end
2626
end
@@ -127,7 +127,7 @@ function __init__()
127127
end
128128
end
129129

130-
function memoize!(ptr, fn)
130+
function memoize!(ptr::Ptr{Cvoid}, fn::String)::String
131131
fn = get(ptr_map, ptr, fn)
132132
if !haskey(ptr_map, ptr)
133133
ptr_map[ptr] = fn
@@ -140,7 +140,7 @@ end
140140

141141
import GPUCompiler: IRError, InvalidIRError
142142

143-
function restore_lookups(mod::LLVM.Module)
143+
function restore_lookups(mod::LLVM.Module)::Nothing
144144
T_size_t = convert(LLVM.LLVMType, Int)
145145
for (v, k) in FFI.ptr_map
146146
if haskey(functions(mod), k)
@@ -421,7 +421,11 @@ function check_ir!(@nospecialize(job::CompilerJob), errors::Vector{IRError}, imp
421421
calls = LLVM.CallInst[]
422422
isInline = API.EnzymeGetCLBool(cglobal((:EnzymeInline, API.libEnzyme))) != 0
423423
mod = LLVM.parent(f)
424-
for bb in blocks(f), inst in collect(instructions(bb))
424+
for bb in blocks(f)
425+
iter = LLVM.API.LLVMGetFirstInstruction(bb)
426+
while iter != C_NULL
427+
inst = LLVM.Instruction(iter)
428+
iter = LLVM.API.LLVMGetNextInstruction(iter)
425429
if isa(inst, LLVM.CallInst)
426430
push!(calls, inst)
427431
# remove illegal invariant.load and jtbaa_const invariants
@@ -489,7 +493,11 @@ function check_ir!(@nospecialize(job::CompilerJob), errors::Vector{IRError}, imp
489493
newf, _ = get_function!(mod, fname, FT)
490494
else
491495
found = nothing
492-
for lbb in blocks(initfn), linst in collect(instructions(lbb))
496+
for lbb in blocks(initfn)
497+
liter = LLVM.API.LLVMGetFirstInstruction(lbb)
498+
while liter != C_NULL
499+
linst = LLVM.Instruction(liter)
500+
liter = LLVM.API.LLVMGetNextInstruction(liter)
493501
if !isa(linst, LLVM.CallInst)
494502
continue
495503
end
@@ -502,6 +510,7 @@ function check_ir!(@nospecialize(job::CompilerJob), errors::Vector{IRError}, imp
502510
break
503511
end
504512
end
513+
end
505514
if found == nothing
506515
msg = sprint() do io::IO
507516
println(
@@ -630,6 +639,7 @@ function check_ir!(@nospecialize(job::CompilerJob), errors::Vector{IRError}, imp
630639
end
631640
end
632641
end
642+
end
633643

634644
while length(calls) > 0
635645
inst = pop!(calls)

src/jlrt.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,7 @@ function load_if_mixed(oval::OT, val::VT) where {OT, VT}
326326
end
327327
end
328328

329-
function val_from_byref_if_mixed(B::LLVM.IRBuilder, gutils::GradientUtils, @nospecialize(oval::LLVM.Value), @nospecialize(val::LLVM.Value))
329+
function val_from_byref_if_mixed(B::LLVM.IRBuilder, gutils::GradientUtils, @nospecialize(oval::LLVM.Value), @nospecialize(val::LLVM.Value))::LLVM.Value
330330
world = enzyme_extract_world(LLVM.parent(position(B)))
331331
legal, TT, _ = abs_typeof(oval)
332332
if !legal
@@ -374,7 +374,7 @@ function ref_if_mixed(val::VT) where {VT}
374374
end
375375
end
376376

377-
function byref_from_val_if_mixed(B::LLVM.IRBuilder, @nospecialize(val::LLVM.Value))
377+
function byref_from_val_if_mixed(B::LLVM.IRBuilder, @nospecialize(val::LLVM.Value))::LLVM.Value
378378
world = enzyme_extract_world(LLVM.parent(position(B)))
379379
legal, TT, _ = abs_typeof(val)
380380
if !legal

src/rules/allocrules.jl

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,4 @@
1-
2-
function array_inner(::Type{<:Array{T}}) where {T}
3-
return T
4-
end
1+
@inline LLT_ALIGN(x::Int, sz::Int) = (((x) + (sz) - 1) & ~((sz) - 1))
52
function array_shadow_handler(
63
B::LLVM.API.LLVMBuilderRef,
74
OrigCI::LLVM.API.LLVMValueRef,
@@ -52,8 +49,6 @@ function array_shadow_handler(
5249

5350
isunion = typ isa Union
5451

55-
LLT_ALIGN(x, sz) = (((x) + (sz) - 1) & ~((sz) - 1))
56-
5752
if !isunboxed
5853
elsz = sizeof(Ptr{Cvoid})
5954
al = elsz

0 commit comments

Comments
 (0)