Skip to content

Commit f013866

Browse files
authored
1.12: Fix absint inlining (#2701)
* 1.12: Fix absint inlining * More future * fix * more fix * fix
1 parent cf1f006 commit f013866

File tree

1 file changed

+59
-53
lines changed

1 file changed

+59
-53
lines changed

src/compiler/interpreter.jl

Lines changed: 59 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -351,47 +351,13 @@ end
351351

352352
import .EnzymeRules: FwdConfig, RevConfig, Annotation
353353
using Core.Compiler: ArgInfo, StmtInfo, AbsIntState
354-
function Core.Compiler.abstract_call_gf_by_type(
355-
@nospecialize(interp::EnzymeInterpreter),
356-
@nospecialize(f),
357-
arginfo::ArgInfo,
358-
si::StmtInfo,
359-
@nospecialize(atype),
360-
sv::AbsIntState,
361-
max_methods::Int,
362-
)
363-
364-
ret = @invoke Core.Compiler.abstract_call_gf_by_type(
365-
interp::AbstractInterpreter,
366-
f::Any,
367-
arginfo::ArgInfo,
368-
si::StmtInfo,
369-
atype::Any,
370-
sv::AbsIntState,
371-
max_methods::Int,
372-
)
373-
if isdefined(Core.Compiler, :Future) # if stackless inference
374-
return Core.Compiler.Future{Core.Compiler.CallMeta}(ret, interp, sv) do ret, interp, sv
375-
callinfo = ret.info
376-
specTypes = simplify_kw(atype)
377-
378-
if is_primitive_func(specTypes)
379-
callinfo = NoInlineCallInfo(callinfo, atype, :primitive)
380-
elseif is_alwaysinline_func(specTypes)
381-
callinfo = AlwaysInlineCallInfo(callinfo, atype)
382-
else
383-
method_table = Core.Compiler.method_table(interp)
384-
if interp.inactive_rules && EnzymeRules.is_inactive_from_sig(specTypes; world = interp.world, method_table)
385-
callinfo = NoInlineCallInfo(callinfo, atype, :inactive)
386-
elseif interp.forward_rules && EnzymeRules.has_frule_from_sig(specTypes; world = interp.world, method_table)
387-
callinfo = NoInlineCallInfo(callinfo, atype, :frule)
388-
elseif interp.reverse_rules && EnzymeRules.has_rrule_from_sig(specTypes; world = interp.world, method_table)
389-
callinfo = NoInlineCallInfo(callinfo, atype, :rrule)
390-
end
391-
end
392-
return Core.Compiler.CallMeta(ret.rt, ret.exct, ret.effects, callinfo)
393-
end
394-
end
354+
355+
struct FutureCallinfoByType
356+
atype::Any
357+
end
358+
359+
@inline function (closure::FutureCallinfoByType)(ret::Core.Compiler.CallMeta, @nospecialize(interp::AbstractInterpreter), sv::AbsIntState)
360+
atype = closure.atype
395361
callinfo = ret.info
396362
specTypes = simplify_kw(atype)
397363

@@ -416,6 +382,34 @@ function Core.Compiler.abstract_call_gf_by_type(
416382
end
417383
end
418384

385+
function Core.Compiler.abstract_call_gf_by_type(
386+
@nospecialize(interp::EnzymeInterpreter),
387+
@nospecialize(f),
388+
arginfo::ArgInfo,
389+
si::StmtInfo,
390+
@nospecialize(atype),
391+
sv::AbsIntState,
392+
max_methods::Int,
393+
)
394+
395+
ret = @invoke Core.Compiler.abstract_call_gf_by_type(
396+
interp::AbstractInterpreter,
397+
f::Any,
398+
arginfo::ArgInfo,
399+
si::StmtInfo,
400+
atype::Any,
401+
sv::AbsIntState,
402+
max_methods::Int,
403+
)
404+
405+
if isdefined(Core.Compiler, :Future) # if stackless inference
406+
return Core.Compiler.Future{Core.Compiler.CallMeta}(FutureCallinfoByType(atype), ret, interp, sv)
407+
end
408+
409+
return FutureCallinfoByType(atype)(ret, interp, sv)
410+
end
411+
412+
419413
let # overload `inlining_policy`
420414
@static if VERSION v"1.11.0-DEV.879"
421415
sigs_ex = :(
@@ -481,10 +475,12 @@ let # overload `inlining_policy`
481475
@assert info.kind === :rrule
482476
@safe_debug "Blocking inlining due to rrule" info.tt
483477
end
484-
return nothing
478+
479+
return false
485480
elseif info isa AlwaysInlineCallInfo
486481
@safe_debug "Forcing inlining for primitive func" info.tt
487-
return src
482+
483+
return true
488484
end
489485
return @invoke Core.Compiler.src_inlining_policy($(args_ex.args...))
490486
end
@@ -1063,8 +1059,10 @@ function abstract_call_known(
10631059
if length(argtypes) != 1
10641060
@static if VERSION < v"1.11.0-"
10651061
return CallMeta(Union{}, Effects(), NoCallInfo())
1066-
else
1062+
elseif VERSION < v"1.12.0-"
10671063
return CallMeta(Union{}, Union{}, Effects(), NoCallInfo())
1064+
else
1065+
return Core.Compiler.Future{Core.Compiler.CallMeta}(CallMeta(Union{}, Union{}, Effects(), NoCallInfo()))
10681066
end
10691067
end
10701068
@static if VERSION < v"1.11.0-"
@@ -1073,28 +1071,35 @@ function abstract_call_known(
10731071
Core.Compiler.EFFECTS_TOTAL,
10741072
MethodResultPure(),
10751073
)
1076-
else
1074+
elseif VERSION < v"1.12.0-"
10771075
return CallMeta(
10781076
Core.Const(true),
10791077
Union{},
10801078
Core.Compiler.EFFECTS_TOTAL,
10811079
MethodResultPure(),
10821080
)
1081+
else
1082+
return Core.Compiler.Future{Core.Compiler.CallMeta}(CallMeta(
1083+
Core.Const(true),
1084+
Union{},
1085+
Core.Compiler.EFFECTS_TOTAL,
1086+
MethodResultPure(),
1087+
))
10831088
end
10841089
end
10851090

10861091
if interp.broadcast_rewrite
10871092
if f === Base.materialize && length(argtypes) == 2
10881093
bcty = widenconst(argtypes[2])
1089-
if Base.isconcretetype(bcty) && bcty <: Base.Broadcast.Broadcasted{<:Base.Broadcast.DefaultArrayStyle, Nothing} && bc_or_array_or_number_ty(bcty) && has_array(bcty)
1090-
ElType = ty_broadcast_getindex_eltype(interp, bcty)
1091-
if ElType !== Union{} && Base.isconcretetype(ElType)
1092-
fn2 = Enzyme.Compiler.Interpreter.OverrideBCMaterialize{ElType}()
1094+
if Base.isconcretetype(bcty) && bcty <: Base.Broadcast.Broadcasted{<:Base.Broadcast.DefaultArrayStyle, Nothing} && bc_or_array_or_number_ty(bcty) && has_array(bcty)
1095+
ElType = ty_broadcast_getindex_eltype(interp, bcty)
1096+
if ElType !== Union{} && Base.isconcretetype(ElType)
1097+
fn2 = Enzyme.Compiler.Interpreter.OverrideBCMaterialize{ElType}()
10931098
arginfo2 = ArgInfo(
1094-
fargs isa Nothing ? nothing :
1095-
[:(fn2), fargs[2:end]...],
1096-
[Core.Const(fn2), argtypes[2:end]...],
1099+
fargs isa Nothing ? nothing : [:(fn2), fargs[2:end]...],
1100+
[Core.Const(fn2), argtypes[2:end]...],
10971101
)
1102+
10981103
return Base.@invoke abstract_call_known(
10991104
interp::AbstractInterpreter,
11001105
fn2::Any,
@@ -1103,7 +1108,7 @@ function abstract_call_known(
11031108
sv::AbsIntState,
11041109
max_methods::Int,
11051110
)
1106-
end
1111+
end
11071112
end
11081113
end
11091114

@@ -1119,6 +1124,7 @@ function abstract_call_known(
11191124
[:(Enzyme.Compiler.Interpreter.override_bc_copyto!), fargs[2:end]...],
11201125
[Core.Const(Enzyme.Compiler.Interpreter.override_bc_copyto!), argtypes[2:end]...],
11211126
)
1127+
11221128
return Base.@invoke abstract_call_known(
11231129
interp::AbstractInterpreter,
11241130
Enzyme.Compiler.Interpreter.override_bc_copyto!::Any,

0 commit comments

Comments
 (0)