From 1070f9436b619edfa595f577a11d7cf2826567b2 Mon Sep 17 00:00:00 2001 From: Shuhei Kadowaki Date: Fri, 14 Jan 2022 05:21:10 +0900 Subject: [PATCH 1/3] optimizer: Julia-level escape analysis This commit ports [EscapeAnalysis.jl](https://github.com/aviatesk/EscapeAnalysis.jl) into Julia base. You can find the documentation of this escape analysis at [this GitHub page](https://aviatesk.github.io/EscapeAnalysis.jl/dev/)[^1]. [^1]: The same documentation will be included into Julia's developer documentation by this commit. This escape analysis will hopefully be an enabling technology for various memory-related optimizations at Julia's high level compilation pipeline. Possible target optimization includes alias aware SROA (#43888), array SROA (#43909), `mutating_arrayfreeze` optimization (#42465), stack allocation of mutables, finalizer elision and so on[^2]. [^2]: It would be also interesting if LLVM-level optimizations can consume IPO information derived by this escape analysis to broaden optimization possibilities. The primary motivation for porting EA in this PR is to check its impact on latency as well as to get feedbacks from a broader range of developers. The plan is that we first introduce EA in this commit, and then merge the depending PRs built on top of this commit like #43888, #43909 and #42465 This commit simply defines and runs EA inside Julia base compiler and enables the existing test suite with it. In this commit, we just run EA before inlining to generate IPO cache. The depending PRs, EA will be invoked again after inlining to be used for various local optimizations. --- base/boot.jl | 60 +- base/compiler/bootstrap.jl | 10 +- base/compiler/compiler.jl | 2 + base/compiler/optimize.jl | 103 +- .../ssair/EscapeAnalysis/EscapeAnalysis.jl | 1909 ++++++++++++++ .../ssair/EscapeAnalysis/disjoint_set.jl | 143 ++ .../ssair/EscapeAnalysis/interprocedural.jl | 151 ++ base/compiler/ssair/driver.jl | 6 +- base/compiler/tfuncs.jl | 2 +- base/compiler/typeinfer.jl | 2 +- base/compiler/types.jl | 12 +- base/compiler/utilities.jl | 4 + doc/make.jl | 1 + doc/src/devdocs/EscapeAnalysis.md | 363 +++ doc/src/devdocs/llvm.md | 2 +- src/dump.c | 4 + src/gf.c | 19 +- src/jltypes.c | 12 +- src/julia.h | 1 + test/choosetests.jl | 5 +- test/compiler/EscapeAnalysis/EAUtils.jl | 385 +++ .../EscapeAnalysis/interprocedural.jl | 264 ++ test/compiler/EscapeAnalysis/local.jl | 2203 +++++++++++++++++ test/compiler/EscapeAnalysis/setup.jl | 72 + 24 files changed, 5648 insertions(+), 87 deletions(-) create mode 100644 base/compiler/ssair/EscapeAnalysis/EscapeAnalysis.jl create mode 100644 base/compiler/ssair/EscapeAnalysis/disjoint_set.jl create mode 100644 base/compiler/ssair/EscapeAnalysis/interprocedural.jl create mode 100644 doc/src/devdocs/EscapeAnalysis.md create mode 100644 test/compiler/EscapeAnalysis/EAUtils.jl create mode 100644 test/compiler/EscapeAnalysis/interprocedural.jl create mode 100644 test/compiler/EscapeAnalysis/local.jl create mode 100644 test/compiler/EscapeAnalysis/setup.jl diff --git a/base/boot.jl b/base/boot.jl index ecc037407685e..290a98cbf2bbd 100644 --- a/base/boot.jl +++ b/base/boot.jl @@ -401,33 +401,39 @@ _new(:QuoteNode, :Any) _new(:SSAValue, :Int) _new(:Argument, :Int) _new(:ReturnNode, :Any) -eval(Core, :(ReturnNode() = $(Expr(:new, :ReturnNode)))) # unassigned val indicates unreachable -eval(Core, :(GotoIfNot(@nospecialize(cond), dest::Int) = $(Expr(:new, :GotoIfNot, :cond, :dest)))) -eval(Core, :(LineNumberNode(l::Int) = $(Expr(:new, :LineNumberNode, :l, nothing)))) -eval(Core, :(LineNumberNode(l::Int, @nospecialize(f)) = $(Expr(:new, :LineNumberNode, :l, :f)))) -LineNumberNode(l::Int, f::String) = LineNumberNode(l, Symbol(f)) -eval(Core, :(GlobalRef(m::Module, s::Symbol) = $(Expr(:new, :GlobalRef, :m, :s)))) -eval(Core, :(SlotNumber(n::Int) = $(Expr(:new, :SlotNumber, :n)))) -eval(Core, :(TypedSlot(n::Int, @nospecialize(t)) = $(Expr(:new, :TypedSlot, :n, :t)))) -eval(Core, :(PhiNode(edges::Array{Int32, 1}, values::Array{Any, 1}) = $(Expr(:new, :PhiNode, :edges, :values)))) -eval(Core, :(PiNode(val, typ) = $(Expr(:new, :PiNode, :val, :typ)))) -eval(Core, :(PhiCNode(values::Array{Any, 1}) = $(Expr(:new, :PhiCNode, :values)))) -eval(Core, :(UpsilonNode(val) = $(Expr(:new, :UpsilonNode, :val)))) -eval(Core, :(UpsilonNode() = $(Expr(:new, :UpsilonNode)))) -eval(Core, :(LineInfoNode(mod::Module, @nospecialize(method), file::Symbol, line::Int, inlined_at::Int) = - $(Expr(:new, :LineInfoNode, :mod, :method, :file, :line, :inlined_at)))) -eval(Core, :(CodeInstance(mi::MethodInstance, @nospecialize(rettype), @nospecialize(inferred_const), - @nospecialize(inferred), const_flags::Int32, - min_world::UInt, max_world::UInt, ipo_effects::UInt8, effects::UInt8, - relocatability::UInt8) = - ccall(:jl_new_codeinst, Ref{CodeInstance}, (Any, Any, Any, Any, Int32, UInt, UInt, UInt8, UInt8, UInt8), - mi, rettype, inferred_const, inferred, const_flags, min_world, max_world, ipo_effects, effects, relocatability))) -eval(Core, :(Const(@nospecialize(v)) = $(Expr(:new, :Const, :v)))) -eval(Core, :(PartialStruct(@nospecialize(typ), fields::Array{Any, 1}) = $(Expr(:new, :PartialStruct, :typ, :fields)))) -eval(Core, :(PartialOpaque(@nospecialize(typ), @nospecialize(env), isva::Bool, parent::MethodInstance, source::Method) = $(Expr(:new, :PartialOpaque, :typ, :env, :isva, :parent, :source)))) -eval(Core, :(InterConditional(slot::Int, @nospecialize(vtype), @nospecialize(elsetype)) = $(Expr(:new, :InterConditional, :slot, :vtype, :elsetype)))) -eval(Core, :(MethodMatch(@nospecialize(spec_types), sparams::SimpleVector, method::Method, fully_covers::Bool) = - $(Expr(:new, :MethodMatch, :spec_types, :sparams, :method, :fully_covers)))) +eval(Core, quote + ReturnNode() = $(Expr(:new, :ReturnNode)) # unassigned val indicates unreachable + GotoIfNot(@nospecialize(cond), dest::Int) = $(Expr(:new, :GotoIfNot, :cond, :dest)) + LineNumberNode(l::Int) = $(Expr(:new, :LineNumberNode, :l, nothing)) + function LineNumberNode(l::Int, @nospecialize(f)) + isa(f, String) && (f = Symbol(f)) + return $(Expr(:new, :LineNumberNode, :l, :f)) + end + LineInfoNode(mod::Module, @nospecialize(method), file::Symbol, line::Int, inlined_at::Int) = + $(Expr(:new, :LineInfoNode, :mod, :method, :file, :line, :inlined_at)) + GlobalRef(m::Module, s::Symbol) = $(Expr(:new, :GlobalRef, :m, :s)) + SlotNumber(n::Int) = $(Expr(:new, :SlotNumber, :n)) + TypedSlot(n::Int, @nospecialize(t)) = $(Expr(:new, :TypedSlot, :n, :t)) + PhiNode(edges::Array{Int32, 1}, values::Array{Any, 1}) = $(Expr(:new, :PhiNode, :edges, :values)) + PiNode(@nospecialize(val), @nospecialize(typ)) = $(Expr(:new, :PiNode, :val, :typ)) + PhiCNode(values::Array{Any, 1}) = $(Expr(:new, :PhiCNode, :values)) + UpsilonNode(@nospecialize(val)) = $(Expr(:new, :UpsilonNode, :val)) + UpsilonNode() = $(Expr(:new, :UpsilonNode)) + function CodeInstance( + mi::MethodInstance, @nospecialize(rettype), @nospecialize(inferred_const), + @nospecialize(inferred), const_flags::Int32, min_world::UInt, max_world::UInt, + ipo_effects::UInt8, effects::UInt8, @nospecialize(argescapes#=::Union{Nothing,Vector{ArgEscapeInfo}}=#), + relocatability::UInt8) + return ccall(:jl_new_codeinst, Ref{CodeInstance}, + (Any, Any, Any, Any, Int32, UInt, UInt, UInt8, UInt8, Any, UInt8), + mi, rettype, inferred_const, inferred, const_flags, min_world, max_world, ipo_effects, effects, argescapes, relocatability) + end + Const(@nospecialize(v)) = $(Expr(:new, :Const, :v)) + PartialStruct(@nospecialize(typ), fields::Array{Any, 1}) = $(Expr(:new, :PartialStruct, :typ, :fields)) + PartialOpaque(@nospecialize(typ), @nospecialize(env), isva::Bool, parent::MethodInstance, source::Method) = $(Expr(:new, :PartialOpaque, :typ, :env, :isva, :parent, :source)) + InterConditional(slot::Int, @nospecialize(vtype), @nospecialize(elsetype)) = $(Expr(:new, :InterConditional, :slot, :vtype, :elsetype)) + MethodMatch(@nospecialize(spec_types), sparams::SimpleVector, method::Method, fully_covers::Bool) = $(Expr(:new, :MethodMatch, :spec_types, :sparams, :method, :fully_covers)) +end) Module(name::Symbol=:anonymous, std_imports::Bool=true, default_names::Bool=true) = ccall(:jl_f_new_module, Ref{Module}, (Any, Bool, Bool), name, std_imports, default_names) diff --git a/base/compiler/bootstrap.jl b/base/compiler/bootstrap.jl index 2517b181d2804..1989d8aa57393 100644 --- a/base/compiler/bootstrap.jl +++ b/base/compiler/bootstrap.jl @@ -11,10 +11,11 @@ let world = get_world_counter() interp = NativeInterpreter(world) + analyze_escapes_tt = Tuple{typeof(analyze_escapes), IRCode, Int, Bool, typeof(get_escape_cache(code_cache(interp)))} fs = Any[ # we first create caches for the optimizer, because they contain many loop constructions # and they're better to not run in interpreter even during bootstrapping - run_passes, + analyze_escapes_tt, run_passes, # then we create caches for inference entries typeinf_ext, typeinf, typeinf_edge, ] @@ -32,7 +33,12 @@ let end starttime = time() for f in fs - for m in _methods_by_ftype(Tuple{typeof(f), Vararg{Any}}, 10, typemax(UInt)) + if isa(f, DataType) && f.name === typename(Tuple) + tt = f + else + tt = Tuple{typeof(f), Vararg{Any}} + end + for m in _methods_by_ftype(tt, 10, typemax(UInt)) # remove any TypeVars from the intersection typ = Any[m.spec_types.parameters...] for i = 1:length(typ) diff --git a/base/compiler/compiler.jl b/base/compiler/compiler.jl index 41e045773fb06..d13fb9e21b483 100644 --- a/base/compiler/compiler.jl +++ b/base/compiler/compiler.jl @@ -97,6 +97,8 @@ ntuple(f, n) = (Any[f(i) for i = 1:n]...,) # core docsystem include("docs/core.jl") +import Core.Compiler.CoreDocs +Core.atdoc!(CoreDocs.docm) # sorting function sort end diff --git a/base/compiler/optimize.jl b/base/compiler/optimize.jl index 58f20b5ef2a0c..635e53a9e1f1d 100644 --- a/base/compiler/optimize.jl +++ b/base/compiler/optimize.jl @@ -1,5 +1,35 @@ # This file is a part of Julia. License is MIT: https://julialang.org/license +############# +# constants # +############# + +# The slot has uses that are not statically dominated by any assignment +# This is implied by `SLOT_USEDUNDEF`. +# If this is not set, all the uses are (statically) dominated by the defs. +# In particular, if a slot has `AssignedOnce && !StaticUndef`, it is an SSA. +const SLOT_STATICUNDEF = 1 # slot might be used before it is defined (structurally) +const SLOT_ASSIGNEDONCE = 16 # slot is assigned to only once +const SLOT_USEDUNDEF = 32 # slot has uses that might raise UndefVarError +# const SLOT_CALLED = 64 + +# NOTE make sure to sync the flag definitions below with julia.h and `jl_code_info_set_ir` in method.c + +const IR_FLAG_NULL = 0x00 +# This statement is marked as @inbounds by user. +# Ff replaced by inlining, any contained boundschecks may be removed. +const IR_FLAG_INBOUNDS = 0x01 << 0 +# This statement is marked as @inline by user +const IR_FLAG_INLINE = 0x01 << 1 +# This statement is marked as @noinline by user +const IR_FLAG_NOINLINE = 0x01 << 2 +const IR_FLAG_THROW_BLOCK = 0x01 << 3 +# This statement may be removed if its result is unused. In particular it must +# thus be both pure and effect free. +const IR_FLAG_EFFECT_FREE = 0x01 << 4 + +const TOP_TUPLE = GlobalRef(Core, :tuple) + ##################### # OptimizationState # ##################### @@ -21,10 +51,10 @@ function push!(et::EdgeTracker, ci::CodeInstance) push!(et, ci.def) end -struct InliningState{S <: Union{EdgeTracker, Nothing}, T, I<:AbstractInterpreter} +struct InliningState{S <: Union{EdgeTracker, Nothing}, MICache, I<:AbstractInterpreter} params::OptimizationParams et::S - mi_cache::T + mi_cache::MICache # TODO move this to `OptimizationState` (as used by EscapeAnalysis as well) interp::I end @@ -52,7 +82,34 @@ function inlining_policy(interp::AbstractInterpreter, @nospecialize(src), stmt_f return nothing end +function argextype end # imported by EscapeAnalysis +function stmt_effect_free end # imported by EscapeAnalysis +function alloc_array_ndims end # imported by EscapeAnalysis include("compiler/ssair/driver.jl") +using .EscapeAnalysis +import .EscapeAnalysis: EscapeState, ArgEscapeCache, is_ipo_profitable + +""" + cache_escapes!(caller::InferenceResult, estate::EscapeState) + +Transforms escape information of call arguments of `caller`, +and then caches it into a global cache for later interprocedural propagation. +""" +cache_escapes!(caller::InferenceResult, estate::EscapeState) = + caller.argescapes = ArgEscapeCache(estate) + +function get_escape_cache(mi_cache::MICache) where MICache + return function (linfo::Union{InferenceResult,MethodInstance}) + if isa(linfo, InferenceResult) + argescapes = linfo.argescapes + else + codeinst = get(mi_cache, linfo, nothing) + isa(codeinst, CodeInstance) || return nothing + argescapes = codeinst.argescapes + end + return argescapes !== nothing ? argescapes::ArgEscapeCache : nothing + end +end mutable struct OptimizationState linfo::MethodInstance @@ -121,36 +178,6 @@ function ir_to_codeinf!(opt::OptimizationState) return src end -############# -# constants # -############# - -# The slot has uses that are not statically dominated by any assignment -# This is implied by `SLOT_USEDUNDEF`. -# If this is not set, all the uses are (statically) dominated by the defs. -# In particular, if a slot has `AssignedOnce && !StaticUndef`, it is an SSA. -const SLOT_STATICUNDEF = 1 # slot might be used before it is defined (structurally) -const SLOT_ASSIGNEDONCE = 16 # slot is assigned to only once -const SLOT_USEDUNDEF = 32 # slot has uses that might raise UndefVarError -# const SLOT_CALLED = 64 - -# NOTE make sure to sync the flag definitions below with julia.h and `jl_code_info_set_ir` in method.c - -const IR_FLAG_NULL = 0x00 -# This statement is marked as @inbounds by user. -# Ff replaced by inlining, any contained boundschecks may be removed. -const IR_FLAG_INBOUNDS = 0x01 << 0 -# This statement is marked as @inline by user -const IR_FLAG_INLINE = 0x01 << 1 -# This statement is marked as @noinline by user -const IR_FLAG_NOINLINE = 0x01 << 2 -const IR_FLAG_THROW_BLOCK = 0x01 << 3 -# This statement may be removed if its result is unused. In particular it must -# thus be both pure and effect free. -const IR_FLAG_EFFECT_FREE = 0x01 << 4 - -const TOP_TUPLE = GlobalRef(Core, :tuple) - ######### # logic # ######### @@ -503,15 +530,23 @@ end # run the optimization work function optimize(interp::AbstractInterpreter, opt::OptimizationState, params::OptimizationParams, caller::InferenceResult) - @timeit "optimizer" ir = run_passes(opt.src, opt) + @timeit "optimizer" ir = run_passes(opt.src, opt, caller) return finish(interp, opt, params, ir, caller) end -function run_passes(ci::CodeInfo, sv::OptimizationState) +function run_passes(ci::CodeInfo, sv::OptimizationState, caller::InferenceResult) @timeit "convert" ir = convert_to_ircode(ci, sv) @timeit "slot2reg" ir = slot2reg(ir, ci, sv) # TODO: Domsorting can produce an updated domtree - no need to recompute here @timeit "compact 1" ir = compact!(ir) + nargs = let def = sv.linfo.def; isa(def, Method) ? Int(def.nargs) : 0; end + get_escape_cache = (@__MODULE__).get_escape_cache(sv.inlining.mi_cache) + if is_ipo_profitable(ir, nargs) + @timeit "IPO EA" begin + state = analyze_escapes(ir, nargs, false, get_escape_cache) + cache_escapes!(caller, state) + end + end @timeit "Inlining" ir = ssa_inlining_pass!(ir, ir.linetable, sv.inlining, ci.propagate_inbounds) # @timeit "verify 2" verify_ir(ir) @timeit "compact 2" ir = compact!(ir) diff --git a/base/compiler/ssair/EscapeAnalysis/EscapeAnalysis.jl b/base/compiler/ssair/EscapeAnalysis/EscapeAnalysis.jl new file mode 100644 index 0000000000000..218afaefa431f --- /dev/null +++ b/base/compiler/ssair/EscapeAnalysis/EscapeAnalysis.jl @@ -0,0 +1,1909 @@ +baremodule EscapeAnalysis + +export + analyze_escapes, + getaliases, + isaliased, + has_no_escape, + has_arg_escape, + has_return_escape, + has_thrown_escape, + has_all_escape + +const _TOP_MOD = ccall(:jl_base_relative_to, Any, (Any,), EscapeAnalysis)::Module + +# imports +import ._TOP_MOD: ==, getindex, setindex! +# usings +import Core: + MethodInstance, Const, Argument, SSAValue, PiNode, PhiNode, UpsilonNode, PhiCNode, + ReturnNode, GotoNode, GotoIfNot, SimpleVector, MethodMatch, CodeInstance, + sizeof, ifelse, arrayset, arrayref, arraysize +import ._TOP_MOD: # Base definitions + @__MODULE__, @eval, @assert, @specialize, @nospecialize, @inbounds, @inline, @noinline, + @label, @goto, !, !==, !=, ≠, +, -, *, ≤, <, ≥, >, &, |, <<, include, error, missing, copy, + Vector, BitSet, IdDict, IdSet, UnitRange, Csize_t, Callable, ∪, ⊆, ∩, :, ∈, ∉, =>, + in, length, get, first, last, haskey, keys, get!, isempty, isassigned, + pop!, push!, pushfirst!, empty!, delete!, max, min, enumerate, unwrap_unionall, + ismutabletype +import Core.Compiler: # Core.Compiler specific definitions + Bottom, InferenceResult, IRCode, IR_FLAG_EFFECT_FREE, + isbitstype, isexpr, is_meta_expr_head, println, widenconst, argextype, singleton_type, + fieldcount_noerror, try_compute_field, try_compute_fieldidx, hasintersect, ⊑, + intrinsic_nothrow, array_builtin_common_typecheck, arrayset_typecheck, + setfield!_nothrow, alloc_array_ndims, stmt_effect_free, check_effect_free! + +if _TOP_MOD === Core.Compiler + include(@__MODULE__, "compiler/ssair/EscapeAnalysis/disjoint_set.jl") +else + include(@__MODULE__, "disjoint_set.jl") +end + +const AInfo = IdSet{Any} +const LivenessSet = BitSet + +""" + x::EscapeInfo + +A lattice for escape information, which holds the following properties: +- `x.Analyzed::Bool`: not formally part of the lattice, only indicates `x` has not been analyzed or not +- `x.ReturnEscape::Bool`: indicates `x` can escape to the caller via return +- `x.ThrownEscape::BitSet`: records SSA statement numbers where `x` can be thrown as exception: + * `isempty(x.ThrownEscape)`: `x` will never be thrown in this call frame (the bottom) + * `pc ∈ x.ThrownEscape`: `x` may be thrown at the SSA statement at `pc` + * `-1 ∈ x.ThrownEscape`: `x` may be thrown at arbitrary points of this call frame (the top) + This information will be used by `escape_exception!` to propagate potential escapes via exception. +- `x.AliasInfo::Union{Bool,IndexableFields,IndexableElements,Unindexable}`: maintains all possible values + that can be aliased to fields or array elements of `x`: + * `x.AliasInfo === false` indicates the fields/elements of `x` aren't analyzed yet + * `x.AliasInfo === true` indicates the fields/elements of `x` can't be analyzed, + e.g. the type of `x` is not known or is not concrete and thus its fields/elements + can't be known precisely + * `x.AliasInfo::IndexableFields` records all the possible values that can be aliased to fields of object `x` with precise index information + * `x.AliasInfo::IndexableElements` records all the possible values that can be aliased to elements of array `x` with precise index information + * `x.AliasInfo::Unindexable` records all the possible values that can be aliased to fields/elements of `x` without precise index information +- `x.Liveness::BitSet`: records SSA statement numbers where `x` should be live, e.g. + to be used as a call argument, to be returned to a caller, or preserved for `:foreigncall`: + * `isempty(x.Liveness)`: `x` is never be used in this call frame (the bottom) + * `0 ∈ x.Liveness` also has the special meaning that it's a call argument of the currently + analyzed call frame (and thus it's visible from the caller immediately). + * `pc ∈ x.Liveness`: `x` may be used at the SSA statement at `pc` + * `-1 ∈ x.Liveness`: `x` may be used at arbitrary points of this call frame (the top) + +There are utility constructors to create common `EscapeInfo`s, e.g., +- `NoEscape()`: the bottom(-like) element of this lattice, meaning it won't escape to anywhere +- `AllEscape()`: the topmost element of this lattice, meaning it will escape to everywhere + +`analyze_escapes` will transition these elements from the bottom to the top, +in the same direction as Julia's native type inference routine. +An abstract state will be initialized with the bottom(-like) elements: +- the call arguments are initialized as `ArgEscape()`, whose `Liveness` property includes `0` + to indicate that it is passed as a call argument and visible from a caller immediately +- the other states are initialized as `NotAnalyzed()`, which is a special lattice element that + is slightly lower than `NoEscape`, but at the same time doesn't represent any meaning + other than it's not analyzed yet (thus it's not formally part of the lattice) +""" +struct EscapeInfo + Analyzed::Bool + ReturnEscape::Bool + ThrownEscape::LivenessSet + AliasInfo #::Union{IndexableFields,IndexableElements,Unindexable,Bool} + Liveness::LivenessSet + + function EscapeInfo( + Analyzed::Bool, + ReturnEscape::Bool, + ThrownEscape::LivenessSet, + AliasInfo#=::Union{IndexableFields,IndexableElements,Unindexable,Bool}=#, + Liveness::LivenessSet, + ) + @nospecialize AliasInfo + return new( + Analyzed, + ReturnEscape, + ThrownEscape, + AliasInfo, + Liveness, + ) + end + function EscapeInfo( + x::EscapeInfo, + # non-concrete fields should be passed as default arguments + # in order to avoid allocating non-concrete `NamedTuple`s + AliasInfo#=::Union{IndexableFields,IndexableElements,Unindexable,Bool}=# = x.AliasInfo; + Analyzed::Bool = x.Analyzed, + ReturnEscape::Bool = x.ReturnEscape, + ThrownEscape::LivenessSet = x.ThrownEscape, + Liveness::LivenessSet = x.Liveness, + ) + @nospecialize AliasInfo + return new( + Analyzed, + ReturnEscape, + ThrownEscape, + AliasInfo, + Liveness, + ) + end +end + +# precomputed default values in order to eliminate computations at each callsite + +const BOT_THROWN_ESCAPE = LivenessSet() +# NOTE the lattice operations should try to avoid actual set computations on this top value, +# and e.g. LivenessSet(0:1000000) should also work without incurring excessive computations +const TOP_THROWN_ESCAPE = LivenessSet(-1) + +const BOT_LIVENESS = LivenessSet() +# NOTE the lattice operations should try to avoid actual set computations on this top value, +# and e.g. LivenessSet(0:1000000) should also work without incurring excessive computations +const TOP_LIVENESS = LivenessSet(-1:0) +const ARG_LIVENESS = LivenessSet(0) + +# the constructors +NotAnalyzed() = EscapeInfo(false, false, BOT_THROWN_ESCAPE, false, BOT_LIVENESS) # not formally part of the lattice +NoEscape() = EscapeInfo(true, false, BOT_THROWN_ESCAPE, false, BOT_LIVENESS) +ArgEscape() = EscapeInfo(true, false, BOT_THROWN_ESCAPE, true, ARG_LIVENESS) +ReturnEscape(pc::Int) = EscapeInfo(true, true, BOT_THROWN_ESCAPE, false, LivenessSet(pc)) +AllReturnEscape() = EscapeInfo(true, true, BOT_THROWN_ESCAPE, false, TOP_LIVENESS) +ThrownEscape(pc::Int) = EscapeInfo(true, false, LivenessSet(pc), false, BOT_LIVENESS) +AllEscape() = EscapeInfo(true, true, TOP_THROWN_ESCAPE, true, TOP_LIVENESS) + +const ⊥, ⊤ = NotAnalyzed(), AllEscape() + +# Convenience names for some ⊑ₑ queries +has_no_escape(x::EscapeInfo) = !x.ReturnEscape && isempty(x.ThrownEscape) && 0 ∉ x.Liveness +has_arg_escape(x::EscapeInfo) = 0 in x.Liveness +has_return_escape(x::EscapeInfo) = x.ReturnEscape +has_return_escape(x::EscapeInfo, pc::Int) = x.ReturnEscape && (-1 ∈ x.Liveness || pc in x.Liveness) +has_thrown_escape(x::EscapeInfo) = !isempty(x.ThrownEscape) +has_thrown_escape(x::EscapeInfo, pc::Int) = -1 ∈ x.ThrownEscape || pc in x.ThrownEscape +has_all_escape(x::EscapeInfo) = ⊤ ⊑ₑ x + +# utility lattice constructors +ignore_argescape(x::EscapeInfo) = EscapeInfo(x; Liveness=delete!(copy(x.Liveness), 0)) +ignore_thrownescapes(x::EscapeInfo) = EscapeInfo(x; ThrownEscape=BOT_THROWN_ESCAPE) +ignore_aliasinfo(x::EscapeInfo) = EscapeInfo(x, false) +ignore_liveness(x::EscapeInfo) = EscapeInfo(x; Liveness=BOT_LIVENESS) + +# AliasInfo +struct IndexableFields + infos::Vector{AInfo} +end +struct IndexableElements + infos::IdDict{Int,AInfo} +end +struct Unindexable + info::AInfo +end +IndexableFields(nfields::Int) = IndexableFields(AInfo[AInfo() for _ in 1:nfields]) +Unindexable() = Unindexable(AInfo()) + +merge_to_unindexable(AliasInfo::IndexableFields) = Unindexable(merge_to_unindexable(AliasInfo.infos)) +merge_to_unindexable(AliasInfo::Unindexable, AliasInfos::IndexableFields) = Unindexable(merge_to_unindexable(AliasInfo.info, AliasInfos.infos)) +merge_to_unindexable(infos::Vector{AInfo}) = merge_to_unindexable(AInfo(), infos) +function merge_to_unindexable(info::AInfo, infos::Vector{AInfo}) + for i = 1:length(infos) + info = info ∪ infos[i] + end + return info +end +merge_to_unindexable(AliasInfo::IndexableElements) = Unindexable(merge_to_unindexable(AliasInfo.infos)) +merge_to_unindexable(AliasInfo::Unindexable, AliasInfos::IndexableElements) = Unindexable(merge_to_unindexable(AliasInfo.info, AliasInfos.infos)) +merge_to_unindexable(infos::IdDict{Int,AInfo}) = merge_to_unindexable(AInfo(), infos) +function merge_to_unindexable(info::AInfo, infos::IdDict{Int,AInfo}) + for idx in keys(infos) + info = info ∪ infos[idx] + end + return info +end + +# we need to make sure this `==` operator corresponds to lattice equality rather than object equality, +# otherwise `propagate_changes` can't detect the convergence +x::EscapeInfo == y::EscapeInfo = begin + # fast pass: better to avoid top comparison + x === y && return true + x.Analyzed === y.Analyzed || return false + x.ReturnEscape === y.ReturnEscape || return false + xt, yt = x.ThrownEscape, y.ThrownEscape + if xt === TOP_THROWN_ESCAPE + yt === TOP_THROWN_ESCAPE || return false + elseif yt === TOP_THROWN_ESCAPE + return false # x.ThrownEscape === TOP_THROWN_ESCAPE + else + xt == yt || return false + end + xa, ya = x.AliasInfo, y.AliasInfo + if isa(xa, Bool) + xa === ya || return false + elseif isa(xa, IndexableFields) + isa(ya, IndexableFields) || return false + xa.infos == ya.infos || return false + elseif isa(xa, IndexableElements) + isa(ya, IndexableElements) || return false + xa.infos == ya.infos || return false + else + xa = xa::Unindexable + isa(ya, Unindexable) || return false + xa.info == ya.info || return false + end + xl, yl = x.Liveness, y.Liveness + if xl === TOP_LIVENESS + yl === TOP_LIVENESS || return false + elseif yl === TOP_LIVENESS + return false # x.Liveness === TOP_LIVENESS + else + xl == yl || return false + end + return true +end + +""" + x::EscapeInfo ⊑ₑ y::EscapeInfo -> Bool + +The non-strict partial order over `EscapeInfo`. +""" +x::EscapeInfo ⊑ₑ y::EscapeInfo = begin + # fast pass: better to avoid top comparison + if y === ⊤ + return true + elseif x === ⊤ + return false # return y === ⊤ + elseif x === ⊥ + return true + elseif y === ⊥ + return false # return x === ⊥ + end + x.Analyzed ≤ y.Analyzed || return false + x.ReturnEscape ≤ y.ReturnEscape || return false + xt, yt = x.ThrownEscape, y.ThrownEscape + if xt === TOP_THROWN_ESCAPE + yt !== TOP_THROWN_ESCAPE && return false + elseif yt !== TOP_THROWN_ESCAPE + xt ⊆ yt || return false + end + xa, ya = x.AliasInfo, y.AliasInfo + if isa(xa, Bool) + xa && ya !== true && return false + elseif isa(xa, IndexableFields) + if isa(ya, IndexableFields) + xinfos, yinfos = xa.infos, ya.infos + xn, yn = length(xinfos), length(yinfos) + xn > yn && return false + for i in 1:xn + xinfos[i] ⊆ yinfos[i] || return false + end + elseif isa(ya, IndexableElements) + return false + elseif isa(ya, Unindexable) + xinfos, yinfo = xa.infos, ya.info + for i = length(xinfos) + xinfos[i] ⊆ yinfo || return false + end + else + ya === true || return false + end + elseif isa(xa, IndexableElements) + if isa(ya, IndexableElements) + xinfos, yinfos = xa.infos, ya.infos + keys(xinfos) ⊆ keys(yinfos) || return false + for idx in keys(xinfos) + xinfos[idx] ⊆ yinfos[idx] || return false + end + elseif isa(ya, IndexableFields) + return false + elseif isa(ya, Unindexable) + xinfos, yinfo = xa.infos, ya.info + for idx in keys(xinfos) + xinfos[idx] ⊆ yinfo || return false + end + else + ya === true || return false + end + else + xa = xa::Unindexable + if isa(ya, Unindexable) + xinfo, yinfo = xa.info, ya.info + xinfo ⊆ yinfo || return false + else + ya === true || return false + end + end + xl, yl = x.Liveness, y.Liveness + if xl === TOP_LIVENESS + yl !== TOP_LIVENESS && return false + elseif yl !== TOP_LIVENESS + xl ⊆ yl || return false + end + return true +end + +""" + x::EscapeInfo ⊏ₑ y::EscapeInfo -> Bool + +The strict partial order over `EscapeInfo`. +This is defined as the irreflexive kernel of `⊏ₑ`. +""" +x::EscapeInfo ⊏ₑ y::EscapeInfo = x ⊑ₑ y && !(y ⊑ₑ x) + +""" + x::EscapeInfo ⋤ₑ y::EscapeInfo -> Bool + +This order could be used as a slightly more efficient version of the strict order `⊏ₑ`, +where we can safely assume `x ⊑ₑ y` holds. +""" +x::EscapeInfo ⋤ₑ y::EscapeInfo = !(y ⊑ₑ x) + +""" + x::EscapeInfo ⊔ₑ y::EscapeInfo -> EscapeInfo + +Computes the join of `x` and `y` in the partial order defined by `EscapeInfo`. +""" +x::EscapeInfo ⊔ₑ y::EscapeInfo = begin + # fast pass: better to avoid top join + if x === ⊤ || y === ⊤ + return ⊤ + elseif x === ⊥ + return y + elseif y === ⊥ + return x + end + xt, yt = x.ThrownEscape, y.ThrownEscape + if xt === TOP_THROWN_ESCAPE || yt === TOP_THROWN_ESCAPE + ThrownEscape = TOP_THROWN_ESCAPE + elseif xt === BOT_THROWN_ESCAPE + ThrownEscape = yt + elseif yt === BOT_THROWN_ESCAPE + ThrownEscape = xt + else + ThrownEscape = xt ∪ yt + end + AliasInfo = merge_alias_info(x.AliasInfo, y.AliasInfo) + xl, yl = x.Liveness, y.Liveness + if xl === TOP_LIVENESS || yl === TOP_LIVENESS + Liveness = TOP_LIVENESS + elseif xl === BOT_LIVENESS + Liveness = yl + elseif yl === BOT_LIVENESS + Liveness = xl + else + Liveness = xl ∪ yl + end + return EscapeInfo( + x.Analyzed | y.Analyzed, + x.ReturnEscape | y.ReturnEscape, + ThrownEscape, + AliasInfo, + Liveness, + ) +end + +function merge_alias_info(@nospecialize(xa), @nospecialize(ya)) + if xa === true || ya === true + return true + elseif xa === false + return ya + elseif ya === false + return xa + elseif isa(xa, IndexableFields) + if isa(ya, IndexableFields) + xinfos, yinfos = xa.infos, ya.infos + xn, yn = length(xinfos), length(yinfos) + nmax, nmin = max(xn, yn), min(xn, yn) + infos = Vector{AInfo}(undef, nmax) + for i in 1:nmax + if i > nmin + infos[i] = (xn > yn ? xinfos : yinfos)[i] + else + infos[i] = xinfos[i] ∪ yinfos[i] + end + end + return IndexableFields(infos) + elseif isa(ya, Unindexable) + xinfos, yinfo = xa.infos, ya.info + return merge_to_unindexable(ya, xa) + else + return true # handle conflicting case conservatively + end + elseif isa(xa, IndexableElements) + if isa(ya, IndexableElements) + xinfos, yinfos = xa.infos, ya.infos + infos = IdDict{Int,AInfo}() + for idx in keys(xinfos) + if !haskey(yinfos, idx) + infos[idx] = xinfos[idx] + else + infos[idx] = xinfos[idx] ∪ yinfos[idx] + end + end + for idx in keys(yinfos) + haskey(xinfos, idx) && continue # unioned already + infos[idx] = yinfos[idx] + end + return IndexableElements(infos) + elseif isa(ya, Unindexable) + return merge_to_unindexable(ya, xa) + else + return true # handle conflicting case conservatively + end + else + xa = xa::Unindexable + if isa(ya, IndexableFields) + return merge_to_unindexable(xa, ya) + elseif isa(ya, IndexableElements) + return merge_to_unindexable(xa, ya) + else + ya = ya::Unindexable + xinfo, yinfo = xa.info, ya.info + info = xinfo ∪ yinfo + return Unindexable(info) + end + end +end + +const AliasSet = IntDisjointSet{Int} + +const ArrayInfo = IdDict{Int,Vector{Int}} + +""" + estate::EscapeState + +Extended lattice that maps arguments and SSA values to escape information represented as [`EscapeInfo`](@ref). +Escape information imposed on SSA IR element `x` can be retrieved by `estate[x]`. +""" +struct EscapeState + escapes::Vector{EscapeInfo} + aliasset::AliasSet + nargs::Int + arrayinfo::Union{Nothing,ArrayInfo} +end +function EscapeState(nargs::Int, nstmts::Int, arrayinfo::Union{Nothing,ArrayInfo}) + escapes = EscapeInfo[ + 1 ≤ i ≤ nargs ? ArgEscape() : ⊥ for i in 1:(nargs+nstmts)] + aliasset = AliasSet(nargs+nstmts) + return EscapeState(escapes, aliasset, nargs, arrayinfo) +end +function getindex(estate::EscapeState, @nospecialize(x)) + xidx = iridx(x, estate) + return xidx === nothing ? nothing : estate.escapes[xidx] +end +function setindex!(estate::EscapeState, v::EscapeInfo, @nospecialize(x)) + xidx = iridx(x, estate) + if xidx !== nothing + estate.escapes[xidx] = v + end + return estate +end + +""" + iridx(x, estate::EscapeState) -> xidx::Union{Int,Nothing} + +Tries to convert analyzable IR element `x::Union{Argument,SSAValue}` to +its unique identifier number `xidx` that is valid in the analysis context of `estate`. +Returns `nothing` if `x` isn't maintained by `estate` and thus unanalyzable (e.g. `x::GlobalRef`). + +`irval` is the inverse function of `iridx` (not formally), i.e. +`irval(iridx(x::Union{Argument,SSAValue}, state), state) === x`. +""" +function iridx(@nospecialize(x), estate::EscapeState) + if isa(x, Argument) + xidx = x.n + @assert 1 ≤ xidx ≤ estate.nargs "invalid Argument" + elseif isa(x, SSAValue) + xidx = x.id + estate.nargs + else + return nothing + end + return xidx +end + +""" + irval(xidx::Int, estate::EscapeState) -> x::Union{Argument,SSAValue} + +Converts its unique identifier number `xidx` to the original IR element `x::Union{Argument,SSAValue}` +that is analyzable in the context of `estate`. + +`iridx` is the inverse function of `irval` (not formally), i.e. +`iridx(irval(xidx, state), state) === xidx`. +""" +function irval(xidx::Int, estate::EscapeState) + x = xidx > estate.nargs ? SSAValue(xidx-estate.nargs) : Argument(xidx) + return x +end + +function getaliases(x::Union{Argument,SSAValue}, estate::EscapeState) + xidx = iridx(x, estate) + aliases = getaliases(xidx, estate) + aliases === nothing && return nothing + return Union{Argument,SSAValue}[irval(aidx, estate) for aidx in aliases] +end +function getaliases(xidx::Int, estate::EscapeState) + aliasset = estate.aliasset + root = find_root!(aliasset, xidx) + if xidx ≠ root || aliasset.ranks[xidx] > 0 + # the size of this alias set containing `key` is larger than 1, + # collect the entire alias set + aliases = Int[] + for aidx in 1:length(aliasset.parents) + if aliasset.parents[aidx] == root + push!(aliases, aidx) + end + end + return aliases + else + return nothing + end +end + +isaliased(x::Union{Argument,SSAValue}, y::Union{Argument,SSAValue}, estate::EscapeState) = + isaliased(iridx(x, estate), iridx(y, estate), estate) +isaliased(xidx::Int, yidx::Int, estate::EscapeState) = + in_same_set(estate.aliasset, xidx, yidx) + +struct ArgEscapeInfo + EscapeBits::UInt8 +end +function ArgEscapeInfo(x::EscapeInfo) + x === ⊤ && return ArgEscapeInfo(ARG_ALL_ESCAPE) + EscapeBits = 0x00 + has_return_escape(x) && (EscapeBits |= ARG_RETURN_ESCAPE) + has_thrown_escape(x) && (EscapeBits |= ARG_THROWN_ESCAPE) + return ArgEscapeInfo(EscapeBits) +end + +const ARG_ALL_ESCAPE = 0x01 << 0 +const ARG_RETURN_ESCAPE = 0x01 << 1 +const ARG_THROWN_ESCAPE = 0x01 << 2 + +has_no_escape(x::ArgEscapeInfo) = !has_all_escape(x) && !has_return_escape(x) && !has_thrown_escape(x) +has_all_escape(x::ArgEscapeInfo) = x.EscapeBits & ARG_ALL_ESCAPE ≠ 0 +has_return_escape(x::ArgEscapeInfo) = x.EscapeBits & ARG_RETURN_ESCAPE ≠ 0 +has_thrown_escape(x::ArgEscapeInfo) = x.EscapeBits & ARG_THROWN_ESCAPE ≠ 0 + +struct ArgAliasing + aidx::Int + bidx::Int +end + +struct ArgEscapeCache + argescapes::Vector{ArgEscapeInfo} + argaliases::Vector{ArgAliasing} +end + +function ArgEscapeCache(estate::EscapeState) + nargs = estate.nargs + argescapes = Vector{ArgEscapeInfo}(undef, nargs) + argaliases = ArgAliasing[] + for i = 1:nargs + info = estate.escapes[i] + @assert info.AliasInfo === true + argescapes[i] = ArgEscapeInfo(info) + for j = (i+1):nargs + if isaliased(i, j, estate) + push!(argaliases, ArgAliasing(i, j)) + end + end + end + return ArgEscapeCache(argescapes, argaliases) +end + +""" + is_ipo_profitable(ir::IRCode, nargs::Int) -> Bool + +Heuristically checks if there is any profitability to run the escape analysis on `ir` +and generate IPO escape information cache. Specifically, this function examines +if any call argument is "interesting" in terms of their escapability. +""" +function is_ipo_profitable(ir::IRCode, nargs::Int) + for i = 1:nargs + t = unwrap_unionall(widenconst(ir.argtypes[i])) + t <: IO && return false # bail out IO-related functions + is_ipo_profitable_type(t) && return true + end + return false +end +function is_ipo_profitable_type(@nospecialize t) + if isa(t, Union) + return is_ipo_profitable_type(t.a) && is_ipo_profitable_type(t.b) + end + (t === String || t === Symbol || t === Module || t === SimpleVector) && return false + return ismutabletype(t) +end + +abstract type Change end +struct EscapeChange <: Change + xidx::Int + xinfo::EscapeInfo +end +struct AliasChange <: Change + xidx::Int + yidx::Int +end +struct ArgAliasChange <: Change + xidx::Int + yidx::Int +end +struct LivenessChange <: Change + xidx::Int + livepc::Int +end +const Changes = Vector{Change} + +struct AnalysisState{T<:Callable} + ir::IRCode + estate::EscapeState + changes::Changes + get_escape_cache::T +end + +function getinst(ir::IRCode, idx::Int) + nstmts = length(ir.stmts) + if idx ≤ nstmts + return ir.stmts[idx] + else + return ir.new_nodes.stmts[idx - nstmts] + end +end + +""" + analyze_escapes(ir::IRCode, nargs::Int, call_resolved::Bool, get_escape_cache::Callable) + -> estate::EscapeState + +Analyzes escape information in `ir`: +- `nargs`: the number of actual arguments of the analyzed call +- `call_resolved`: if interprocedural calls are already resolved by `ssa_inlining_pass!` +- `get_escape_cache(::Union{InferenceResult,MethodInstance}) -> Union{Nothing,ArgEscapeCache}`: + retrieves cached argument escape information +""" +function analyze_escapes(ir::IRCode, nargs::Int, call_resolved::Bool, get_escape_cache::T) where T<:Callable + stmts = ir.stmts + nstmts = length(stmts) + length(ir.new_nodes.stmts) + + tryregions, arrayinfo, callinfo = compute_frameinfo(ir, call_resolved) + estate = EscapeState(nargs, nstmts, arrayinfo) + changes = Changes() # keeps changes that happen at current statement + astate = AnalysisState(ir, estate, changes, get_escape_cache) + + local debug_itr_counter = 0 + while true + local anyupdate = false + + for pc in nstmts:-1:1 + stmt = getinst(ir, pc)[:inst] + + # collect escape information + if isa(stmt, Expr) + head = stmt.head + if head === :call + if callinfo !== nothing + escape_call!(astate, pc, stmt.args, callinfo) + else + escape_call!(astate, pc, stmt.args) + end + elseif head === :invoke + escape_invoke!(astate, pc, stmt.args) + elseif head === :new || head === :splatnew + escape_new!(astate, pc, stmt.args) + elseif head === :(=) + lhs, rhs = stmt.args + if isa(lhs, GlobalRef) # global store + add_escape_change!(astate, rhs, ⊤) + else + unexpected_assignment!(ir, pc) + end + elseif head === :foreigncall + escape_foreigncall!(astate, pc, stmt.args) + elseif head === :throw_undef_if_not # XXX when is this expression inserted ? + add_escape_change!(astate, stmt.args[1], ThrownEscape(pc)) + elseif is_meta_expr_head(head) + # meta expressions doesn't account for any usages + continue + elseif head === :enter || head === :leave || head === :the_exception || head === :pop_exception + # ignore these expressions since escapes via exceptions are handled by `escape_exception!` + # `escape_exception!` conservatively propagates `AllEscape` anyway, + # and so escape information imposed on `:the_exception` isn't computed + continue + elseif head === :static_parameter || # this exists statically, not interested in its escape + head === :copyast || # XXX can this account for some escapes? + head === :undefcheck || # XXX can this account for some escapes? + head === :isdefined || # just returns `Bool`, nothing accounts for any escapes + head === :gc_preserve_begin || # `GC.@preserve` expressions themselves won't be used anywhere + head === :gc_preserve_end # `GC.@preserve` expressions themselves won't be used anywhere + continue + else + add_conservative_changes!(astate, pc, stmt.args) + end + elseif isa(stmt, ReturnNode) + if isdefined(stmt, :val) + add_escape_change!(astate, stmt.val, ReturnEscape(pc)) + end + elseif isa(stmt, PhiNode) + escape_edges!(astate, pc, stmt.values) + elseif isa(stmt, PiNode) + escape_val_ifdefined!(astate, pc, stmt) + elseif isa(stmt, PhiCNode) + escape_edges!(astate, pc, stmt.values) + elseif isa(stmt, UpsilonNode) + escape_val_ifdefined!(astate, pc, stmt) + elseif isa(stmt, GlobalRef) # global load + add_escape_change!(astate, SSAValue(pc), ⊤) + elseif isa(stmt, SSAValue) + escape_val!(astate, pc, stmt) + elseif isa(stmt, Argument) + escape_val!(astate, pc, stmt) + else # otherwise `stmt` can be GotoNode, GotoIfNot, and inlined values etc. + continue + end + + isempty(changes) && continue + + anyupdate |= propagate_changes!(estate, changes) + + empty!(changes) + end + + tryregions !== nothing && escape_exception!(astate, tryregions) + + debug_itr_counter += 1 + + anyupdate || break + end + + # if debug_itr_counter > 2 + # println("[EA] excessive iteration count found ", debug_itr_counter, " (", singleton_type(ir.argtypes[1]), ")") + # end + + return estate +end + +""" + compute_frameinfo(ir::IRCode, call_resolved::Bool) -> (tryregions, arrayinfo, callinfo) + +A preparatory linear scan before the escape analysis on `ir` to find: +- `tryregions::Union{Nothing,Vector{UnitRange{Int}}}`: regions in which potential `throw`s can be caught (used by `escape_exception!`) +- `arrayinfo::Union{Nothing,IdDict{Int,Vector{Int}}}`: array allocations whose dimensions are known precisely (with some very simple local analysis) +- `callinfo::`: when `!call_resolved`, `compute_frameinfo` additionally returns `callinfo::Vector{Union{MethodInstance,InferenceResult}}`, + which contains information about statically resolved callsites. + The inliner will use essentially equivalent interprocedural information to inline callees as well as resolve static callsites, + this additional information won't be required when analyzing post-inlining IR. + +!!! note + This array dimension analysis to compute `arrayinfo` is very local and doesn't account + for flow-sensitivity nor complex aliasing. + Ideally this dimension analysis should be done as a part of type inference that + propagates array dimenstions in a flow sensitive way. +""" +function compute_frameinfo(ir::IRCode, call_resolved::Bool) + nstmts, nnewnodes = length(ir.stmts), length(ir.new_nodes.stmts) + tryregions, arrayinfo = nothing, nothing + if !call_resolved + callinfo = Vector{Any}(undef, nstmts+nnewnodes) + else + callinfo = nothing + end + for idx in 1:nstmts+nnewnodes + inst = getinst(ir, idx) + stmt = inst[:inst] + if !call_resolved + # TODO don't call `check_effect_free!` in the inlinear + check_effect_free!(ir, idx, stmt, inst[:type]) + end + if callinfo !== nothing && isexpr(stmt, :call) + callinfo[idx] = resolve_call(ir, stmt, inst[:info]) + elseif isexpr(stmt, :enter) + @assert idx ≤ nstmts "try/catch inside new_nodes unsupported" + tryregions === nothing && (tryregions = UnitRange{Int}[]) + leave_block = stmt.args[1]::Int + leave_pc = first(ir.cfg.blocks[leave_block].stmts) + push!(tryregions, idx:leave_pc) + elseif isexpr(stmt, :foreigncall) + args = stmt.args + name = args[1] + nn = normalize(name) + isa(nn, Symbol) || @goto next_stmt + ndims = alloc_array_ndims(nn) + ndims === nothing && @goto next_stmt + if ndims ≠ 0 + length(args) ≥ ndims+6 || @goto next_stmt + dims = Int[] + for i in 1:ndims + dim = argextype(args[i+6], ir) + isa(dim, Const) || @goto next_stmt + dim = dim.val + isa(dim, Int) || @goto next_stmt + push!(dims, dim) + end + else + length(args) ≥ 7 || @goto next_stmt + dims = argextype(args[7], ir) + if isa(dims, Const) + dims = dims.val + isa(dims, Tuple{Vararg{Int}}) || @goto next_stmt + dims = collect(Int, dims) + else + dims === Tuple{} || @goto next_stmt + dims = Int[] + end + end + if arrayinfo === nothing + arrayinfo = ArrayInfo() + end + arrayinfo[idx] = dims + elseif arrayinfo !== nothing + # TODO this super limited alias analysis is able to handle only very simple cases + # this should be replaced with a proper forward dimension analysis + if isa(stmt, PhiNode) + values = stmt.values + local dims = nothing + for i = 1:length(values) + if isassigned(values, i) + val = values[i] + if isa(val, SSAValue) && haskey(arrayinfo, val.id) + if dims === nothing + dims = arrayinfo[val.id] + continue + elseif dims == arrayinfo[val.id] + continue + end + end + end + @goto next_stmt + end + if dims !== nothing + arrayinfo[idx] = dims + end + elseif isa(stmt, PiNode) + if isdefined(stmt, :val) + val = stmt.val + if isa(val, SSAValue) && haskey(arrayinfo, val.id) + arrayinfo[idx] = arrayinfo[val.id] + end + end + end + end + @label next_stmt + end + return tryregions, arrayinfo, callinfo +end + +# define resolve_call +if _TOP_MOD === Core.Compiler + include(@__MODULE__, "compiler/ssair/EscapeAnalysis/interprocedural.jl") +else + include(@__MODULE__, "interprocedural.jl") +end + +# propagate changes, and check convergence +function propagate_changes!(estate::EscapeState, changes::Changes) + local anychanged = false + for change in changes + if isa(change, EscapeChange) + anychanged |= propagate_escape_change!(estate, change) + elseif isa(change, LivenessChange) + anychanged |= propagate_liveness_change!(estate, change) + else + change = change::AliasChange + anychanged |= propagate_alias_change!(estate, change) + end + end + return anychanged +end + +@inline propagate_escape_change!(estate::EscapeState, change::EscapeChange) = + propagate_escape_change!(⊔ₑ, estate, change) + +# allows this to work as lattice join as well as lattice meet +@inline function propagate_escape_change!(@specialize(op), + estate::EscapeState, change::EscapeChange) + (; xidx, xinfo) = change + anychanged = _propagate_escape_change!(op, estate, xidx, xinfo) + # COMBAK is there a more efficient method of escape information equalization on aliasset? + aliases = getaliases(xidx, estate) + if aliases !== nothing + for aidx in aliases + anychanged |= _propagate_escape_change!(op, estate, aidx, xinfo) + end + end + return anychanged +end + +@inline function _propagate_escape_change!(@specialize(op), + estate::EscapeState, xidx::Int, info::EscapeInfo) + old = estate.escapes[xidx] + new = op(old, info) + if old ≠ new + estate.escapes[xidx] = new + return true + end + return false +end + +# propagate Liveness changes separately in order to avoid constructing too many LivenessSet +@inline function propagate_liveness_change!(estate::EscapeState, change::LivenessChange) + (; xidx, livepc) = change + info = estate.escapes[xidx] + Liveness = info.Liveness + Liveness === TOP_LIVENESS && return false + livepc in Liveness && return false + if Liveness === BOT_LIVENESS || Liveness === ARG_LIVENESS + # if this Liveness is a constant, we shouldn't modify it and propagate this change as a new EscapeInfo + Liveness = copy(Liveness) + push!(Liveness, livepc) + estate.escapes[xidx] = EscapeInfo(info; Liveness) + return true + else + # directly modify Liveness property in order to avoid excessive copies + push!(Liveness, livepc) + return true + end +end + +@inline function propagate_alias_change!(estate::EscapeState, change::AliasChange) + anychange = false + (; xidx, yidx) = change + aliasset = estate.aliasset + xroot = find_root!(aliasset, xidx) + yroot = find_root!(aliasset, yidx) + if xroot ≠ yroot + union!(aliasset, xroot, yroot) + return true + end + return false +end + +function add_escape_change!(astate::AnalysisState, @nospecialize(x), xinfo::EscapeInfo) + xinfo === ⊥ && return nothing # performance optimization + xidx = iridx(x, astate.estate) + if xidx !== nothing + if !isbitstype(widenconst(argextype(x, astate.ir))) + push!(astate.changes, EscapeChange(xidx, xinfo)) + end + end + return nothing +end + +function add_liveness_change!(astate::AnalysisState, @nospecialize(x), livepc::Int) + xidx = iridx(x, astate.estate) + if xidx !== nothing + if !isbitstype(widenconst(argextype(x, astate.ir))) + push!(astate.changes, LivenessChange(xidx, livepc)) + end + end + return nothing +end + +function add_alias_change!(astate::AnalysisState, @nospecialize(x), @nospecialize(y)) + if isa(x, GlobalRef) + return add_escape_change!(astate, y, ⊤) + elseif isa(y, GlobalRef) + return add_escape_change!(astate, x, ⊤) + end + estate = astate.estate + xidx = iridx(x, estate) + yidx = iridx(y, estate) + if xidx !== nothing && yidx !== nothing && !isaliased(xidx, yidx, astate.estate) + pushfirst!(astate.changes, AliasChange(xidx, yidx)) + # add new escape change here so that it's shared among the expanded `aliasset` in `propagate_escape_change!` + xinfo = estate.escapes[xidx] + yinfo = estate.escapes[yidx] + add_escape_change!(astate, x, xinfo ⊔ₑ yinfo) + end + return nothing +end + +struct LocalDef + idx::Int +end +struct LocalUse + idx::Int +end + +function add_alias_escapes!(astate::AnalysisState, @nospecialize(v), ainfo::AInfo) + estate = astate.estate + for x in ainfo + isa(x, LocalUse) || continue # ignore def + x = SSAValue(x.idx) # obviously this won't be true once we implement interprocedural AliasInfo + add_alias_change!(astate, v, x) + end +end + +function add_thrown_escapes!(astate::AnalysisState, pc::Int, args::Vector{Any}, + first_idx::Int = 1, last_idx::Int = length(args)) + info = ThrownEscape(pc) + for i in first_idx:last_idx + add_escape_change!(astate, args[i], info) + end +end + +function add_liveness_changes!(astate::AnalysisState, pc::Int, args::Vector{Any}, + first_idx::Int = 1, last_idx::Int = length(args)) + for i in first_idx:last_idx + arg = args[i] + add_liveness_change!(astate, arg, pc) + end +end + +function add_fallback_changes!(astate::AnalysisState, pc::Int, args::Vector{Any}, + first_idx::Int = 1, last_idx::Int = length(args)) + info = ThrownEscape(pc) + for i in first_idx:last_idx + arg = args[i] + add_escape_change!(astate, arg, info) + add_liveness_change!(astate, arg, pc) + end +end + +function add_conservative_changes!(astate::AnalysisState, pc::Int, args::Vector{Any}, + first_idx::Int = 1, last_idx::Int = length(args)) + for i in first_idx:last_idx + add_escape_change!(astate, args[i], ⊤) + end + add_escape_change!(astate, SSAValue(pc), ⊤) # it may return GlobalRef etc. + return nothing +end + +function escape_edges!(astate::AnalysisState, pc::Int, edges::Vector{Any}) + ret = SSAValue(pc) + for i in 1:length(edges) + if isassigned(edges, i) + v = edges[i] + add_alias_change!(astate, ret, v) + end + end +end + +function escape_val_ifdefined!(astate::AnalysisState, pc::Int, x) + if isdefined(x, :val) + escape_val!(astate, pc, x.val) + end +end + +function escape_val!(astate::AnalysisState, pc::Int, @nospecialize(val)) + ret = SSAValue(pc) + add_alias_change!(astate, ret, val) +end + +function escape_unanalyzable_obj!(astate::AnalysisState, @nospecialize(obj), objinfo::EscapeInfo) + objinfo = EscapeInfo(objinfo, true) + add_escape_change!(astate, obj, objinfo) + return objinfo +end + +@noinline function unexpected_assignment!(ir::IRCode, pc::Int) + @eval Main (ir = $ir; pc = $pc) + error("unexpected assignment found: inspect `Main.pc` and `Main.pc`") +end + +is_effect_free(ir::IRCode, pc::Int) = getinst(ir, pc)[:flag] & IR_FLAG_EFFECT_FREE ≠ 0 + +# NOTE if we don't maintain the alias set that is separated from the lattice state, we can do +# something like below: it essentially incorporates forward escape propagation in our default +# backward propagation, and leads to inefficient convergence that requires more iterations +# # lhs = rhs: propagate escape information of `rhs` to `lhs` +# function escape_alias!(astate::AnalysisState, @nospecialize(lhs), @nospecialize(rhs)) +# if isa(rhs, SSAValue) || isa(rhs, Argument) +# vinfo = astate.estate[rhs] +# else +# return +# end +# add_escape_change!(astate, lhs, vinfo) +# end + +""" + escape_exception!(astate::AnalysisState, tryregions::Vector{UnitRange{Int}}) + +Propagates escapes via exceptions that can happen in `tryregions`. + +Naively it seems enough to propagate escape information imposed on `:the_exception` object, +but actually there are several other ways to access to the exception object such as +`Base.current_exceptions` and manual catch of `rethrow`n object. +For example, escape analysis needs to account for potential escape of the allocated object +via `rethrow_escape!()` call in the example below: +```julia +const Gx = Ref{Any}() +@noinline function rethrow_escape!() + try + rethrow() + catch err + Gx[] = err + end +end +unsafeget(x) = isassigned(x) ? x[] : throw(x) + +code_escapes() do + r = Ref{String}() + try + t = unsafeget(r) + catch err + t = typeof(err) # `err` (which `r` may alias to) doesn't escape here + rethrow_escape!() # `r` can escape here + end + return t +end +``` + +As indicated by the above example, it requires a global analysis in addition to a base escape +analysis to reason about all possible escapes via existing exception interfaces correctly. +For now we conservatively always propagate `AllEscape` to all potentially thrown objects, +since such an additional analysis might not be worthwhile to do given that exception handlings +and error paths usually don't need to be very performance sensitive, and optimizations of +error paths might be very ineffective anyway since they are sometimes "unoptimized" +intentionally for latency reasons. +""" +function escape_exception!(astate::AnalysisState, tryregions::Vector{UnitRange{Int}}) + estate = astate.estate + # NOTE if `:the_exception` is the only way to access the exception, we can do: + # exc = SSAValue(pc) + # excinfo = estate[exc] + excinfo = ⊤ + escapes = estate.escapes + for i in 1:length(escapes) + x = escapes[i] + xt = x.ThrownEscape + xt === TOP_THROWN_ESCAPE && @goto propagate_exception_escape # fast pass + for pc in xt + for region in tryregions + pc in region && @goto propagate_exception_escape # early break because of AllEscape + end + end + continue + @label propagate_exception_escape + xval = irval(i, estate) + add_escape_change!(astate, xval, excinfo) + end +end + +# escape statically-resolved call, i.e. `Expr(:invoke, ::MethodInstance, ...)` +escape_invoke!(astate::AnalysisState, pc::Int, args::Vector{Any}) = + escape_invoke!(astate, pc, args, first(args)::MethodInstance, 2) + +function escape_invoke!(astate::AnalysisState, pc::Int, args::Vector{Any}, + linfo::Linfo, first_idx::Int, last_idx::Int = length(args)) + if isa(linfo, InferenceResult) + cache = astate.get_escape_cache(linfo) + linfo = linfo.linfo + else + cache = astate.get_escape_cache(linfo) + end + if cache === nothing + return add_conservative_changes!(astate, pc, args, 2) + else + cache = cache::ArgEscapeCache + end + ret = SSAValue(pc) + retinfo = astate.estate[ret] # escape information imposed on the call statement + method = linfo.def::Method + nargs = Int(method.nargs) + for (i, argidx) in enumerate(first_idx:last_idx) + arg = args[argidx] + if i > nargs + # handle isva signature + # COMBAK will this be invalid once we take alias information into account? + i = nargs + end + arginfo = cache.argescapes[i] + info = from_interprocedural(arginfo, pc) + if has_return_escape(arginfo) + # if this argument can be "returned", in addition to propagating + # the escape information imposed on this call argument within the callee, + # we should also account for possible aliasing of this argument and the returned value + add_escape_change!(astate, arg, info) + add_alias_change!(astate, ret, arg) + else + # if this is simply passed as the call argument, we can just propagate + # the escape information imposed on this call argument within the callee + add_escape_change!(astate, arg, info) + end + end + for (; aidx, bidx) in cache.argaliases + add_alias_change!(astate, args[aidx-(first_idx-1)], args[bidx-(first_idx-1)]) + end + # we should disable the alias analysis on this newly introduced object + add_escape_change!(astate, ret, EscapeInfo(retinfo, true)) +end + +""" + from_interprocedural(arginfo::ArgEscapeInfo, pc::Int) -> x::EscapeInfo + +Reinterprets the escape information imposed on the call argument which is cached as `arginfo` +in the context of the caller frame, where `pc` is the SSA statement number of the return value. +""" +function from_interprocedural(arginfo::ArgEscapeInfo, pc::Int) + has_all_escape(arginfo) && return ⊤ + + ThrownEscape = has_thrown_escape(arginfo) ? LivenessSet(pc) : BOT_THROWN_ESCAPE + + return EscapeInfo( + #=Analyzed=#true, #=ReturnEscape=#false, ThrownEscape, + # FIXME implement interprocedural memory effect-analysis + # currently, this essentially disables the entire field analysis + # it might be okay from the SROA point of view, since we can't remove the allocation + # as far as it's passed to a callee anyway, but still we may want some field analysis + # for e.g. stack allocation or some other IPO optimizations + #=AliasInfo=#true, #=Liveness=#LivenessSet(pc)) +end + +# escape every argument `(args[6:length(args[3])])` and the name `args[1]` +# TODO: we can apply a similar strategy like builtin calls to specialize some foreigncalls +function escape_foreigncall!(astate::AnalysisState, pc::Int, args::Vector{Any}) + nargs = length(args) + if nargs < 6 + # invalid foreigncall, just escape everything + add_conservative_changes!(astate, pc, args) + return + end + argtypes = args[3]::SimpleVector + nargs = length(argtypes) + name = args[1] + nn = normalize(name) + if isa(nn, Symbol) + boundserror_ninds = array_resize_info(nn) + if boundserror_ninds !== nothing + boundserror, ninds = boundserror_ninds + escape_array_resize!(boundserror, ninds, astate, pc, args) + return + end + if is_array_copy(nn) + escape_array_copy!(astate, pc, args) + return + elseif is_array_isassigned(nn) + escape_array_isassigned!(astate, pc, args) + return + end + # if nn === :jl_gc_add_finalizer_th + # # TODO add `FinalizerEscape` ? + # end + end + # NOTE array allocations might have been proven as nothrow (https://github.com/JuliaLang/julia/pull/43565) + nothrow = is_effect_free(astate.ir, pc) + name_info = nothrow ? ⊥ : ThrownEscape(pc) + add_escape_change!(astate, name, name_info) + add_liveness_change!(astate, name, pc) + for i = 1:nargs + # we should escape this argument if it is directly called, + # otherwise just impose ThrownEscape if not nothrow + if argtypes[i] === Any + arg_info = ⊤ + else + arg_info = nothrow ? ⊥ : ThrownEscape(pc) + end + add_escape_change!(astate, args[5+i], arg_info) + add_liveness_change!(astate, args[5+i], pc) + end + for i = (5+nargs):length(args) + arg = args[i] + add_escape_change!(astate, arg, ⊥) + add_liveness_change!(astate, arg, pc) + end +end + +normalize(@nospecialize x) = isa(x, QuoteNode) ? x.value : x + +function escape_call!(astate::AnalysisState, pc::Int, args::Vector{Any}, callinfo::Vector{Any}) + info = callinfo[pc] + if isa(info, Bool) + info && return # known to be no escape + # now cascade to the builtin handling + escape_call!(astate, pc, args) + return + elseif isa(info, CallInfo) + for linfo in info.linfos + escape_invoke!(astate, pc, args, linfo, 1) + end + # accounts for a potential escape via MethodError + info.nothrow || add_thrown_escapes!(astate, pc, args) + return + else + @assert info === missing + # if this call couldn't be analyzed, escape it conservatively + add_conservative_changes!(astate, pc, args) + end +end + +function escape_call!(astate::AnalysisState, pc::Int, args::Vector{Any}) + ir = astate.ir + ft = argextype(first(args), ir, ir.sptypes, ir.argtypes) + f = singleton_type(ft) + if isa(f, Core.IntrinsicFunction) + # XXX somehow `:call` expression can creep in here, ideally we should be able to do: + # argtypes = Any[argextype(args[i], astate.ir) for i = 2:length(args)] + argtypes = Any[] + for i = 2:length(args) + arg = args[i] + push!(argtypes, isexpr(arg, :call) ? Any : argextype(arg, ir)) + end + if intrinsic_nothrow(f, argtypes) + add_liveness_changes!(astate, pc, args, 2) + else + add_fallback_changes!(astate, pc, args, 2) + end + return # TODO accounts for pointer operations? + end + result = escape_builtin!(f, astate, pc, args) + if result === missing + # if this call hasn't been handled by any of pre-defined handlers, escape it conservatively + add_conservative_changes!(astate, pc, args) + return + elseif result === true + add_liveness_changes!(astate, pc, args, 2) + return # ThrownEscape is already checked + else + # we escape statements with the `ThrownEscape` property using the effect-freeness + # computed by `stmt_effect_free` invoked within inlining + # TODO throwness ≠ "effect-free-ness" + if is_effect_free(astate.ir, pc) + add_liveness_changes!(astate, pc, args, 2) + else + add_fallback_changes!(astate, pc, args, 2) + end + return + end +end + +escape_builtin!(@nospecialize(f), _...) = return missing + +# safe builtins +escape_builtin!(::typeof(isa), _...) = return false +escape_builtin!(::typeof(typeof), _...) = return false +escape_builtin!(::typeof(sizeof), _...) = return false +escape_builtin!(::typeof(===), _...) = return false +# not really safe, but `ThrownEscape` will be imposed later +escape_builtin!(::typeof(isdefined), _...) = return false +escape_builtin!(::typeof(throw), _...) = return false + +function escape_builtin!(::typeof(ifelse), astate::AnalysisState, pc::Int, args::Vector{Any}) + length(args) == 4 || return false + f, cond, th, el = args + ret = SSAValue(pc) + condt = argextype(cond, astate.ir) + if isa(condt, Const) && (cond = condt.val; isa(cond, Bool)) + if cond + add_alias_change!(astate, th, ret) + else + add_alias_change!(astate, el, ret) + end + else + add_alias_change!(astate, th, ret) + add_alias_change!(astate, el, ret) + end + return false +end + +function escape_builtin!(::typeof(typeassert), astate::AnalysisState, pc::Int, args::Vector{Any}) + length(args) == 3 || return false + f, obj, typ = args + ret = SSAValue(pc) + add_alias_change!(astate, ret, obj) + return false +end + +function escape_new!(astate::AnalysisState, pc::Int, args::Vector{Any}) + obj = SSAValue(pc) + objinfo = astate.estate[obj] + AliasInfo = objinfo.AliasInfo + nargs = length(args) + if isa(AliasInfo, Bool) + AliasInfo && @goto conservative_propagation + # AliasInfo of this object hasn't been analyzed yet: set AliasInfo now + typ = widenconst(argextype(obj, astate.ir)) + nfields = fieldcount_noerror(typ) + if nfields === nothing + AliasInfo = Unindexable() + @goto escape_unindexable_def + else + AliasInfo = IndexableFields(nfields) + @goto escape_indexable_def + end + elseif isa(AliasInfo, IndexableFields) + @label escape_indexable_def + # fields are known precisely: propagate escape information imposed on recorded possibilities to the exact field values + infos = AliasInfo.infos + nf = length(infos) + objinfo′ = ignore_aliasinfo(objinfo) + for i in 2:nargs + i-1 > nf && break # may happen when e.g. ϕ-node merges values with different types + arg = args[i] + add_alias_escapes!(astate, arg, infos[i-1]) + push!(infos[i-1], LocalDef(pc)) + # propagate the escape information of this object ignoring field information + add_escape_change!(astate, arg, objinfo′) + add_liveness_change!(astate, arg, pc) + end + add_escape_change!(astate, obj, EscapeInfo(objinfo, AliasInfo)) # update with new AliasInfo + elseif isa(AliasInfo, Unindexable) + @label escape_unindexable_def + # fields are known partially: propagate escape information imposed on recorded possibilities to all fields values + info = AliasInfo.info + objinfo′ = ignore_aliasinfo(objinfo) + for i in 2:nargs + arg = args[i] + add_alias_escapes!(astate, arg, info) + push!(info, LocalDef(pc)) + # propagate the escape information of this object ignoring field information + add_escape_change!(astate, arg, objinfo′) + add_liveness_change!(astate, arg, pc) + end + add_escape_change!(astate, obj, EscapeInfo(objinfo, AliasInfo)) # update with new AliasInfo + else + # this object has been used as array, but it is allocated as struct here (i.e. should throw) + # update obj's field information and just handle this case conservatively + objinfo = escape_unanalyzable_obj!(astate, obj, objinfo) + @label conservative_propagation + # the fields couldn't be analyzed precisely: propagate the entire escape information + # of this object to all its fields as the most conservative propagation + for i in 2:nargs + arg = args[i] + add_escape_change!(astate, arg, objinfo) + add_liveness_change!(astate, arg, pc) + end + end + if !is_effect_free(astate.ir, pc) + add_thrown_escapes!(astate, pc, args) + end +end + +function escape_builtin!(::typeof(tuple), astate::AnalysisState, pc::Int, args::Vector{Any}) + escape_new!(astate, pc, args) + return false +end + +function analyze_fields(ir::IRCode, @nospecialize(typ), @nospecialize(fld)) + nfields = fieldcount_noerror(typ) + if nfields === nothing + return Unindexable(), 0 + end + if isa(typ, DataType) + fldval = try_compute_field(ir, fld) + fidx = try_compute_fieldidx(typ, fldval) + else + fidx = nothing + end + if fidx === nothing + return Unindexable(), 0 + end + return IndexableFields(nfields), fidx +end + +function reanalyze_fields(ir::IRCode, AliasInfo::IndexableFields, @nospecialize(typ), @nospecialize(fld)) + nfields = fieldcount_noerror(typ) + if nfields === nothing + return merge_to_unindexable(AliasInfo), 0 + end + if isa(typ, DataType) + fldval = try_compute_field(ir, fld) + fidx = try_compute_fieldidx(typ, fldval) + else + fidx = nothing + end + if fidx === nothing + return merge_to_unindexable(AliasInfo), 0 + end + infos = AliasInfo.infos + ninfos = length(infos) + if nfields > ninfos + for _ in 1:(nfields-ninfos) + push!(infos, AInfo()) + end + end + return AliasInfo, fidx +end + +function escape_builtin!(::typeof(getfield), astate::AnalysisState, pc::Int, args::Vector{Any}) + length(args) ≥ 3 || return false + ir, estate = astate.ir, astate.estate + obj = args[2] + typ = widenconst(argextype(obj, ir)) + if hasintersect(typ, Module) # global load + add_escape_change!(astate, SSAValue(pc), ⊤) + end + if isa(obj, SSAValue) || isa(obj, Argument) + objinfo = estate[obj] + else + return false + end + AliasInfo = objinfo.AliasInfo + if isa(AliasInfo, Bool) + AliasInfo && @goto conservative_propagation + # AliasInfo of this object hasn't been analyzed yet: set AliasInfo now + AliasInfo, fidx = analyze_fields(ir, typ, args[3]) + if isa(AliasInfo, IndexableFields) + @goto record_indexable_use + else + @goto record_unindexable_use + end + elseif isa(AliasInfo, IndexableFields) + AliasInfo, fidx = reanalyze_fields(ir, AliasInfo, typ, args[3]) + isa(AliasInfo, Unindexable) && @goto record_unindexable_use + @label record_indexable_use + push!(AliasInfo.infos[fidx], LocalUse(pc)) + add_escape_change!(astate, obj, EscapeInfo(objinfo, AliasInfo)) # update with new AliasInfo + elseif isa(AliasInfo, Unindexable) + @label record_unindexable_use + push!(AliasInfo.info, LocalUse(pc)) + add_escape_change!(astate, obj, EscapeInfo(objinfo, AliasInfo)) # update with new AliasInfo + else + # this object has been used as array, but it is used as struct here (i.e. should throw) + # update obj's field information and just handle this case conservatively + objinfo = escape_unanalyzable_obj!(astate, obj, objinfo) + @label conservative_propagation + # at the extreme case, a field of `obj` may point to `obj` itself + # so add the alias change here as the most conservative propagation + add_alias_change!(astate, obj, SSAValue(pc)) + end + return false +end + +function escape_builtin!(::typeof(setfield!), astate::AnalysisState, pc::Int, args::Vector{Any}) + length(args) ≥ 4 || return false + ir, estate = astate.ir, astate.estate + obj = args[2] + val = args[4] + if isa(obj, SSAValue) || isa(obj, Argument) + objinfo = estate[obj] + else + # unanalyzable object (e.g. obj::GlobalRef): escape field value conservatively + add_escape_change!(astate, val, ⊤) + @goto add_thrown_escapes + end + AliasInfo = objinfo.AliasInfo + if isa(AliasInfo, Bool) + AliasInfo && @goto conservative_propagation + # AliasInfo of this object hasn't been analyzed yet: set AliasInfo now + typ = widenconst(argextype(obj, ir)) + AliasInfo, fidx = analyze_fields(ir, typ, args[3]) + if isa(AliasInfo, IndexableFields) + @goto escape_indexable_def + else + @goto escape_unindexable_def + end + elseif isa(AliasInfo, IndexableFields) + typ = widenconst(argextype(obj, ir)) + AliasInfo, fidx = reanalyze_fields(ir, AliasInfo, typ, args[3]) + isa(AliasInfo, Unindexable) && @goto escape_unindexable_def + @label escape_indexable_def + add_alias_escapes!(astate, val, AliasInfo.infos[fidx]) + push!(AliasInfo.infos[fidx], LocalDef(pc)) + objinfo = EscapeInfo(objinfo, AliasInfo) + add_escape_change!(astate, obj, objinfo) # update with new AliasInfo + # propagate the escape information of this object ignoring field information + add_escape_change!(astate, val, ignore_aliasinfo(objinfo)) + elseif isa(AliasInfo, Unindexable) + info = AliasInfo.info + @label escape_unindexable_def + add_alias_escapes!(astate, val, AliasInfo.info) + push!(AliasInfo.info, LocalDef(pc)) + objinfo = EscapeInfo(objinfo, AliasInfo) + add_escape_change!(astate, obj, objinfo) # update with new AliasInfo + # propagate the escape information of this object ignoring field information + add_escape_change!(astate, val, ignore_aliasinfo(objinfo)) + else + # this object has been used as array, but it is used as struct here (i.e. should throw) + # update obj's field information and just handle this case conservatively + objinfo = escape_unanalyzable_obj!(astate, obj, objinfo) + @label conservative_propagation + # the field couldn't be analyzed: alias this object to the value being assigned + # as the most conservative propagation (as required for ArgAliasing) + add_alias_change!(astate, val, obj) + end + # also propagate escape information imposed on the return value of this `setfield!` + ssainfo = estate[SSAValue(pc)] + add_escape_change!(astate, val, ssainfo) + # compute the throwness of this setfield! call here since builtin_nothrow doesn't account for that + @label add_thrown_escapes + argtypes = Any[] + for i = 2:length(args) + push!(argtypes, argextype(args[i], ir)) + end + setfield!_nothrow(argtypes) || add_thrown_escapes!(astate, pc, args, 2) + return true +end + +function escape_builtin!(::typeof(arrayref), astate::AnalysisState, pc::Int, args::Vector{Any}) + length(args) ≥ 4 || return false + # check potential thrown escapes from this arrayref call + argtypes = Any[argextype(args[i], astate.ir) for i in 2:length(args)] + boundcheckt = argtypes[1] + aryt = argtypes[2] + if !array_builtin_common_typecheck(boundcheckt, aryt, argtypes, 3) + add_thrown_escapes!(astate, pc, args, 2) + end + ary = args[3] + inbounds = isa(boundcheckt, Const) && !boundcheckt.val::Bool + inbounds || add_escape_change!(astate, ary, ThrownEscape(pc)) + # we don't track precise index information about this array and thus don't know what values + # can be referenced here: directly propagate the escape information imposed on the return + # value of this `arrayref` call to the array itself as the most conservative propagation + # but also with updated index information + estate = astate.estate + if isa(ary, SSAValue) || isa(ary, Argument) + aryinfo = estate[ary] + else + return true + end + AliasInfo = aryinfo.AliasInfo + if isa(AliasInfo, Bool) + AliasInfo && @goto conservative_propagation + # AliasInfo of this array hasn't been analyzed yet: set AliasInfo now + idx = array_nd_index(astate, ary, args[4:end]) + if isa(idx, Int) + AliasInfo = IndexableElements(IdDict{Int,AInfo}()) + @goto record_indexable_use + end + AliasInfo = Unindexable() + @goto record_unindexable_use + elseif isa(AliasInfo, IndexableElements) + idx = array_nd_index(astate, ary, args[4:end]) + if !isa(idx, Int) + AliasInfo = merge_to_unindexable(AliasInfo) + @goto record_unindexable_use + end + @label record_indexable_use + info = get!(()->AInfo(), AliasInfo.infos, idx) + push!(info, LocalUse(pc)) + add_escape_change!(astate, ary, EscapeInfo(aryinfo, AliasInfo)) # update with new AliasInfo + elseif isa(AliasInfo, Unindexable) + @label record_unindexable_use + push!(AliasInfo.info, LocalUse(pc)) + add_escape_change!(astate, ary, EscapeInfo(aryinfo, AliasInfo)) # update with new AliasInfo + else + # this object has been used as struct, but it is used as array here (thus should throw) + # update ary's element information and just handle this case conservatively + aryinfo = escape_unanalyzable_obj!(astate, ary, aryinfo) + @label conservative_propagation + # at the extreme case, an element of `ary` may point to `ary` itself + # so add the alias change here as the most conservative propagation + add_alias_change!(astate, ary, SSAValue(pc)) + end + return true +end + +function escape_builtin!(::typeof(arrayset), astate::AnalysisState, pc::Int, args::Vector{Any}) + length(args) ≥ 5 || return false + # check potential escapes from this arrayset call + # NOTE here we essentially only need to account for TypeError, assuming that + # UndefRefError or BoundsError don't capture any of the arguments here + argtypes = Any[argextype(args[i], astate.ir) for i in 2:length(args)] + boundcheckt = argtypes[1] + aryt = argtypes[2] + valt = argtypes[3] + if !(array_builtin_common_typecheck(boundcheckt, aryt, argtypes, 4) && + arrayset_typecheck(aryt, valt)) + add_thrown_escapes!(astate, pc, args, 2) + end + ary = args[3] + val = args[4] + inbounds = isa(boundcheckt, Const) && !boundcheckt.val::Bool + inbounds || add_escape_change!(astate, ary, ThrownEscape(pc)) + # we don't track precise index information about this array and won't record what value + # is being assigned here: directly propagate the escape information of this array to + # the value being assigned as the most conservative propagation + estate = astate.estate + if isa(ary, SSAValue) || isa(ary, Argument) + aryinfo = estate[ary] + else + # unanalyzable object (e.g. obj::GlobalRef): escape field value conservatively + add_escape_change!(astate, val, ⊤) + return true + end + AliasInfo = aryinfo.AliasInfo + if isa(AliasInfo, Bool) + AliasInfo && @goto conservative_propagation + # AliasInfo of this array hasn't been analyzed yet: set AliasInfo now + idx = array_nd_index(astate, ary, args[5:end]) + if isa(idx, Int) + AliasInfo = IndexableElements(IdDict{Int,AInfo}()) + @goto escape_indexable_def + end + AliasInfo = Unindexable() + @goto escape_unindexable_def + elseif isa(AliasInfo, IndexableElements) + idx = array_nd_index(astate, ary, args[5:end]) + if !isa(idx, Int) + AliasInfo = merge_to_unindexable(AliasInfo) + @goto escape_unindexable_def + end + @label escape_indexable_def + info = get!(()->AInfo(), AliasInfo.infos, idx) + add_alias_escapes!(astate, val, info) + push!(info, LocalDef(pc)) + add_escape_change!(astate, ary, EscapeInfo(aryinfo, AliasInfo)) # update with new AliasInfo + # propagate the escape information of this array ignoring elements information + add_escape_change!(astate, val, ignore_aliasinfo(aryinfo)) + elseif isa(AliasInfo, Unindexable) + @label escape_unindexable_def + add_alias_escapes!(astate, val, AliasInfo.info) + push!(AliasInfo.info, LocalDef(pc)) + add_escape_change!(astate, ary, EscapeInfo(aryinfo, AliasInfo)) # update with new AliasInfo + # propagate the escape information of this array ignoring elements information + add_escape_change!(astate, val, ignore_aliasinfo(aryinfo)) + else + # this object has been used as struct, but it is used as array here (thus should throw) + # update ary's element information and just handle this case conservatively + aryinfo = escape_unanalyzable_obj!(astate, ary, aryinfo) + @label conservative_propagation + add_alias_change!(astate, val, ary) + end + # also propagate escape information imposed on the return value of this `arrayset` + ssainfo = estate[SSAValue(pc)] + add_escape_change!(astate, ary, ssainfo) + return true +end + +# NOTE this function models and thus should be synced with the implementation of: +# size_t array_nd_index(jl_array_t *a, jl_value_t **args, size_t nidxs, ...) +function array_nd_index(astate::AnalysisState, @nospecialize(ary), args::Vector{Any}, nidxs::Int = length(args)) + isa(ary, SSAValue) || return nothing + aryid = ary.id + arrayinfo = astate.estate.arrayinfo + isa(arrayinfo, ArrayInfo) || return nothing + haskey(arrayinfo, aryid) || return nothing + dims = arrayinfo[aryid] + local i = 0 + local k, stride = 0, 1 + local nd = length(dims) + while k < nidxs + arg = args[k+1] + argval = argextype(arg, astate.ir) + isa(argval, Const) || return nothing + argval = argval.val + isa(argval, Int) || return nothing + ii = argval - 1 + i += ii * stride + d = k ≥ nd ? 1 : dims[k+1] + k < nidxs - 1 && ii ≥ d && return nothing # BoundsError + stride *= d + k += 1 + end + while k < nd + stride *= dims[k+1] + k += 1 + end + i ≥ stride && return nothing # BoundsError + return i +end + +function escape_builtin!(::typeof(arraysize), astate::AnalysisState, pc::Int, args::Vector{Any}) + length(args) == 3 || return false + ary = args[2] + dim = args[3] + if !arraysize_typecheck(ary, dim, astate.ir) + add_escape_change!(astate, ary, ThrownEscape(pc)) + add_escape_change!(astate, dim, ThrownEscape(pc)) + end + # NOTE we may still see "arraysize: dimension out of range", but it doesn't capture anything + return true +end + +function arraysize_typecheck(@nospecialize(ary), @nospecialize(dim), ir::IRCode) + aryt = argextype(ary, ir) + aryt ⊑ Array || return false + dimt = argextype(dim, ir) + dimt ⊑ Int || return false + return true +end + +# returns nothing if this isn't array resizing operation, +# otherwise returns true if it can throw BoundsError and false if not +function array_resize_info(name::Symbol) + if name === :jl_array_grow_beg || name === :jl_array_grow_end + return false, 1 + elseif name === :jl_array_del_beg || name === :jl_array_del_end + return true, 1 + elseif name === :jl_array_grow_at || name === :jl_array_del_at + return true, 2 + else + return nothing + end +end + +# NOTE may potentially throw "cannot resize array with shared data" error, +# but just ignore it since it doesn't capture anything +function escape_array_resize!(boundserror::Bool, ninds::Int, + astate::AnalysisState, pc::Int, args::Vector{Any}) + length(args) ≥ 6+ninds || return add_fallback_changes!(astate, pc, args) + ary = args[6] + aryt = argextype(ary, astate.ir) + aryt ⊑ Array || return add_fallback_changes!(astate, pc, args) + for i in 1:ninds + ind = args[i+6] + indt = argextype(ind, astate.ir) + indt ⊑ Integer || return add_fallback_changes!(astate, pc, args) + end + if boundserror + # this array resizing can potentially throw `BoundsError`, impose it now + add_escape_change!(astate, ary, ThrownEscape(pc)) + end + # give up indexing analysis whenever we see array resizing + # (since we track array dimensions only globally) + mark_unindexable!(astate, ary) + add_liveness_changes!(astate, pc, args, 6) +end + +function mark_unindexable!(astate::AnalysisState, @nospecialize(ary)) + isa(ary, SSAValue) || return + aryinfo = astate.estate[ary] + AliasInfo = aryinfo.AliasInfo + isa(AliasInfo, IndexableElements) || return + AliasInfo = merge_to_unindexable(AliasInfo) + add_escape_change!(astate, ary, EscapeInfo(aryinfo, AliasInfo)) +end + +is_array_copy(name::Symbol) = name === :jl_array_copy + +# FIXME this implementation is very conservative, improve the accuracy and solve broken test cases +function escape_array_copy!(astate::AnalysisState, pc::Int, args::Vector{Any}) + length(args) ≥ 6 || return add_fallback_changes!(astate, pc, args) + ary = args[6] + aryt = argextype(ary, astate.ir) + aryt ⊑ Array || return add_fallback_changes!(astate, pc, args) + if isa(ary, SSAValue) || isa(ary, Argument) + newary = SSAValue(pc) + aryinfo = astate.estate[ary] + newaryinfo = astate.estate[newary] + add_escape_change!(astate, newary, aryinfo) + add_escape_change!(astate, ary, newaryinfo) + end + add_liveness_changes!(astate, pc, args, 6) +end + +is_array_isassigned(name::Symbol) = name === :jl_array_isassigned + +function escape_array_isassigned!(astate::AnalysisState, pc::Int, args::Vector{Any}) + if !array_isassigned_nothrow(args, astate.ir) + add_thrown_escapes!(astate, pc, args) + end + add_liveness_changes!(astate, pc, args, 6) +end + +function array_isassigned_nothrow(args::Vector{Any}, src::IRCode) + # if !validate_foreigncall_args(args, + # :jl_array_isassigned, Cint, svec(Any,Csize_t), 0, :ccall) + # return false + # end + length(args) ≥ 7 || return false + arytype = argextype(args[6], src) + arytype ⊑ Array || return false + idxtype = argextype(args[7], src) + idxtype ⊑ Csize_t || return false + return true +end + +# # COMBAK do we want to enable this (and also backport this to Base for array allocations?) +# import Core.Compiler: Cint, svec +# function validate_foreigncall_args(args::Vector{Any}, +# name::Symbol, @nospecialize(rt), argtypes::SimpleVector, nreq::Int, convension::Symbol) +# length(args) ≥ 5 || return false +# normalize(args[1]) === name || return false +# args[2] === rt || return false +# args[3] === argtypes || return false +# args[4] === vararg || return false +# normalize(args[5]) === convension || return false +# return true +# end + +if isdefined(Core, :ImmutableArray) + +import Core: ImmutableArray, arrayfreeze, mutating_arrayfreeze, arraythaw + +escape_builtin!(::typeof(arrayfreeze), astate::AnalysisState, pc::Int, args::Vector{Any}) = + is_safe_immutable_array_op(Array, astate, args) +escape_builtin!(::typeof(mutating_arrayfreeze), astate::AnalysisState, pc::Int, args::Vector{Any}) = + is_safe_immutable_array_op(Array, astate, args) +escape_builtin!(::typeof(arraythaw), astate::AnalysisState, pc::Int, args::Vector{Any}) = + is_safe_immutable_array_op(ImmutableArray, astate, args) +function is_safe_immutable_array_op(@nospecialize(arytype), astate::AnalysisState, args::Vector{Any}) + length(args) == 2 || return false + argextype(args[2], astate.ir) ⊑ arytype || return false + return true +end + +end # if isdefined(Core, :ImmutableArray) + +if _TOP_MOD !== Core.Compiler + # NOTE define fancy package utilities when developing EA as an external package + include(@__MODULE__, "EAUtils.jl") + using .EAUtils + export code_escapes, @code_escapes, __clear_cache! +end + +end # baremodule EscapeAnalysis diff --git a/base/compiler/ssair/EscapeAnalysis/disjoint_set.jl b/base/compiler/ssair/EscapeAnalysis/disjoint_set.jl new file mode 100644 index 0000000000000..915bc214d7c3c --- /dev/null +++ b/base/compiler/ssair/EscapeAnalysis/disjoint_set.jl @@ -0,0 +1,143 @@ +# A disjoint set implementation adapted from +# https://github.com/JuliaCollections/DataStructures.jl/blob/f57330a3b46f779b261e6c07f199c88936f28839/src/disjoint_set.jl +# under the MIT license: https://github.com/JuliaCollections/DataStructures.jl/blob/master/License.md + +# imports +import ._TOP_MOD: + length, + eltype, + union!, + push! +# usings +import ._TOP_MOD: + OneTo, collect, zero, zeros, one, typemax + +# Disjoint-Set + +############################################################ +# +# A forest of disjoint sets of integers +# +# Since each element is an integer, we can use arrays +# instead of dictionary (for efficiency) +# +# Disjoint sets over other key types can be implemented +# based on an IntDisjointSet through a map from the key +# to an integer index +# +############################################################ + +_intdisjointset_bounds_err_msg(T) = "the maximum number of elements in IntDisjointSet{$T} is $(typemax(T))" + +""" + IntDisjointSet{T<:Integer}(n::Integer) + +A forest of disjoint sets of integers, which is a data structure +(also called a union–find data structure or merge–find set) +that tracks a set of elements partitioned +into a number of disjoint (non-overlapping) subsets. +""" +mutable struct IntDisjointSet{T<:Integer} + parents::Vector{T} + ranks::Vector{T} + ngroups::T +end + +IntDisjointSet(n::T) where {T<:Integer} = IntDisjointSet{T}(collect(OneTo(n)), zeros(T, n), n) +IntDisjointSet{T}(n::Integer) where {T<:Integer} = IntDisjointSet{T}(collect(OneTo(T(n))), zeros(T, T(n)), T(n)) +length(s::IntDisjointSet) = length(s.parents) + +""" + num_groups(s::IntDisjointSet) + +Get a number of groups. +""" +num_groups(s::IntDisjointSet) = s.ngroups +eltype(::Type{IntDisjointSet{T}}) where {T<:Integer} = T + +# find the root element of the subset that contains x +# path compression is implemented here +function find_root_impl!(parents::Vector{T}, x::Integer) where {T<:Integer} + p = parents[x] + @inbounds if parents[p] != p + parents[x] = p = _find_root_impl!(parents, p) + end + return p +end + +# unsafe version of the above +function _find_root_impl!(parents::Vector{T}, x::Integer) where {T<:Integer} + @inbounds p = parents[x] + @inbounds if parents[p] != p + parents[x] = p = _find_root_impl!(parents, p) + end + return p +end + +""" + find_root!(s::IntDisjointSet{T}, x::T) + +Find the root element of the subset that contains an member `x`. +Path compression happens here. +""" +find_root!(s::IntDisjointSet{T}, x::T) where {T<:Integer} = find_root_impl!(s.parents, x) + +""" + in_same_set(s::IntDisjointSet{T}, x::T, y::T) + +Returns `true` if `x` and `y` belong to the same subset in `s`, and `false` otherwise. +""" +in_same_set(s::IntDisjointSet{T}, x::T, y::T) where {T<:Integer} = find_root!(s, x) == find_root!(s, y) + +""" + union!(s::IntDisjointSet{T}, x::T, y::T) + +Merge the subset containing `x` and that containing `y` into one +and return the root of the new set. +""" +function union!(s::IntDisjointSet{T}, x::T, y::T) where {T<:Integer} + parents = s.parents + xroot = find_root_impl!(parents, x) + yroot = find_root_impl!(parents, y) + return xroot != yroot ? root_union!(s, xroot, yroot) : xroot +end + +""" + root_union!(s::IntDisjointSet{T}, x::T, y::T) + +Form a new set that is the union of the two sets whose root elements are +`x` and `y` and return the root of the new set. +Assume `x ≠ y` (unsafe). +""" +function root_union!(s::IntDisjointSet{T}, x::T, y::T) where {T<:Integer} + parents = s.parents + rks = s.ranks + @inbounds xrank = rks[x] + @inbounds yrank = rks[y] + + if xrank < yrank + x, y = y, x + elseif xrank == yrank + rks[x] += one(T) + end + @inbounds parents[y] = x + s.ngroups -= one(T) + return x +end + +""" + push!(s::IntDisjointSet{T}) + +Make a new subset with an automatically chosen new element `x`. +Returns the new element. Throw an `ArgumentError` if the +capacity of the set would be exceeded. +""" +function push!(s::IntDisjointSet{T}) where {T<:Integer} + l = length(s) + l < typemax(T) || throw(ArgumentError(_intdisjointset_bounds_err_msg(T))) + x = l + one(T) + push!(s.parents, x) + push!(s.ranks, zero(T)) + s.ngroups += one(T) + return x +end diff --git a/base/compiler/ssair/EscapeAnalysis/interprocedural.jl b/base/compiler/ssair/EscapeAnalysis/interprocedural.jl new file mode 100644 index 0000000000000..9880c13db4ad1 --- /dev/null +++ b/base/compiler/ssair/EscapeAnalysis/interprocedural.jl @@ -0,0 +1,151 @@ +# TODO this file contains many duplications with the inlining analysis code, factor them out + +import Core.Compiler: + MethodInstance, InferenceResult, Signature, ConstResult, + MethodResultPure, MethodMatchInfo, UnionSplitInfo, ConstCallInfo, InvokeCallInfo, + call_sig, argtypes_to_type, is_builtin, is_return_type, istopfunction, validate_sparams, + specialize_method, invoke_rewrite + +const Linfo = Union{MethodInstance,InferenceResult} +struct CallInfo + linfos::Vector{Linfo} + nothrow::Bool +end + +function resolve_call(ir::IRCode, stmt::Expr, @nospecialize(info)) + sig = call_sig(ir, stmt) + if sig === nothing + return missing + end + # TODO handle _apply_iterate + if is_builtin(sig) && sig.f !== invoke + return false + end + # handling corresponding to late_inline_special_case! + (; f, argtypes) = sig + if length(argtypes) == 3 && istopfunction(f, :!==) + return true + elseif length(argtypes) == 3 && istopfunction(f, :(>:)) + return true + elseif f === TypeVar && 2 ≤ length(argtypes) ≤ 4 && (argtypes[2] ⊑ Symbol) + return true + elseif f === UnionAll && length(argtypes) == 3 && (argtypes[2] ⊑ TypeVar) + return true + elseif is_return_type(f) + return true + end + if info isa MethodResultPure + return true + elseif info === false + return missing + end + # TODO handle OpaqueClosureCallInfo + if sig.f === invoke + isa(info, InvokeCallInfo) || return missing + return analyze_invoke_call(sig, info) + elseif isa(info, ConstCallInfo) + return analyze_const_call(sig, info) + elseif isa(info, MethodMatchInfo) + infos = MethodMatchInfo[info] + elseif isa(info, UnionSplitInfo) + infos = info.matches + else # isa(info, ReturnTypeCallInfo), etc. + return missing + end + return analyze_call(sig, infos) +end + +function analyze_invoke_call(sig::Signature, info::InvokeCallInfo) + match = info.match + if !match.fully_covers + # TODO: We could union split out the signature check and continue on + return missing + end + result = info.result + if isa(result, InferenceResult) + return CallInfo(Linfo[result], true) + else + argtypes = invoke_rewrite(sig.argtypes) + mi = analyze_match(match, length(argtypes)) + mi === nothing && return missing + return CallInfo(Linfo[mi], true) + end +end + +function analyze_const_call(sig::Signature, cinfo::ConstCallInfo) + linfos = Linfo[] + (; call, results) = cinfo + infos = isa(call, MethodMatchInfo) ? MethodMatchInfo[call] : call.matches + local nothrow = true # required to account for potential escape via MethodError + local j = 0 + for i in 1:length(infos) + meth = infos[i].results + nothrow &= !meth.ambig + nmatch = Core.Compiler.length(meth) + if nmatch == 0 # No applicable methods + # mark this call may potentially throw, and the try next union split + nothrow = false + continue + end + for i = 1:nmatch + j += 1 + result = results[j] + match = Core.Compiler.getindex(meth, i) + if result === nothing + mi = analyze_match(match, length(sig.argtypes)) + mi === nothing && return missing + push!(linfos, mi) + elseif isa(result, ConstResult) + # TODO we may want to feedback information that this call always throws if !isdefined(result, :result) + push!(linfos, result.mi) + else + push!(linfos, result) + end + nothrow &= match.fully_covers + end + end + return CallInfo(linfos, nothrow) +end + +function analyze_call(sig::Signature, infos::Vector{MethodMatchInfo}) + linfos = Linfo[] + local nothrow = true # required to account for potential escape via MethodError + for i in 1:length(infos) + meth = infos[i].results + nothrow &= !meth.ambig + nmatch = Core.Compiler.length(meth) + if nmatch == 0 # No applicable methods + # mark this call may potentially throw, and the try next union split + nothrow = false + continue + end + for i = 1:nmatch + match = Core.Compiler.getindex(meth, i) + mi = analyze_match(match, length(sig.argtypes)) + mi === nothing && return missing + push!(linfos, mi) + nothrow &= match.fully_covers + end + end + return CallInfo(linfos, nothrow) +end + +function analyze_match(match::MethodMatch, npassedargs::Int) + method = match.method + na = Int(method.nargs) + if na != npassedargs && !(na > 0 && method.isva) + # we have a method match only because an earlier + # inference step shortened our call args list, even + # though we have too many arguments to actually + # call this function + return nothing + end + + # Bail out if any static parameters are left as TypeVar + # COMBAK is this needed for escape analysis? + validate_sparams(match.sparams) || return nothing + + # See if there exists a specialization for this method signature + mi = specialize_method(match; preexisting=true) # Union{Nothing, MethodInstance} + return mi +end diff --git a/base/compiler/ssair/driver.jl b/base/compiler/ssair/driver.jl index e54a09fe351b3..7329dafcb1121 100644 --- a/base/compiler/ssair/driver.jl +++ b/base/compiler/ssair/driver.jl @@ -14,8 +14,10 @@ include("compiler/ssair/basicblock.jl") include("compiler/ssair/domtree.jl") include("compiler/ssair/ir.jl") include("compiler/ssair/slot2ssa.jl") -include("compiler/ssair/passes.jl") include("compiler/ssair/inlining.jl") include("compiler/ssair/verify.jl") include("compiler/ssair/legacy.jl") -#@isdefined(Base) && include("compiler/ssair/show.jl") +function try_compute_field end # imported by EscapeAnalysis +include("compiler/ssair/EscapeAnalysis/EscapeAnalysis.jl") +include("compiler/ssair/passes.jl") +# @isdefined(Base) && include("compiler/ssair/show.jl") diff --git a/base/compiler/tfuncs.jl b/base/compiler/tfuncs.jl index 2238d43d65b27..e67594f196c90 100644 --- a/base/compiler/tfuncs.jl +++ b/base/compiler/tfuncs.jl @@ -1288,7 +1288,7 @@ function apply_type_nothrow(argtypes::Array{Any, 1}, @nospecialize(rt)) return false end elseif (isa(ai, Const) && isa(ai.val, Type)) || isconstType(ai) - ai = isa(ai, Const) ? ai.val : ai.parameters[1] + ai = isa(ai, Const) ? ai.val : (ai::DataType).parameters[1] if has_free_typevars(u.var.lb) || has_free_typevars(u.var.ub) return false end diff --git a/base/compiler/typeinfer.jl b/base/compiler/typeinfer.jl index 03ba383de4f61..d600df1dbb0a1 100644 --- a/base/compiler/typeinfer.jl +++ b/base/compiler/typeinfer.jl @@ -313,7 +313,7 @@ function CodeInstance( widenconst(result_type), rettype_const, inferred_result, const_flags, first(valid_worlds), last(valid_worlds), # TODO: Actually do something with non-IPO effects - encode_effects(result.ipo_effects), encode_effects(result.ipo_effects), + encode_effects(result.ipo_effects), encode_effects(result.ipo_effects), result.argescapes, relocatability) end diff --git a/base/compiler/types.jl b/base/compiler/types.jl index cebb560a2010b..c1231691c0fa5 100644 --- a/base/compiler/types.jl +++ b/base/compiler/types.jl @@ -120,16 +120,18 @@ mutable struct InferenceResult linfo::MethodInstance argtypes::Vector{Any} overridden_by_const::BitVector - result # ::Type, or InferenceState if WIP - src #::Union{CodeInfo, OptimizationState, Nothing} # if inferred copy is available + result # ::Type, or InferenceState if WIP + src # ::Union{CodeInfo, OptimizationState} if inferred copy is available, nothing otherwise valid_worlds::WorldRange # if inference and optimization is finished - ipo_effects::Effects # if inference is finished - effects::Effects # if optimization is finished + ipo_effects::Effects # if inference is finished + effects::Effects # if optimization is finished + argescapes # ::ArgEscapeCache if optimized, nothing otherwise function InferenceResult(linfo::MethodInstance, arginfo#=::Union{Nothing,Tuple{ArgInfo,InferenceState}}=# = nothing, va_override::Bool = false) argtypes, overridden_by_const = matching_cache_argtypes(linfo, arginfo, va_override) - return new(linfo, argtypes, overridden_by_const, Any, nothing, WorldRange(), Effects(), Effects()) + return new(linfo, argtypes, overridden_by_const, Any, nothing, + WorldRange(), Effects(), Effects(), nothing) end end diff --git a/base/compiler/utilities.jl b/base/compiler/utilities.jl index e97441495f16b..9b1106e964919 100644 --- a/base/compiler/utilities.jl +++ b/base/compiler/utilities.jl @@ -19,6 +19,8 @@ function _any(@nospecialize(f), a) end return false end +any(@nospecialize(f), itr) = _any(f, itr) +any(itr) = _any(identity, itr) function _all(@nospecialize(f), a) for x in a @@ -26,6 +28,8 @@ function _all(@nospecialize(f), a) end return true end +all(@nospecialize(f), itr) = _all(f, itr) +all(itr) = _all(identity, itr) function contains_is(itr, @nospecialize(x)) for y in itr diff --git a/doc/make.jl b/doc/make.jl index 8be3b807400d1..bb7ef83048178 100644 --- a/doc/make.jl +++ b/doc/make.jl @@ -148,6 +148,7 @@ DevDocs = [ "devdocs/require.md", "devdocs/inference.md", "devdocs/ssair.md", + "devdocs/EscapeAnalysis.md", "devdocs/gc-sa.md", ], "Developing/debugging Julia's C code" => [ diff --git a/doc/src/devdocs/EscapeAnalysis.md b/doc/src/devdocs/EscapeAnalysis.md new file mode 100644 index 0000000000000..de09dfec48c42 --- /dev/null +++ b/doc/src/devdocs/EscapeAnalysis.md @@ -0,0 +1,363 @@ +`Core.Compiler.EscapeAnalysis` is a compiler utility module that aims to analyze +escape information of [Julia's SSA-form IR](@ref Julia-SSA-form-IR) a.k.a. `IRCode`. + +You can give a try to the escape analysis by loading the `EAUtils.jl` utility script that +define the convenience entries `code_escapes` and `@code_escapes` for testing and debugging purposes: +```@repl EAUtils +include(normpath(Sys.BINDIR::String, "..", "share", "julia", "test", "compiler", "EscapeAnalysis", "EAUtils.jl")) +using EAUtils + +mutable struct SafeRef{T} + x::T +end +Base.getindex(x::SafeRef) = x.x; +Base.setindex!(x::SafeRef, v) = x.x = v; +Base.isassigned(x::SafeRef) = true; +get′(x) = isassigned(x) ? x[] : throw(x); + +result = code_escapes((String,String,String,String)) do s1, s2, s3, s4 + r1 = Ref(s1) + r2 = Ref(s2) + r3 = SafeRef(s3) + try + s1 = get′(r1) + ret = sizeof(s1) + catch err + global GV = err # will definitely escape `r1` + end + s2 = get′(r2) # still `r2` doesn't escape fully + s3 = get′(r3) # still `r3` doesn't escape fully + s4 = sizeof(s4) # the argument `s4` doesn't escape here + return s2, s3, s4 +end +``` + +The symbols in the side of each call argument and SSA statements represents the following meaning: +- `◌` (plain): this value is not analyzed because escape information of it won't be used anyway (when the object is `isbitstype` for example) +- `✓` (green or cyan): this value never escapes (`has_no_escape(result.state[x])` holds), colored blue if it has arg escape also (`has_arg_escape(result.state[x])` holds) +- `↑` (blue or yellow): this value can escape to the caller via return (`has_return_escape(result.state[x])` holds), colored yellow if it has unhandled thrown escape also (`has_thrown_escape(result.state[x])` holds) +- `X` (red): this value can escape to somewhere the escape analysis can't reason about like escapes to a global memory (`has_all_escape(result.state[x])` holds) +- `*` (bold): this value's escape state is between the `ReturnEscape` and `AllEscape` in the partial order of [`EscapeInfo`](@ref Core.Compiler.EscapeAnalysis.EscapeInfo), colored yellow if it has unhandled thrown escape also (`has_thrown_escape(result.state[x])` holds) +- `′`: this value has additional object field / array element information in its `AliasInfo` property + +Escape information of each call argument and SSA value can be inspected programmatically as like: +```@repl EAUtils +result.state[Core.Argument(3)] # get EscapeInfo of `s2` + +result.state[Core.SSAValue(3)] # get EscapeInfo of `r3` +``` + +## Analysis Design + +### Lattice Design + +`EscapeAnalysis` is implemented as a [data-flow analysis](https://en.wikipedia.org/wiki/Data-flow_analysis) +that works on a lattice of `x::EscapeInfo`, which is composed of the following properties: +- `x.Analyzed::Bool`: not formally part of the lattice, only indicates `x` has not been analyzed or not +- `x.ReturnEscape::BitSet`: records SSA statements where `x` can escape to the caller via return +- `x.ThrownEscape::BitSet`: records SSA statements where `x` can be thrown as exception + (used for the [exception handling](@ref EA-Exception-Handling) described below) +- `x.AliasInfo`: maintains all possible values that can be aliased to fields or array elements of `x` + (used for the [alias analysis](@ref EA-Alias-Analysis) described below) +- `x.ArgEscape::Int` (not implemented yet): indicates it will escape to the caller through + `setfield!` on argument(s) + +These attributes can be combined to create a partial lattice that has a finite height, given +the invariant that an input program has a finite number of statements, which is assured by Julia's semantics. +The clever part of this lattice design is that it enables a simpler implementation of +lattice operations by allowing them to handle each lattice property separately[^LatticeDesign]. + +### Backward Escape Propagation + +This escape analysis implementation is based on the data-flow algorithm described in the paper[^MM02]. +The analysis works on the lattice of `EscapeInfo` and transitions lattice elements from the +bottom to the top until every lattice element gets converged to a fixed point by maintaining +a (conceptual) working set that contains program counters corresponding to remaining SSA +statements to be analyzed. The analysis manages a single global state that tracks +`EscapeInfo` of each argument and SSA statement, but also note that some flow-sensitivity +is encoded as program counters recorded in `EscapeInfo`'s `ReturnEscape` property, +which can be combined with domination analysis later to reason about flow-sensitivity if necessary. + +One distinctive design of this escape analysis is that it is fully _backward_, +i.e. escape information flows _from usages to definitions_. +For example, in the code snippet below, EA first analyzes the statement `return %1` and +imposes `ReturnEscape` on `%1` (corresponding to `obj`), and then it analyzes +`%1 = %new(Base.RefValue{String, _2}))` and propagates the `ReturnEscape` imposed on `%1` +to the call argument `_2` (corresponding to `s`): +```@repl EAUtils +code_escapes((String,)) do s + obj = Ref(s) + return obj +end +``` + +The key observation here is that this backward analysis allows escape information to flow +naturally along the use-def chain rather than control-flow[^BackandForth]. +As a result this scheme enables a simple implementation of escape analysis, +e.g. `PhiNode` for example can be handled simply by propagating escape information +imposed on a `PhiNode` to its predecessor values: +```@repl EAUtils +code_escapes((Bool, String, String)) do cnd, s, t + if cnd + obj = Ref(s) + else + obj = Ref(t) + end + return obj +end +``` + +### [Alias Analysis](@id EA-Alias-Analysis) + +`EscapeAnalysis` implements a backward field analysis in order to reason about escapes +imposed on object fields with certain accuracy, +and `x::EscapeInfo`'s `x.AliasInfo` property exists for this purpose. +It records all possible values that can be aliased to fields of `x` at "usage" sites, +and then the escape information of that recorded values are propagated to the actual field values later at "definition" sites. +More specifically, the analysis records a value that may be aliased to a field of object by analyzing `getfield` call, +and then it propagates its escape information to the field when analyzing `%new(...)` expression or `setfield!` call[^Dynamism]. +```@repl EAUtils +code_escapes((String,)) do s + obj = SafeRef("init") + obj[] = s + v = obj[] + return v +end +``` +In the example above, `ReturnEscape` imposed on `%3` (corresponding to `v`) is _not_ directly +propagated to `%1` (corresponding to `obj`) but rather that `ReturnEscape` is only propagated +to `_2` (corresponding to `s`). Here `%3` is recorded in `%1`'s `AliasInfo` property as +it can be aliased to the first field of `%1`, and then when analyzing `Base.setfield!(%1, :x, _2)::String`, +that escape information is propagated to `_2` but not to `%1`. + +So `EscapeAnalysis` tracks which IR elements can be aliased across a `getfield`-`%new`/`setfield!` chain +in order to analyze escapes of object fields, but actually this alias analysis needs to be +generalized to handle other IR elements as well. This is because in Julia IR the same +object is sometimes represented by different IR elements and so we should make sure that those +different IR elements that actually can represent the same object share the same escape information. +IR elements that return the same object as their operand(s), such as `PiNode` and `typeassert`, +can cause that IR-level aliasing and thus requires escape information imposed on any of such +aliased values to be shared between them. +More interestingly, it is also needed for correctly reasoning about mutations on `PhiNode`. +Let's consider the following example: +```@repl EAUtils +code_escapes((Bool, String,)) do cond, x + if cond + ϕ2 = ϕ1 = SafeRef("foo") + else + ϕ2 = ϕ1 = SafeRef("bar") + end + ϕ2[] = x + y = ϕ1[] + return y +end +``` +`ϕ1 = %5` and `ϕ2 = %6` are aliased and thus `ReturnEscape` imposed on `%8 = Base.getfield(%6, :x)::String` (corresponding to `y = ϕ1[]`) +needs to be propagated to `Base.setfield!(%5, :x, _3)::String` (corresponding to `ϕ2[] = x`). +In order for such escape information to be propagated correctly, the analysis should recognize that +the _predecessors_ of `ϕ1` and `ϕ2` can be aliased as well and equalize their escape information. + +One interesting property of such aliasing information is that it is not known at "usage" site +but can only be derived at "definition" site (as aliasing is conceptually equivalent to assignment), +and thus it doesn't naturally fit in a backward analysis. In order to efficiently propagate escape +information between related values, EscapeAnalysis.jl uses an approach inspired by the escape +analysis algorithm explained in an old JVM paper[^JVM05]. That is, in addition to managing +escape lattice elements, the analysis also maintains an "equi"-alias set, a disjoint set of +aliased arguments and SSA statements. The alias set manages values that can be aliased to +each other and allows escape information imposed on any of such aliased values to be equalized +between them. + +### [Array Analysis](@id EA-Array-Analysis) + +The alias analysis for object fields described above can also be generalized to analyze array operations. +`EscapeAnalysis` implements handlings for various primitive array operations so that it can propagate +escapes via `arrayref`-`arrayset` use-def chain and does not escape allocated arrays too conservatively: +```@repl EAUtils +code_escapes((String,)) do s + ary = Any[] + push!(ary, SafeRef(s)) + return ary[1], length(ary) +end +``` +In the above example `EscapeAnalysis` understands that `%20` and `%2` (corresponding to the allocated object `SafeRef(s)`) +are aliased via the `arrayset`-`arrayref` chain and imposes `ReturnEscape` on them, +but not impose it on the allocated array `%1` (corresponding to `ary`). +`EscapeAnalysis` still imposes `ThrownEscape` on `ary` since it also needs to account for +potential escapes via `BoundsError`, but also note that such unhandled `ThrownEscape` can +often be ignored when optimizing the `ary` allocation. + +Furthermore, in cases when array index information as well as array dimensions can be known _precisely_, +`EscapeAnalysis` is able to even reason about "per-element" aliasing via `arrayref`-`arrayset` chain, +as `EscapeAnalysis` does "per-field" alias analysis for objects: +```@repl EAUtils +code_escapes((String,String)) do s, t + ary = Vector{Any}(undef, 2) + ary[1] = SafeRef(s) + ary[2] = SafeRef(t) + return ary[1], length(ary) +end +``` +Note that `ReturnEscape` is only imposed on `%2` (corresponding to `SafeRef(s)`) but not on `%4` (corresponding to `SafeRef(t)`). +This is because the allocated array's dimension and indices involved with all `arrayref`/`arrayset` +operations are available as constant information and `EscapeAnalysis` can understand that +`%6` is aliased to `%2` but never be aliased to `%4`. +In this kind of case, the succeeding optimization passes will be able to +replace `Base.arrayref(true, %1, 1)::Any` with `%2` (a.k.a. "load-forwarding") and +eventually eliminate the allocation of array `%1` entirely (a.k.a. "scalar-replacement"). + +When compared to object field analysis, where an access to object field can be analyzed trivially +using type information derived by inference, array dimension isn't encoded as type information +and so we need an additional analysis to derive that information. `EscapeAnalysis` at this moment +first does an additional simple linear scan to analyze dimensions of allocated arrays before +firing up the main analysis routine so that the succeeding escape analysis can precisely +analyze operations on those arrays. + +However, such precise "per-element" alias analysis is often hard. +Essentially, the main difficulty inherit to array is that array dimension and index are often non-constant: +- loop often produces loop-variant, non-constant array indices +- (specific to vectors) array resizing changes array dimension and invalidates its constant-ness + +Let's discuss those difficulties with concrete examples. + +In the following example, `EscapeAnalysis` fails the precise alias analysis since the index +at the `Base.arrayset(false, %4, %8, %6)::Vector{Any}` is not (trivially) constant. +Especially `Any[nothing, nothing]` forms a loop and calls that `arrayset` operation in a loop, +where `%6` is represented as a ϕ-node value (whose value is control-flow dependent). +As a result, `ReturnEscape` ends up imposed on both `%23` (corresponding to `SafeRef(s)`) and +`%25` (corresponding to `SafeRef(t)`), although ideally we want it to be imposed only on `%23` but not on `%25`: +```@repl EAUtils +code_escapes((String,String)) do s, t + ary = Any[nothing, nothing] + ary[1] = SafeRef(s) + ary[2] = SafeRef(t) + return ary[1], length(ary) +end +``` + +The next example illustrates how vector resizing makes precise alias analysis hard. +The essential difficulty is that the dimension of allocated array `%1` is first initialized as `0`, +but it changes by the two `:jl_array_grow_end` calls afterwards. +`EscapeAnalysis` currently simply gives up precise alias analysis whenever it encounters any +array resizing operations and so `ReturnEscape` is imposed on both `%2` (corresponding to `SafeRef(s)`) +and `%20` (corresponding to `SafeRef(t)`): +```@repl EAUtils +code_escapes((String,String)) do s, t + ary = Any[] + push!(ary, SafeRef(s)) + push!(ary, SafeRef(t)) + ary[1], length(ary) +end +``` + +In order to address these difficulties, we need inference to be aware of array dimensions +and propagate array dimensions in a flow-sensitive way[^ArrayDimension], as well as come +up with nice representation of loop-variant values. + +`EscapeAnalysis` at this moment quickly switches to the more imprecise analysis that doesn't +track precise index information in cases when array dimensions or indices are trivially non +constant. The switch can naturally be implemented as a lattice join operation of +`EscapeInfo.AliasInfo` property in the data-flow analysis framework. + +### [Exception Handling](@id EA-Exception-Handling) + +It would be also worth noting how `EscapeAnalysis` handles possible escapes via exceptions. +Naively it seems enough to propagate escape information imposed on `:the_exception` object to +all values that may be thrown in a corresponding `try` block. +But there are actually several other ways to access to the exception object in Julia, +such as `Base.current_exceptions` and `rethrow`. +For example, escape analysis needs to account for potential escape of `r` in the example below: +```@repl EAUtils +const GR = Ref{Any}(); +@noinline function rethrow_escape!() + try + rethrow() + catch err + GR[] = err + end +end; +get′(x) = isassigned(x) ? x[] : throw(x); + +code_escapes() do + r = Ref{String}() + local t + try + t = get′(r) + catch err + t = typeof(err) # `err` (which `r` aliases to) doesn't escape here + rethrow_escape!() # but `r` escapes here + end + return t +end +``` + +It requires a global analysis in order to correctly reason about all possible escapes via +existing exception interfaces. For now we always propagate the topmost escape information to +all potentially thrown objects conservatively, since such an additional analysis might not be +worthwhile to do given that exception handling and error path usually don't need to be +very performance sensitive, and also optimizations of error paths might be very ineffective anyway +since they are often even "unoptimized" intentionally for latency reasons. + +`x::EscapeInfo`'s `x.ThrownEscape` property records SSA statements where `x` can be thrown as an exception. +Using this information `EscapeAnalysis` can propagate possible escapes via exceptions limitedly +to only those may be thrown in each `try` region: +```@repl EAUtils +result = code_escapes((String,String)) do s1, s2 + r1 = Ref(s1) + r2 = Ref(s2) + local ret + try + s1 = get′(r1) + ret = sizeof(s1) + catch err + global GV = err # will definitely escape `r1` + end + s2 = get′(r2) # still `r2` doesn't escape fully + return s2 +end +``` + +## Analysis Usage + +When using `EscapeAnalysis` in Julia's high-level compilation pipeline, we can run +`analyze_escapes(ir::IRCode) -> estate::EscapeState` to analyze escape information of each SSA-IR element in `ir`. + +Note that it should be most effective if `analyze_escapes` runs after inlining, +as `EscapeAnalysis`'s interprocedural escape information handling is limited at this moment. + +Since the computational cost of `analyze_escapes` is not that cheap, +it is more ideal if it runs once and succeeding optimization passes incrementally update + the escape information upon IR transformation. + +```@docs +Core.Compiler.EscapeAnalysis.analyze_escapes +Core.Compiler.EscapeAnalysis.EscapeState +Core.Compiler.EscapeAnalysis.EscapeInfo +``` + +[^LatticeDesign]: Our type inference implementation takes the alternative approach, + where each lattice property is represented by a special lattice element type object. + It turns out that it started to complicate implementations of the lattice operations + mainly because it often requires conversion rules between each lattice element type object. + And we are working on [overhauling our type inference lattice implementation](https://github.com/JuliaLang/julia/pull/42596) + with `EscapeInfo`-like lattice design. + +[^MM02]: _A Graph-Free approach to Data-Flow Analysis_. + Markas Mohnen, 2002, April. + . + +[^BackandForth]: Our type inference algorithm in contrast is implemented as a forward analysis, + because type information usually flows from "definition" to "usage" and it is more + natural and effective to propagate such information in a forward way. + +[^Dynamism]: In some cases, however, object fields can't be analyzed precisely. + For example, object may escape to somewhere `EscapeAnalysis` can't account for possible memory effects on it, + or fields of the objects simply can't be known because of the lack of type information. + In such cases `AliasInfo` property is raised to the topmost element within its own lattice order, + and it causes succeeding field analysis to be conservative and escape information imposed on + fields of an unanalyzable object to be propagated to the object itself. + +[^JVM05]: _Escape Analysis in the Context of Dynamic Compilation and Deoptimization_. + Thomas Kotzmann and Hanspeter Mössenböck, 2005, June. + . + +[^ArrayDimension]: Otherwise we will need yet another forward data-flow analysis on top of the escape analysis. diff --git a/doc/src/devdocs/llvm.md b/doc/src/devdocs/llvm.md index 1e983949ea0b6..840822f136004 100644 --- a/doc/src/devdocs/llvm.md +++ b/doc/src/devdocs/llvm.md @@ -28,7 +28,7 @@ The difference between an intrinsic and a builtin is that a builtin is a first c that can be used like any other Julia function. An intrinsic can operate only on unboxed data, and therefore its arguments must be statically typed. -### Alias Analysis +### [Alias Analysis](@id LLVM-Alias-Analysis) Julia currently uses LLVM's [Type Based Alias Analysis](https://llvm.org/docs/LangRef.html#tbaa-metadata). To find the comments that document the inclusion relationships, look for `static MDNode*` in diff --git a/src/dump.c b/src/dump.c index 168034d89236d..f2c8629ca9c8b 100644 --- a/src/dump.c +++ b/src/dump.c @@ -524,12 +524,14 @@ static void jl_serialize_code_instance(jl_serializer_state *s, jl_code_instance_ jl_serialize_value(s, codeinst->inferred); jl_serialize_value(s, codeinst->rettype_const); jl_serialize_value(s, codeinst->rettype); + jl_serialize_value(s, codeinst->argescapes); } else { // skip storing useless data jl_serialize_value(s, NULL); jl_serialize_value(s, NULL); jl_serialize_value(s, jl_any_type); + jl_serialize_value(s, jl_nothing); } write_uint8(s->s, codeinst->relocatability); jl_serialize_code_instance(s, codeinst->next, skip_partial_opaque); @@ -1667,6 +1669,8 @@ static jl_value_t *jl_deserialize_value_code_instance(jl_serializer_state *s, jl jl_gc_wb(codeinst, codeinst->rettype_const); codeinst->rettype = jl_deserialize_value(s, &codeinst->rettype); jl_gc_wb(codeinst, codeinst->rettype); + codeinst->argescapes = jl_deserialize_value(s, &codeinst->argescapes); + jl_gc_wb(codeinst, codeinst->argescapes); if (constret) codeinst->invoke = jl_fptr_const_return; if ((flags >> 3) & 1) diff --git a/src/gf.c b/src/gf.c index 7c42a9b802df3..01d03fe77394f 100644 --- a/src/gf.c +++ b/src/gf.c @@ -207,7 +207,8 @@ JL_DLLEXPORT jl_code_instance_t* jl_new_codeinst( jl_method_instance_t *mi, jl_value_t *rettype, jl_value_t *inferred_const, jl_value_t *inferred, int32_t const_flags, size_t min_world, size_t max_world, - uint8_t ipo_effects, uint8_t effects, uint8_t relocatability); + uint8_t ipo_effects, uint8_t effects, jl_value_t *argescapes, + uint8_t relocatability); JL_DLLEXPORT void jl_mi_cache_insert(jl_method_instance_t *mi JL_ROOTING_ARGUMENT, jl_code_instance_t *ci JL_ROOTED_ARGUMENT JL_MAYBE_UNROOTED); @@ -244,7 +245,7 @@ jl_datatype_t *jl_mk_builtin_func(jl_datatype_t *dt, const char *name, jl_fptr_a jl_code_instance_t *codeinst = jl_new_codeinst(mi, (jl_value_t*)jl_any_type, jl_nothing, jl_nothing, - 0, 1, ~(size_t)0, 0, 0, 0); + 0, 1, ~(size_t)0, 0, 0, jl_nothing, 0); jl_mi_cache_insert(mi, codeinst); codeinst->specptr.fptr1 = fptr; codeinst->invoke = jl_fptr_args; @@ -367,7 +368,7 @@ JL_DLLEXPORT jl_code_instance_t *jl_get_method_inferred( } codeinst = jl_new_codeinst( mi, rettype, NULL, NULL, - 0, min_world, max_world, 0, 0, 0); + 0, min_world, max_world, 0, 0, jl_nothing, 0); jl_mi_cache_insert(mi, codeinst); return codeinst; } @@ -376,7 +377,8 @@ JL_DLLEXPORT jl_code_instance_t *jl_new_codeinst( jl_method_instance_t *mi, jl_value_t *rettype, jl_value_t *inferred_const, jl_value_t *inferred, int32_t const_flags, size_t min_world, size_t max_world, - uint8_t ipo_effects, uint8_t effects, uint8_t relocatability + uint8_t ipo_effects, uint8_t effects, jl_value_t *argescapes, + uint8_t relocatability /*, jl_array_t *edges, int absolute_max*/) { jl_task_t *ct = jl_current_task; @@ -401,9 +403,10 @@ JL_DLLEXPORT jl_code_instance_t *jl_new_codeinst( codeinst->isspecsig = 0; codeinst->precompile = 0; codeinst->next = NULL; - codeinst->relocatability = relocatability; codeinst->ipo_purity_bits = ipo_effects; codeinst->purity_bits = effects; + codeinst->argescapes = argescapes; + codeinst->relocatability = relocatability; return codeinst; } @@ -2013,7 +2016,7 @@ jl_code_instance_t *jl_compile_method_internal(jl_method_instance_t *mi, size_t if (unspec && jl_atomic_load_relaxed(&unspec->invoke)) { jl_code_instance_t *codeinst = jl_new_codeinst(mi, (jl_value_t*)jl_any_type, NULL, NULL, - 0, 1, ~(size_t)0, 0, 0, 0); + 0, 1, ~(size_t)0, 0, 0, jl_nothing, 0); codeinst->isspecsig = 0; codeinst->specptr = unspec->specptr; codeinst->rettype_const = unspec->rettype_const; @@ -2031,7 +2034,7 @@ jl_code_instance_t *jl_compile_method_internal(jl_method_instance_t *mi, size_t if (!jl_code_requires_compiler(src)) { jl_code_instance_t *codeinst = jl_new_codeinst(mi, (jl_value_t*)jl_any_type, NULL, NULL, - 0, 1, ~(size_t)0, 0, 0, 0); + 0, 1, ~(size_t)0, 0, 0, jl_nothing, 0); codeinst->invoke = jl_fptr_interpret_call; jl_mi_cache_insert(mi, codeinst); record_precompile_statement(mi); @@ -2066,7 +2069,7 @@ jl_code_instance_t *jl_compile_method_internal(jl_method_instance_t *mi, size_t return ucache; } codeinst = jl_new_codeinst(mi, (jl_value_t*)jl_any_type, NULL, NULL, - 0, 1, ~(size_t)0, 0, 0, 0); + 0, 1, ~(size_t)0, 0, 0, jl_nothing, 0); codeinst->isspecsig = 0; codeinst->specptr = ucache->specptr; codeinst->rettype_const = ucache->rettype_const; diff --git a/src/jltypes.c b/src/jltypes.c index f6f9db0762810..86630ac39c059 100644 --- a/src/jltypes.c +++ b/src/jltypes.c @@ -2492,7 +2492,7 @@ void jl_init_types(void) JL_GC_DISABLED jl_code_instance_type = jl_new_datatype(jl_symbol("CodeInstance"), core, jl_any_type, jl_emptysvec, - jl_perm_symsvec(14, + jl_perm_symsvec(15, "def", "next", "min_world", @@ -2502,10 +2502,11 @@ void jl_init_types(void) JL_GC_DISABLED "inferred", //"edges", //"absolute_max", - "ipo_purity_bits", "purity_bits", + "ipo_purity_bits", "purity_bits", + "argescapes", "isspecsig", "precompile", "invoke", "specptr", // function object decls "relocatability"), - jl_svec(14, + jl_svec(15, jl_method_instance_type, jl_any_type, jl_ulong_type, @@ -2515,7 +2516,8 @@ void jl_init_types(void) JL_GC_DISABLED jl_any_type, //jl_any_type, //jl_bool_type, - jl_uint8_type, jl_uint8_type, + jl_uint8_type, jl_uint8_type, + jl_any_type, jl_bool_type, jl_bool_type, jl_any_type, jl_any_type, // fptrs @@ -2668,8 +2670,8 @@ void jl_init_types(void) JL_GC_DISABLED jl_svecset(jl_methtable_type->types, 11, jl_uint8_type); jl_svecset(jl_method_type->types, 12, jl_method_instance_type); jl_svecset(jl_method_instance_type->types, 6, jl_code_instance_type); - jl_svecset(jl_code_instance_type->types, 11, jl_voidpointer_type); jl_svecset(jl_code_instance_type->types, 12, jl_voidpointer_type); + jl_svecset(jl_code_instance_type->types, 13, jl_voidpointer_type); jl_compute_field_offsets(jl_datatype_type); jl_compute_field_offsets(jl_typename_type); diff --git a/src/julia.h b/src/julia.h index 20edd53ad39a7..f3905897a1202 100644 --- a/src/julia.h +++ b/src/julia.h @@ -410,6 +410,7 @@ typedef struct _jl_code_instance_t { uint8_t terminates:2; } purity_flags; }; + jl_value_t *argescapes; // escape information of call arguments // compilation state cache uint8_t isspecsig; // if specptr is a specialized function signature for specTypes->rettype diff --git a/test/choosetests.jl b/test/choosetests.jl index e00aedffdd42e..f86f665bc2217 100644 --- a/test/choosetests.jl +++ b/test/choosetests.jl @@ -142,7 +142,10 @@ function choosetests(choices = []) filtertests!(tests, "subarray") filtertests!(tests, "compiler", ["compiler/inference", "compiler/validation", "compiler/ssair", "compiler/irpasses", "compiler/codegen", - "compiler/inline", "compiler/contextual"]) + "compiler/inline", "compiler/contextual", + "compiler/EscapeAnalysis/local", "compiler/EscapeAnalysis/interprocedural"]) + filtertests!(tests, "compiler/EscapeAnalysis", [ + "compiler/EscapeAnalysis/local", "compiler/EscapeAnalysis/interprocedural"]) filtertests!(tests, "stdlib", STDLIBS) # do ambiguous first to avoid failing if ambiguities are introduced by other tests filtertests!(tests, "ambiguous") diff --git a/test/compiler/EscapeAnalysis/EAUtils.jl b/test/compiler/EscapeAnalysis/EAUtils.jl new file mode 100644 index 0000000000000..3ae9b41a0ddac --- /dev/null +++ b/test/compiler/EscapeAnalysis/EAUtils.jl @@ -0,0 +1,385 @@ +module EAUtils + +export code_escapes, @code_escapes, __clear_cache! + +const CC = Core.Compiler +const EA = CC.EscapeAnalysis + +# entries +# ------- + +import Base: unwrap_unionall, rewrap_unionall +import InteractiveUtils: gen_call_with_extracted_types_and_kwargs + +""" + @code_escapes [options...] f(args...) + +Evaluates the arguments to the function call, determines its types, and then calls +[`code_escapes`](@ref) on the resulting expression. +As with `@code_typed` and its family, any of `code_escapes` keyword arguments can be given +as the optional arguments like `@code_escapes optimize=false myfunc(myargs...)`. +""" +macro code_escapes(ex0...) + return gen_call_with_extracted_types_and_kwargs(__module__, :code_escapes, ex0) +end + +""" + code_escapes(f, argtypes=Tuple{}; [debuginfo::Symbol = :none], [optimize::Bool = true]) -> result::EscapeResult + +Runs the escape analysis on optimized IR of a generic function call with the given type signature. + +# Keyword Arguments + +- `optimize::Bool = true`: + if `true` returns escape information of post-inlining IR (used for local optimization), + otherwise returns escape information of pre-inlining IR (used for interprocedural escape information generation) +- `debuginfo::Symbol = :none`: + controls the amount of code metadata present in the output, possible options are `:none` or `:source`. +""" +function code_escapes(@nospecialize(f), @nospecialize(types=Base.default_tt(f)); + world::UInt = get_world_counter(), + interp::Core.Compiler.AbstractInterpreter = Core.Compiler.NativeInterpreter(world), + debuginfo::Symbol = :none, + optimize::Bool = true) + ft = Core.Typeof(f) + if isa(types, Type) + u = unwrap_unionall(types) + tt = rewrap_unionall(Tuple{ft, u.parameters...}, types) + else + tt = Tuple{ft, types...} + end + interp = EscapeAnalyzer(interp, tt, optimize) + results = Base.code_typed_by_type(tt; optimize=true, world, interp) + isone(length(results)) || throw(ArgumentError("`code_escapes` only supports single analysis result")) + return EscapeResult(interp.ir, interp.state, interp.linfo, debuginfo===:source) +end + +# in order to run a whole analysis from ground zero (e.g. for benchmarking, etc.) +__clear_cache!() = empty!(GLOBAL_CODE_CACHE) + +# AbstractInterpreter +# ------------------- + +# imports +import .CC: + AbstractInterpreter, NativeInterpreter, WorldView, WorldRange, + InferenceParams, OptimizationParams, get_world_counter, get_inference_cache, code_cache, + lock_mi_inference, unlock_mi_inference, add_remark!, + may_optimize, may_compress, may_discard_trees, verbose_stmt_info +# usings +import Core: + CodeInstance, MethodInstance, CodeInfo +import .CC: + InferenceResult, OptimizationState, IRCode, copy as cccopy, + @timeit, convert_to_ircode, slot2reg, compact!, ssa_inlining_pass!, sroa_pass!, + adce_pass!, type_lift_pass!, JLOptions, verify_ir, verify_linetable +import .EA: analyze_escapes, ArgEscapeCache, EscapeInfo, EscapeState, is_ipo_profitable + +# when working outside of Core.Compiler, +# cache entire escape state for later inspection and debugging +struct EscapeCache + cache::ArgEscapeCache + state::EscapeState # preserved just for debugging purpose + ir::IRCode # preserved just for debugging purpose +end + +mutable struct EscapeAnalyzer{State} <: AbstractInterpreter + native::NativeInterpreter + cache::IdDict{InferenceResult,EscapeCache} + entry_tt + optimize::Bool + ir::IRCode + state::State + linfo::MethodInstance + EscapeAnalyzer(native::NativeInterpreter, @nospecialize(tt), optimize::Bool) = + new{EscapeState}(native, IdDict{InferenceResult,EscapeCache}(), tt, optimize) +end + +CC.InferenceParams(interp::EscapeAnalyzer) = InferenceParams(interp.native) +CC.OptimizationParams(interp::EscapeAnalyzer) = OptimizationParams(interp.native) +CC.get_world_counter(interp::EscapeAnalyzer) = get_world_counter(interp.native) + +CC.lock_mi_inference(::EscapeAnalyzer, ::MethodInstance) = nothing +CC.unlock_mi_inference(::EscapeAnalyzer, ::MethodInstance) = nothing + +CC.add_remark!(interp::EscapeAnalyzer, sv, s) = add_remark!(interp.native, sv, s) + +CC.may_optimize(interp::EscapeAnalyzer) = may_optimize(interp.native) +CC.may_compress(interp::EscapeAnalyzer) = may_compress(interp.native) +CC.may_discard_trees(interp::EscapeAnalyzer) = may_discard_trees(interp.native) +CC.verbose_stmt_info(interp::EscapeAnalyzer) = verbose_stmt_info(interp.native) + +CC.get_inference_cache(interp::EscapeAnalyzer) = get_inference_cache(interp.native) + +const GLOBAL_CODE_CACHE = IdDict{MethodInstance,CodeInstance}() + +function CC.code_cache(interp::EscapeAnalyzer) + worlds = WorldRange(get_world_counter(interp)) + return WorldView(GlobalCache(), worlds) +end + +struct GlobalCache end + +CC.haskey(wvc::WorldView{GlobalCache}, mi::MethodInstance) = haskey(GLOBAL_CODE_CACHE, mi) + +CC.get(wvc::WorldView{GlobalCache}, mi::MethodInstance, default) = get(GLOBAL_CODE_CACHE, mi, default) + +CC.getindex(wvc::WorldView{GlobalCache}, mi::MethodInstance) = getindex(GLOBAL_CODE_CACHE, mi) + +function CC.setindex!(wvc::WorldView{GlobalCache}, ci::CodeInstance, mi::MethodInstance) + GLOBAL_CODE_CACHE[mi] = ci + add_callback!(mi) # register the callback on invalidation + return nothing +end + +function add_callback!(linfo) + if !isdefined(linfo, :callbacks) + linfo.callbacks = Any[invalidate_cache!] + else + if !any(@nospecialize(cb)->cb===invalidate_cache!, linfo.callbacks) + push!(linfo.callbacks, invalidate_cache!) + end + end + return nothing +end + +function invalidate_cache!(replaced, max_world, depth = 0) + delete!(GLOBAL_CODE_CACHE, replaced) + + if isdefined(replaced, :backedges) + for mi in replaced.backedges + mi = mi::MethodInstance + if !haskey(GLOBAL_CODE_CACHE, mi) + continue # otherwise fall into infinite loop + end + invalidate_cache!(mi, max_world, depth+1) + end + end + return nothing +end + +function CC.optimize(interp::EscapeAnalyzer, + opt::OptimizationState, params::OptimizationParams, caller::InferenceResult) + ir = run_passes_with_ea(interp, opt.src, opt, caller) + return CC.finish(interp, opt, params, ir, caller) +end + +function CC.cache_result!(interp::EscapeAnalyzer, caller::InferenceResult) + if haskey(interp.cache, caller) + GLOBAL_ESCAPE_CACHE[caller.linfo] = interp.cache[caller] + end + return Base.@invoke CC.cache_result!(interp::AbstractInterpreter, caller::InferenceResult) +end + +const GLOBAL_ESCAPE_CACHE = IdDict{MethodInstance,EscapeCache}() + +""" + cache_escapes!(caller::InferenceResult, estate::EscapeState, cacheir::IRCode) + +Transforms escape information of call arguments of `caller`, +and then caches it into a global cache for later interprocedural propagation. +""" +function cache_escapes!(interp::EscapeAnalyzer, + caller::InferenceResult, estate::EscapeState, cacheir::IRCode) + cache = ArgEscapeCache(estate) + ecache = EscapeCache(cache, estate, cacheir) + interp.cache[caller] = ecache + return cache +end + +function get_escape_cache(interp::EscapeAnalyzer) + return function (linfo::Union{InferenceResult,MethodInstance}) + if isa(linfo, InferenceResult) + ecache = get(interp.cache, linfo, nothing) + else + ecache = get(GLOBAL_ESCAPE_CACHE, linfo, nothing) + end + return ecache !== nothing ? ecache.cache : nothing + end +end + +function run_passes_with_ea(interp::EscapeAnalyzer, ci::CodeInfo, sv::OptimizationState, + caller::InferenceResult) + @timeit "convert" ir = convert_to_ircode(ci, sv) + @timeit "slot2reg" ir = slot2reg(ir, ci, sv) + # TODO: Domsorting can produce an updated domtree - no need to recompute here + @timeit "compact 1" ir = compact!(ir) + nargs = let def = sv.linfo.def; isa(def, Method) ? Int(def.nargs) : 0; end + local state + if is_ipo_profitable(ir, nargs) || caller.linfo.specTypes === interp.entry_tt + try + @timeit "[IPO EA]" begin + state = analyze_escapes(ir, nargs, false, get_escape_cache(interp)) + cache_escapes!(interp, caller, state, cccopy(ir)) + end + catch err + @error "error happened within [IPO EA], insepct `Main.ir` and `Main.nargs`" + @eval Main (ir = $ir; nargs = $nargs) + rethrow(err) + end + end + if caller.linfo.specTypes === interp.entry_tt && !interp.optimize + # return back the result + interp.ir = cccopy(ir) + interp.state = state + interp.linfo = sv.linfo + end + @timeit "Inlining" ir = ssa_inlining_pass!(ir, ir.linetable, sv.inlining, ci.propagate_inbounds) + # @timeit "verify 2" verify_ir(ir) + @timeit "compact 2" ir = compact!(ir) + if caller.linfo.specTypes === interp.entry_tt && interp.optimize + try + @timeit "[Local EA]" state = analyze_escapes(ir, nargs, true, get_escape_cache(interp)) + catch err + @error "error happened within [Local EA], insepct `Main.ir` and `Main.nargs`" + @eval Main (ir = $ir; nargs = $nargs) + rethrow(err) + end + # return back the result + interp.ir = cccopy(ir) + interp.state = state + interp.linfo = sv.linfo + end + @timeit "SROA" ir = sroa_pass!(ir) + @timeit "ADCE" ir = adce_pass!(ir) + @timeit "type lift" ir = type_lift_pass!(ir) + @timeit "compact 3" ir = compact!(ir) + if JLOptions().debug_level == 2 + @timeit "verify 3" (verify_ir(ir); verify_linetable(ir.linetable)) + end + return ir +end + +# printing +# -------- + +import Core: Argument, SSAValue +import .CC: widenconst, singleton_type + +Base.getindex(estate::EscapeState, @nospecialize(x)) = CC.getindex(estate, x) + +function get_name_color(x::EscapeInfo, symbol::Bool = false) + getname(x) = string(nameof(x)) + if x === EA.⊥ + name, color = (getname(EA.NotAnalyzed), "◌"), :plain + elseif EA.has_no_escape(EA.ignore_argescape(x)) + if EA.has_arg_escape(x) + name, color = (getname(EA.ArgEscape), "✓"), :cyan + else + name, color = (getname(EA.NoEscape), "✓"), :green + end + elseif EA.has_all_escape(x) + name, color = (getname(EA.AllEscape), "X"), :red + elseif EA.has_return_escape(x) + name = (getname(EA.ReturnEscape), "↑") + color = EA.has_thrown_escape(x) ? :yellow : :blue + else + name = (nothing, "*") + color = EA.has_thrown_escape(x) ? :yellow : :bold + end + name = symbol ? last(name) : first(name) + if name !== nothing && !isa(x.AliasInfo, Bool) + name = string(name, "′") + end + return name, color +end + +# pcs = sprint(show, collect(x.EscapeSites); context=:limit=>true) +function Base.show(io::IO, x::EscapeInfo) + name, color = get_name_color(x) + if isnothing(name) + Base.@invoke show(io::IO, x::Any) + else + printstyled(io, name; color) + end +end +function Base.show(io::IO, ::MIME"application/prs.juno.inline", x::EscapeInfo) + name, color = get_name_color(x) + if isnothing(name) + return x # use fancy tree-view + else + printstyled(io, name; color) + end +end + +struct EscapeResult + ir::IRCode + state::EscapeState + linfo::Union{Nothing,MethodInstance} + source::Bool + function EscapeResult(ir::IRCode, state::EscapeState, + linfo::Union{Nothing,MethodInstance} = nothing, + source::Bool=false) + return new(ir, state, linfo, source) + end +end +Base.show(io::IO, result::EscapeResult) = print_with_info(io, result) +@eval Base.iterate(res::EscapeResult, state=1) = + return state > $(fieldcount(EscapeResult)) ? nothing : (getfield(res, state), state+1) + +Base.show(io::IO, cached::EscapeCache) = show(io, EscapeResult(cached.ir, cached.state, nothing)) + +# adapted from https://github.com/JuliaDebug/LoweredCodeUtils.jl/blob/4612349432447e868cf9285f647108f43bd0a11c/src/codeedges.jl#L881-L897 +function print_with_info(io::IO, (; ir, state, linfo, source)::EscapeResult) + # print escape information on SSA values + function preprint(io::IO) + ft = ir.argtypes[1] + f = singleton_type(ft) + if f === nothing + f = widenconst(ft) + end + print(io, f, '(') + for i in 1:state.nargs + arg = state[Argument(i)] + i == 1 && continue + c, color = get_name_color(arg, true) + printstyled(io, c, ' ', '_', i, "::", ir.argtypes[i]; color) + i ≠ state.nargs && print(io, ", ") + end + print(io, ')') + if !isnothing(linfo) + def = linfo.def + printstyled(io, " in ", (isa(def, Module) ? (def,) : (def.module, " at ", def.file, ':', def.line))...; color=:bold) + end + println(io) + end + + # print escape information on SSA values + # nd = ndigits(length(ssavalues)) + function preprint(io::IO, idx::Int) + c, color = get_name_color(state[SSAValue(idx)], true) + # printstyled(io, lpad(idx, nd), ' ', c, ' '; color) + printstyled(io, rpad(c, 2), ' '; color) + end + + print_with_info(preprint, (args...)->nothing, io, ir, source) +end + +function print_with_info(preprint, postprint, io::IO, ir::IRCode, source::Bool) + io = IOContext(io, :displaysize=>displaysize(io)) + used = Base.IRShow.stmts_used(io, ir) + if source + line_info_preprinter = function (io::IO, indent::String, idx::Int) + r = Base.IRShow.inline_linfo_printer(ir)(io, indent, idx) + idx ≠ 0 && preprint(io, idx) + return r + end + else + line_info_preprinter = Base.IRShow.lineinfo_disabled + end + line_info_postprinter = Base.IRShow.default_expr_type_printer + preprint(io) + bb_idx_prev = bb_idx = 1 + for idx = 1:length(ir.stmts) + preprint(io, idx) + bb_idx = Base.IRShow.show_ir_stmt(io, ir, idx, line_info_preprinter, line_info_postprinter, used, ir.cfg, bb_idx) + postprint(io, idx, bb_idx != bb_idx_prev) + bb_idx_prev = bb_idx + end + max_bb_idx_size = ndigits(length(ir.cfg.blocks)) + line_info_preprinter(io, " "^(max_bb_idx_size + 2), 0) + postprint(io) + return nothing +end + +end # module EAUtils diff --git a/test/compiler/EscapeAnalysis/interprocedural.jl b/test/compiler/EscapeAnalysis/interprocedural.jl new file mode 100644 index 0000000000000..eccdc710a6c12 --- /dev/null +++ b/test/compiler/EscapeAnalysis/interprocedural.jl @@ -0,0 +1,264 @@ +# IPO EA Test +# =========== +# EA works on pre-inlining IR + +include(normpath(@__DIR__, "setup.jl")) + +# callsites +# --------- + +import .EA: ignore_argescape + +noescape(a) = nothing +noescape(a, b) = nothing +function global_escape!(x) + GR[] = x + return nothing +end +union_escape!(x) = global_escape!(x) +union_escape!(x::SafeRef) = nothing +union_escape!(x::SafeRefs) = nothing +Base.@constprop :aggressive function conditional_escape!(cnd, x) + cnd && global_escape!(x) + return nothing +end + +# MethodMatchInfo -- global cache +let result = code_escapes((SafeRef{String},); optimize=false) do x + return noescape(x) + end + @test has_no_escape(ignore_argescape(result.state[Argument(2)])) +end +let result = code_escapes((SafeRef{String},); optimize=false) do x + identity(x) + return nothing + end + @test has_no_escape(ignore_argescape(result.state[Argument(2)])) +end +let result = code_escapes((SafeRef{String},); optimize=false) do x + return identity(x) + end + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) +end +let result = code_escapes((SafeRef{String},); optimize=false) do x + return Ref(x) + end + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) +end +let result = code_escapes((SafeRef{String},); optimize=false) do x + r = Ref{SafeRef{String}}() + r[] = x + return r + end + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) +end +let result = code_escapes((SafeRef{String},); optimize=false) do x + global_escape!(x) + end + @test has_all_escape(result.state[Argument(2)]) +end +# UnionSplitInfo +let result = code_escapes((Bool,Vector{Any}); optimize=false) do c, s + x = c ? s : SafeRef(s) + union_escape!(x) + end + @test has_all_escape(result.state[Argument(3)]) # s +end +let result = code_escapes((Bool,Vector{Any}); optimize=false) do c, s + x = c ? SafeRef(s) : SafeRefs(s, s) + union_escape!(x) + end + @test has_no_escape(ignore_argescape(result.state[Argument(2)])) +end +# ConstCallInfo -- local cache +let result = code_escapes((SafeRef{String},); optimize=false) do x + return conditional_escape!(false, x) + end + @test has_no_escape(ignore_argescape(result.state[Argument(2)])) +end +# InvokeCallInfo +let result = code_escapes((SafeRef{String},); optimize=false) do x + return Base.@invoke noescape(x::Any) + end + @test has_no_escape(ignore_argescape(result.state[Argument(2)])) +end +let result = code_escapes((SafeRef{String},); optimize=false) do x + return Base.@invoke conditional_escape!(false::Any, x::Any) + end + @test has_no_escape(ignore_argescape(result.state[Argument(2)])) +end + +# MethodError +# ----------- +# accounts for ThrownEscape via potential MethodError + +# no method error +identity_if_string(x::SafeRef) = nothing +let result = code_escapes((SafeRef{String},); optimize=false) do x + identity_if_string(x) + end + i = only(findall(iscall((result.ir, identity_if_string)), result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test !has_thrown_escape(result.state[Argument(2)], i) + @test !has_return_escape(result.state[Argument(2)], r) +end +let result = code_escapes((Union{SafeRef{String},Vector{String}},); optimize=false) do x + identity_if_string(x) + end + i = only(findall(iscall((result.ir, identity_if_string)), result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_thrown_escape(result.state[Argument(2)], i) + @test !has_return_escape(result.state[Argument(2)], r) +end +let result = code_escapes((SafeRef{String},); optimize=false) do x + try + identity_if_string(x) + catch err + global GV = err + end + return nothing + end + @test !has_all_escape(result.state[Argument(2)]) +end +let result = code_escapes((Union{SafeRef{String},Vector{String}},); optimize=false) do x + try + identity_if_string(x) + catch err + global GV = err + end + return nothing + end + @test has_all_escape(result.state[Argument(2)]) +end +# method ambiguity error +ambig_error_test(a::SafeRef, b) = nothing +ambig_error_test(a, b::SafeRef) = nothing +ambig_error_test(a, b) = nothing +let result = code_escapes((SafeRef{String},Any); optimize=false) do x, y + ambig_error_test(x, y) + end + i = only(findall(iscall((result.ir, ambig_error_test)), result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_thrown_escape(result.state[Argument(2)], i) # x + @test has_thrown_escape(result.state[Argument(3)], i) # y + @test !has_return_escape(result.state[Argument(2)], r) # x + @test !has_return_escape(result.state[Argument(3)], r) # y +end +let result = code_escapes((SafeRef{String},Any); optimize=false) do x, y + try + ambig_error_test(x, y) + catch err + global GV = err + end + end + @test has_all_escape(result.state[Argument(2)]) # x + @test has_all_escape(result.state[Argument(3)]) # y +end + +# Local EA integration +# -------------------- + +# propagate escapes imposed on call arguments + +# FIXME handle _apply_iterate +# FIXME currently we can't prove the effect-freeness of `getfield(RefValue{String}, :x)` +# because of this check https://github.com/JuliaLang/julia/blob/94b9d66b10e8e3ebdb268e4be5f7e1f43079ad4e/base/compiler/tfuncs.jl#L745 +# and thus it leads to the following two broken tests + +@noinline broadcast_noescape1(a) = (broadcast(identity, a); nothing) +let result = code_escapes() do + broadcast_noescape1(Ref("Hi")) + end + i = only(findall(isnew, result.ir.stmts.inst)) + @test_broken !has_return_escape(result.state[SSAValue(i)]) + @test_broken !has_thrown_escape(result.state[SSAValue(i)]) +end +@noinline broadcast_noescape2(b) = broadcast(identity, b) +let result = code_escapes() do + broadcast_noescape2(Ref("Hi")) + end + i = only(findall(isnew, result.ir.stmts.inst)) + @test_broken !has_return_escape(result.state[SSAValue(i)]) + @test_broken !has_thrown_escape(result.state[SSAValue(i)]) +end +@noinline allescape_argument(a) = (global GV = a) # obvious escape +let result = code_escapes() do + allescape_argument(Ref("Hi")) + end + i = only(findall(isnew, result.ir.stmts.inst)) + @test has_all_escape(result.state[SSAValue(i)]) +end +# if we can't determine the matching method statically, we should be conservative +let result = code_escapes((Ref{Any},)) do a + may_exist(a) + end + @test has_all_escape(result.state[Argument(2)]) +end +let result = code_escapes((Ref{Any},)) do a + Base.@invokelatest broadcast_noescape1(a) + end + @test has_all_escape(result.state[Argument(2)]) +end + +# handling of simple union-split (just exploit the inliner's effort) +@noinline unionsplit_noescape(a) = string(nothing) +@noinline unionsplit_noescape(a::Int) = a + 10 +let result = code_escapes((Union{Int,Nothing},)) do x + s = SafeRef{Union{Int,Nothing}}(x) + unionsplit_noescape(s[]) + return nothing + end + inds = findall(isnew, result.ir.stmts.inst) # find allocation statement + @assert !isempty(inds) + for i in inds + @test has_no_escape(result.state[SSAValue(i)]) + end +end + +@noinline function unused_argument(a) + println("prevent inlining") + return Base.inferencebarrier(nothing) +end +let result = code_escapes() do + a = Ref("foo") # shouldn't be "return escape" + b = unused_argument(a) + nothing + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test !has_return_escape(result.state[SSAValue(i)], r) + + result = code_escapes() do + a = Ref("foo") # still should be "return escape" + b = unused_argument(a) + return a + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[SSAValue(i)], r) +end + +# should propagate escape information imposed on return value to the aliased call argument +@noinline returnescape_argument(a) = (println("prevent inlining"); a) +let result = code_escapes() do + obj = Ref("foo") # should be "return escape" + ret = returnescape_argument(obj) + return ret # alias of `obj` + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[SSAValue(i)], r) +end +@noinline noreturnescape_argument(a) = (println("prevent inlining"); identity("hi")) +let result = code_escapes() do + obj = Ref("foo") # better to not be "return escape" + ret = noreturnescape_argument(obj) + return ret # must not alias to `obj` + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test !has_return_escape(result.state[SSAValue(i)], r) +end diff --git a/test/compiler/EscapeAnalysis/local.jl b/test/compiler/EscapeAnalysis/local.jl new file mode 100644 index 0000000000000..d5b7bd92e8cfb --- /dev/null +++ b/test/compiler/EscapeAnalysis/local.jl @@ -0,0 +1,2203 @@ +# Local EA Test +# ============= +# EA works on post-inlining IR + +include(normpath(@__DIR__, "setup.jl")) + +@testset "basics" begin + let # arg return + result = code_escapes((Any,)) do a # return to caller + return nothing + end + @test has_arg_escape(result.state[Argument(2)]) + # return + result = code_escapes((Any,)) do a + return a + end + i = only(findall(isreturn, result.ir.stmts.inst)) + @test has_arg_escape(result.state[Argument(1)]) # self + @test !has_return_escape(result.state[Argument(1)], i) # self + @test has_arg_escape(result.state[Argument(2)]) # a + @test has_return_escape(result.state[Argument(2)], i) # a + end + let # global store + result = code_escapes((Any,)) do a + global GV = a + nothing + end + @test has_all_escape(result.state[Argument(2)]) + end + let # global load + result = code_escapes() do + global GV + return GV + end + i = only(findall(has_return_escape, map(i->result.state[SSAValue(i)], 1:length(result.ir.stmts)))) + @test has_all_escape(result.state[SSAValue(i)]) + end + let # global store / load (https://github.com/aviatesk/EscapeAnalysis.jl/issues/56) + result = code_escapes((Any,)) do s + global GV + GV = s + return GV + end + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) + end + let # :gc_preserve_begin / :gc_preserve_end + result = code_escapes((String,)) do s + m = SafeRef(s) + GC.@preserve m begin + return nothing + end + end + i = findfirst(isT(SafeRef{String}), result.ir.stmts.type) # find allocation statement + @test !isnothing(i) + @test has_no_escape(result.state[SSAValue(i)]) + end + let # :isdefined + result = code_escapes((String, Bool, )) do a, b + if b + s = Ref(a) + end + return @isdefined(s) + end + i = findfirst(isT(Base.RefValue{String}), result.ir.stmts.type) # find allocation statement + @test !isnothing(i) + @test has_no_escape(result.state[SSAValue(i)]) + end + let # ϕ-node + result = code_escapes((Bool,Any,Any)) do cond, a, b + c = cond ? a : b # ϕ(a, b) + return c + end + @assert any(@nospecialize(x)->isa(x, Core.PhiNode), result.ir.stmts.inst) + i = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(3)], i) # a + @test has_return_escape(result.state[Argument(4)], i) # b + end + let # π-node + result = code_escapes((Any,)) do a + if isa(a, Regex) # a::π(Regex) + return a + end + return nothing + end + @assert any(@nospecialize(x)->isa(x, Core.PiNode), result.ir.stmts.inst) + @test any(findall(isreturn, result.ir.stmts.inst)) do i + has_return_escape(result.state[Argument(2)], i) + end + end + let # φᶜ-node / ϒ-node + result = code_escapes((Any,String)) do a, b + local x::String + try + x = a + catch err + x = b + end + return x + end + @assert any(@nospecialize(x)->isa(x, Core.PhiCNode), result.ir.stmts.inst) + @assert any(@nospecialize(x)->isa(x, Core.UpsilonNode), result.ir.stmts.inst) + i = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], i) + @test has_return_escape(result.state[Argument(3)], i) + end + let # branching + result = code_escapes((Any,Bool,)) do a, c + if c + return nothing # a doesn't escape in this branch + else + return a # a escapes to a caller + end + end + @test has_return_escape(result.state[Argument(2)]) + end + let # loop + result = code_escapes((Int,)) do n + c = SafeRef{Bool}(false) + while n > 0 + rand(Bool) && return c + end + nothing + end + i = only(findall(isnew, result.ir.stmts.inst)) + @test has_return_escape(result.state[SSAValue(i)]) + end + let # try/catch + result = code_escapes((Any,)) do a + try + nothing + catch err + return a # return escape + end + end + @test has_return_escape(result.state[Argument(2)]) + end + let result = code_escapes((Any,)) do a + try + nothing + finally + return a # return escape + end + end + @test has_return_escape(result.state[Argument(2)]) + end + let # :foreigncall + result = code_escapes((Any,)) do x + ccall(:some_ccall, Any, (Any,), x) + end + @test has_all_escape(result.state[Argument(2)]) + end +end + +let # simple allocation + result = code_escapes((Bool,)) do c + mm = SafeRef{Bool}(c) # just allocated, never escapes + return mm[] ? nothing : 1 + end + i = only(findall(isnew, result.ir.stmts.inst)) + @test has_no_escape(result.state[SSAValue(i)]) +end + +@testset "builtins" begin + let # throw + r = code_escapes((Any,)) do a + throw(a) + end + @test has_thrown_escape(r.state[Argument(2)]) + end + + let # implicit throws + r = code_escapes((Any,)) do a + getfield(a, :may_not_field) + end + @test has_thrown_escape(r.state[Argument(2)]) + + r = code_escapes((Any,)) do a + sizeof(a) + end + @test has_thrown_escape(r.state[Argument(2)]) + end + + let # :=== + result = code_escapes((Bool, String)) do cond, s + m = cond ? SafeRef(s) : nothing + c = m === nothing + return c + end + i = only(findall(isnew, result.ir.stmts.inst)) + @test has_no_escape(result.state[SSAValue(i)]) + end + + let # sizeof + ary = [0,1,2] + result = @eval code_escapes() do + ary = $(QuoteNode(ary)) + sizeof(ary) + end + i = only(findall(isT(Core.Const(ary)), result.ir.stmts.type)) + @test has_no_escape(result.state[SSAValue(i)]) + end + + let # ifelse + result = code_escapes((Bool,)) do c + r = ifelse(c, Ref("yes"), Ref("no")) + return r + end + inds = findall(isnew, result.ir.stmts.inst) + @assert !isempty(inds) + for i in inds + @test has_return_escape(result.state[SSAValue(i)]) + end + end + let # ifelse (with constant condition) + result = code_escapes() do + r = ifelse(true, Ref("yes"), Ref(nothing)) + return r + end + for i in 1:length(result.ir.stmts) + if isnew(result.ir.stmts.inst[i]) && isT(Base.RefValue{String})(result.ir.stmts.type[i]) + @test has_return_escape(result.state[SSAValue(i)]) + elseif isnew(result.ir.stmts.inst[i]) && isT(Base.RefValue{Nothing})(result.ir.stmts.type[i]) + @test has_no_escape(result.state[SSAValue(i)]) + end + end + end + + let # typeassert + result = code_escapes((Any,)) do x + y = x::String + return y + end + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) + @test !has_all_escape(result.state[Argument(2)]) + end + + let # isdefined + result = code_escapes((Any,)) do x + isdefined(x, :foo) ? x : throw("undefined") + end + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) + @test !has_all_escape(result.state[Argument(2)]) + + result = code_escapes((Module,)) do m + isdefined(m, 10) # throws + end + @test has_thrown_escape(result.state[Argument(2)]) + end +end + +@testset "flow-sensitivity" begin + # ReturnEscape + let result = code_escapes((Bool,)) do cond + r = Ref("foo") + if cond + return cond + end + return r + end + i = only(findall(isnew, result.ir.stmts.inst)) + rts = findall(isreturn, result.ir.stmts.inst) + @assert length(rts) == 2 + @test count(rt->has_return_escape(result.state[SSAValue(i)], rt), rts) == 1 + end + let result = code_escapes((Bool,)) do cond + r = Ref("foo") + cnt = 0 + while rand(Bool) + cnt += 1 + rand(Bool) && return r + end + rand(Bool) && return r + return cnt + end + i = only(findall(isnew, result.ir.stmts.inst)) + rts = findall(isreturn, result.ir.stmts.inst) # return statement + @assert length(rts) == 3 + @test count(rt->has_return_escape(result.state[SSAValue(i)], rt), rts) == 2 + end +end + +@testset "escape through exceptions" begin + M = @eval Module() begin + unsafeget(x) = isassigned(x) ? x[] : throw(x) + @noinline function escape_rethrow!() + try + rethrow() + catch err + GR[] = err + end + end + @noinline function escape_current_exceptions!() + excs = Base.current_exceptions() + GR[] = excs + end + const GR = Ref{Any}() + @__MODULE__ + end + + let # simple: return escape + result = @eval M $code_escapes() do + r = Ref{String}() + local ret + try + s = unsafeget(r) + ret = sizeof(s) + catch err + ret = err + end + return ret + end + i = only(findall(isnew, result.ir.stmts.inst)) + @test has_return_escape(result.state[SSAValue(i)]) + end + + let # simple: global escape + result = @eval M $code_escapes() do + r = Ref{String}() + local ret # prevent DCE + try + s = unsafeget(r) + ret = sizeof(s) + catch err + global GV = err + end + nothing + end + i = only(findall(isnew, result.ir.stmts.inst)) + @test has_all_escape(result.state[SSAValue(i)]) + end + + let # account for possible escapes via nested throws + result = @eval M $code_escapes() do + r = Ref{String}() + try + try + unsafeget(r) + catch err1 + throw(err1) + end + catch err2 + GR[] = err2 + end + end + i = only(findall(isnew, result.ir.stmts.inst)) + @test has_all_escape(result.state[SSAValue(i)]) + end + let # account for possible escapes via `rethrow` + result = @eval M $code_escapes() do + r = Ref{String}() + try + try + unsafeget(r) + catch err1 + rethrow(err1) + end + catch err2 + GR[] = err2 + end + end + i = only(findall(isnew, result.ir.stmts.inst)) + @test has_all_escape(result.state[SSAValue(i)]) + end + let # account for possible escapes via `rethrow` + result = @eval M $code_escapes() do + try + r = Ref{String}() + unsafeget(r) + catch + escape_rethrow!() + end + end + i = only(findall(isnew, result.ir.stmts.inst)) + @test has_all_escape(result.state[SSAValue(i)]) + end + let # account for possible escapes via `rethrow` + result = @eval M $code_escapes() do + local t + try + r = Ref{String}() + t = unsafeget(r) + catch err + t = typeof(err) + escape_rethrow!() + end + return t + end + i = only(findall(isnew, result.ir.stmts.inst)) + @test has_all_escape(result.state[SSAValue(i)]) + end + let # account for possible escapes via `Base.current_exceptions` + result = @eval M $code_escapes() do + try + r = Ref{String}() + unsafeget(r) + catch + GR[] = Base.current_exceptions() + end + end + i = only(findall(isnew, result.ir.stmts.inst)) + @test has_all_escape(result.state[SSAValue(i)]) + end + let # account for possible escapes via `Base.current_exceptions` + result = @eval M $code_escapes() do + try + r = Ref{String}() + unsafeget(r) + catch + escape_current_exceptions!() + end + end + i = only(findall(isnew, result.ir.stmts.inst)) + @test has_all_escape(result.state[SSAValue(i)]) + end + + let # contextual: escape information imposed on `err` shouldn't propagate to `r2`, but only to `r1` + result = @eval M $code_escapes() do + r1 = Ref{String}() + r2 = Ref{String}() + local ret + try + s1 = unsafeget(r1) + ret = sizeof(s1) + catch err + global GV = err + end + s2 = unsafeget(r2) + return s2, r2 + end + is = findall(isnew, result.ir.stmts.inst) + @test length(is) == 2 + i1, i2 = is + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_all_escape(result.state[SSAValue(i1)]) + @test !has_all_escape(result.state[SSAValue(i2)]) + @test has_return_escape(result.state[SSAValue(i2)], r) + end + + # XXX test cases below are currently broken because of the technical reason described in `escape_exception!` + + let # limited propagation: exception is caught within a frame => doesn't escape to a caller + result = @eval M $code_escapes() do + r = Ref{String}() + local ret + try + s = unsafeget(r) + ret = sizeof(s) + catch + ret = nothing + end + return ret + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test_broken !has_return_escape(result.state[SSAValue(i)], r) + end + let # sequential: escape information imposed on `err1` and `err2 should propagate separately + result = @eval M $code_escapes() do + r1 = Ref{String}() + r2 = Ref{String}() + local ret + try + s1 = unsafeget(r1) + ret = sizeof(s1) + catch err1 + global GV = err1 + end + try + s2 = unsafeget(r2) + ret = sizeof(s2) + catch err2 + ret = err2 + end + return ret + end + is = findall(isnew, result.ir.stmts.inst) + @test length(is) == 2 + i1, i2 = is + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_all_escape(result.state[SSAValue(i1)]) + @test has_return_escape(result.state[SSAValue(i2)], r) + @test_broken !has_all_escape(result.state[SSAValue(i2)]) + end + let # nested: escape information imposed on `inner` shouldn't propagate to `s` + result = @eval M $code_escapes() do + r = Ref{String}() + local ret + try + s = unsafeget(r) + try + ret = sizeof(s) + catch inner + return inner + end + catch outer + ret = nothing + end + return ret + end + i = only(findall(isnew, result.ir.stmts.inst)) + @test_broken !has_return_escape(result.state[SSAValue(i)]) + end + let # merge: escape information imposed on `err1` and `err2 should be merged + result = @eval M $code_escapes() do + r = Ref{String}() + local ret + try + s = unsafeget(r) + ret = sizeof(s) + catch err1 + return err1 + end + try + s = unsafeget(r) + ret = sizeof(s) + catch err2 + return err2 + end + nothing + end + i = only(findall(isnew, result.ir.stmts.inst)) + rs = findall(isreturn, result.ir.stmts.inst) + @test_broken !has_all_escape(result.state[SSAValue(i)]) + for r in rs + @test has_return_escape(result.state[SSAValue(i)], r) + end + end + let # no exception handling: should keep propagating the escape + result = @eval M $code_escapes() do + r = Ref{String}() + local ret + try + s = unsafeget(r) + ret = sizeof(s) + finally + if !@isdefined(ret) + ret = 42 + end + end + return ret + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test_broken !has_return_escape(result.state[SSAValue(i)], r) + end +end + +@testset "field analysis / alias analysis" begin + # escaped allocations + # ------------------- + + # escaped object should escape its fields as well + let result = code_escapes((Any,)) do a + global GV = SafeRef{Any}(a) + nothing + end + i = only(findall(isnew, result.ir.stmts.inst)) + @test has_all_escape(result.state[SSAValue(i)]) + @test has_all_escape(result.state[Argument(2)]) + end + let result = code_escapes((Any,)) do a + global GV = (a,) + nothing + end + i = only(findall(iscall((result.ir, tuple)), result.ir.stmts.inst)) + @test has_all_escape(result.state[SSAValue(i)]) + @test has_all_escape(result.state[Argument(2)]) + end + let result = code_escapes((Any,)) do a + o0 = SafeRef{Any}(a) + global GV = SafeRef(o0) + nothing + end + is = findall(isnew, result.ir.stmts.inst) + @test length(is) == 2 + i0, i1 = is + @test has_all_escape(result.state[SSAValue(i0)]) + @test has_all_escape(result.state[SSAValue(i1)]) + @test has_all_escape(result.state[Argument(2)]) + end + let result = code_escapes((Any,)) do a + t0 = (a,) + global GV = (t0,) + nothing + end + inds = findall(iscall((result.ir, tuple)), result.ir.stmts.inst) + @assert length(inds) == 2 + for i in inds; @test has_all_escape(result.state[SSAValue(i)]); end + @test has_all_escape(result.state[Argument(2)]) + end + # global escape through `setfield!` + let result = code_escapes((Any,)) do a + r = SafeRef{Any}(:init) + global GV = r + r[] = a + nothing + end + i = only(findall(isnew, result.ir.stmts.inst)) + @test has_all_escape(result.state[SSAValue(i)]) + @test has_all_escape(result.state[Argument(2)]) + end + let result = code_escapes((Any,Any)) do a, b + r = SafeRef{Any}(a) + global GV = r + r[] = b + nothing + end + i = only(findall(isnew, result.ir.stmts.inst)) + @test has_all_escape(result.state[SSAValue(i)]) + @test has_all_escape(result.state[Argument(2)]) # a + @test has_all_escape(result.state[Argument(3)]) # b + end + let result = @eval EATModule() begin + const Rx = SafeRef{String}("Rx") + $code_escapes((String,)) do s + Rx[] = s + Core.sizeof(Rx[]) + end + end + @test has_all_escape(result.state[Argument(2)]) + end + let result = @eval EATModule() begin + const Rx = SafeRef{String}("Rx") + $code_escapes((String,)) do s + setfield!(Rx, :x, s) + Core.sizeof(Rx[]) + end + end + @test has_all_escape(result.state[Argument(2)]) + end + let M = EATModule() + @eval M module ___xxx___ + import ..SafeRef + const Rx = SafeRef("Rx") + end + result = @eval M begin + $code_escapes((String,)) do s + rx = getfield(___xxx___, :Rx) + rx[] = s + nothing + end + end + @test has_all_escape(result.state[Argument(2)]) + end + + # field escape + # ------------ + + # field escape should propagate to :new arguments + let result = code_escapes((String,)) do a + o = SafeRef(a) + f = o[] + return f + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) + @test is_load_forwardable(result.state[SSAValue(i)]) + end + let result = code_escapes((String,)) do a + t = (a,) + f = t[1] + return f + end + i = only(findall(iscall((result.ir, tuple)), result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) + @test is_load_forwardable(result.state[SSAValue(i)]) + end + let result = code_escapes((String, String)) do a, b + obj = SafeRefs(a, b) + fld1 = obj[1] + fld2 = obj[2] + return (fld1, fld2) + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) # a + @test has_return_escape(result.state[Argument(3)], r) # b + @test is_load_forwardable(result.state[SSAValue(i)]) + end + + # field escape should propagate to `setfield!` argument + let result = code_escapes((String,)) do a + o = SafeRef("foo") + o[] = a + f = o[] + return f + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) + @test is_load_forwardable(result.state[SSAValue(i)]) + end + # propagate escape information imposed on return value of `setfield!` call + let result = code_escapes((String,)) do a + obj = SafeRef("foo") + return (obj[] = a) + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) + @test is_load_forwardable(result.state[SSAValue(i)]) + end + + # nested allocations + let result = code_escapes((String,)) do a + o1 = SafeRef(a) + o2 = SafeRef(o1) + return o2[] + end + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) + for i in 1:length(result.ir.stmts) + if isnew(result.ir.stmts.inst[i]) && isT(SafeRef{String})(result.ir.stmts.type[i]) + @test has_return_escape(result.state[SSAValue(i)], r) + elseif isnew(result.ir.stmts.inst[i]) && isT(SafeRef{SafeRef{String}})(result.ir.stmts.type[i]) + @test is_load_forwardable(result.state[SSAValue(i)]) + end + end + end + let result = code_escapes((String,)) do a + o1 = (a,) + o2 = (o1,) + return o2[1] + end + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) + for i in 1:length(result.ir.stmts) + if isnew(result.ir.stmts.inst[i]) && isT(Tuple{String})(result.ir.stmts.type[i]) + @test has_return_escape(result.state[SSAValue(i)], r) + elseif isnew(result.ir.stmts.inst[i]) && isT(Tuple{Tuple{String}})(result.ir.stmts.type[i]) + @test is_load_forwardable(result.state[SSAValue(i)]) + end + end + end + let result = code_escapes((String,)) do a + o1 = SafeRef(a) + o2 = SafeRef(o1) + o1′ = o2[] + a′ = o1′[] + return a′ + end + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) + for i in findall(isnew, result.ir.stmts.inst) + @test is_load_forwardable(result.state[SSAValue(i)]) + end + end + let result = code_escapes() do + o1 = SafeRef("foo") + o2 = SafeRef(o1) + return o2 + end + r = only(findall(isreturn, result.ir.stmts.inst)) + for i in findall(isnew, result.ir.stmts.inst) + @test has_return_escape(result.state[SSAValue(i)], r) + end + end + let result = code_escapes() do + o1 = SafeRef("foo") + o2′ = SafeRef(nothing) + o2 = SafeRef{SafeRef}(o2′) + o2[] = o1 + return o2 + end + r = only(findall(isreturn, result.ir.stmts.inst)) + findall(1:length(result.ir.stmts)) do i + if isnew(result.ir.stmts[i][:inst]) + t = result.ir.stmts[i][:type] + return t === SafeRef{String} || # o1 + t === SafeRef{SafeRef} # o2 + end + return false + end |> x->foreach(x) do i + @test has_return_escape(result.state[SSAValue(i)], r) + end + end + let result = code_escapes((String,)) do x + broadcast(identity, Ref(x)) + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) + @test is_load_forwardable(result.state[SSAValue(i)]) + end + + # ϕ-node allocations + let result = code_escapes((Bool,Any,Any)) do cond, x, y + if cond + ϕ = SafeRef{Any}(x) + else + ϕ = SafeRef{Any}(y) + end + return ϕ[] + end + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(3)], r) # x + @test has_return_escape(result.state[Argument(4)], r) # y + i = only(findall(isϕ, result.ir.stmts.inst)) + @test is_load_forwardable(result.state[SSAValue(i)]) + for i in findall(isnew, result.ir.stmts.inst) + @test is_load_forwardable(result.state[SSAValue(i)]) + end + end + let result = code_escapes((Bool,Any,Any)) do cond, x, y + if cond + ϕ2 = ϕ1 = SafeRef{Any}(x) + else + ϕ2 = ϕ1 = SafeRef{Any}(y) + end + return ϕ1[], ϕ2[] + end + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(3)], r) # x + @test has_return_escape(result.state[Argument(4)], r) # y + for i in findall(isϕ, result.ir.stmts.inst) + @test is_load_forwardable(result.state[SSAValue(i)]) + end + for i in findall(isnew, result.ir.stmts.inst) + @test is_load_forwardable(result.state[SSAValue(i)]) + end + end + # when ϕ-node merges values with different types + let result = code_escapes((Bool,String,String,String)) do cond, x, y, z + local out + if cond + ϕ = SafeRef(x) + out = ϕ[] + else + ϕ = SafeRefs(z, y) + end + return @isdefined(out) ? out : throw(ϕ) + end + r = only(findall(isreturn, result.ir.stmts.inst)) + t = only(findall(iscall((result.ir, throw)), result.ir.stmts.inst)) + ϕ = only(findall(isT(Union{SafeRef{String},SafeRefs{String,String}}), result.ir.stmts.type)) + @test has_return_escape(result.state[Argument(3)], r) # x + @test !has_return_escape(result.state[Argument(4)], r) # y + @test has_return_escape(result.state[Argument(5)], r) # z + @test has_thrown_escape(result.state[SSAValue(ϕ)], t) + end + + # alias analysis + # -------------- + + # alias via getfield & Expr(:new) + let result = code_escapes((String,)) do s + r = SafeRef(s) + return r[] + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + val = (result.ir.stmts.inst[r]::ReturnNode).val::SSAValue + @test isaliased(Argument(2), val, result.state) + @test !isaliased(Argument(2), SSAValue(i), result.state) + end + let result = code_escapes((String,)) do s + r1 = SafeRef(s) + r2 = SafeRef(r1) + return r2[] + end + i1, i2 = findall(isnew, result.ir.stmts.inst) + r = only(findall(isreturn, result.ir.stmts.inst)) + val = (result.ir.stmts.inst[r]::ReturnNode).val::SSAValue + @test !isaliased(SSAValue(i1), SSAValue(i2), result.state) + @test isaliased(SSAValue(i1), val, result.state) + @test !isaliased(SSAValue(i2), val, result.state) + end + let result = code_escapes((String,)) do s + r1 = SafeRef(s) + r2 = SafeRef(r1) + return r2[][] + end + r = only(findall(isreturn, result.ir.stmts.inst)) + val = (result.ir.stmts.inst[r]::ReturnNode).val::SSAValue + @test isaliased(Argument(2), val, result.state) + for i in findall(isnew, result.ir.stmts.inst) + @test !isaliased(SSAValue(i), val, result.state) + end + end + let result = @eval EATModule() begin + const Rx = SafeRef("Rx") + $code_escapes((String,)) do s + r = SafeRef(Rx) + rx = r[] # rx aliased to Rx + rx[] = s + nothing + end + end + i = findfirst(isnew, result.ir.stmts.inst) + @test has_all_escape(result.state[Argument(2)]) + @test is_load_forwardable(result.state[SSAValue(i)]) + end + # alias via getfield & setfield! + let result = code_escapes((String,)) do s + r = Ref{String}() + r[] = s + return r[] + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + val = (result.ir.stmts.inst[r]::ReturnNode).val::SSAValue + @test isaliased(Argument(2), val, result.state) + @test !isaliased(Argument(2), SSAValue(i), result.state) + end + let result = code_escapes((String,)) do s + r1 = Ref(s) + r2 = Ref{Base.RefValue{String}}() + r2[] = r1 + return r2[] + end + i1, i2 = findall(isnew, result.ir.stmts.inst) + r = only(findall(isreturn, result.ir.stmts.inst)) + val = (result.ir.stmts.inst[r]::ReturnNode).val::SSAValue + @test !isaliased(SSAValue(i1), SSAValue(i2), result.state) + @test isaliased(SSAValue(i1), val, result.state) + @test !isaliased(SSAValue(i2), val, result.state) + end + let result = code_escapes((String,)) do s + r1 = Ref{String}() + r2 = Ref{Base.RefValue{String}}() + r2[] = r1 + r1[] = s + return r2[][] + end + r = only(findall(isreturn, result.ir.stmts.inst)) + val = (result.ir.stmts.inst[r]::ReturnNode).val::SSAValue + @test isaliased(Argument(2), val, result.state) + for i in findall(isnew, result.ir.stmts.inst) + @test !isaliased(SSAValue(i), val, result.state) + end + result = code_escapes((String,)) do s + r1 = Ref{String}() + r2 = Ref{Base.RefValue{String}}() + r1[] = s + r2[] = r1 + return r2[][] + end + r = only(findall(isreturn, result.ir.stmts.inst)) + val = (result.ir.stmts.inst[r]::ReturnNode).val::SSAValue + @test isaliased(Argument(2), val, result.state) + for i in findall(isnew, result.ir.stmts.inst) + @test !isaliased(SSAValue(i), val, result.state) + end + end + let result = @eval EATModule() begin + const Rx = SafeRef("Rx") + $code_escapes((SafeRef{String}, String,)) do _rx, s + r = SafeRef(_rx) + r[] = Rx + rx = r[] # rx aliased to Rx + rx[] = s + nothing + end + end + i = findfirst(isnew, result.ir.stmts.inst) + @test has_all_escape(result.state[Argument(3)]) + @test is_load_forwardable(result.state[SSAValue(i)]) + end + # alias via typeassert + let result = code_escapes((Any,)) do a + r = a::String + return r + end + r = only(findall(isreturn, result.ir.stmts.inst)) + val = (result.ir.stmts.inst[r]::ReturnNode).val::SSAValue + @test has_return_escape(result.state[Argument(2)], r) # a + @test isaliased(Argument(2), val, result.state) # a <-> r + end + let result = code_escapes((Any,)) do a + global GV + (g::SafeRef{Any})[] = a + nothing + end + @test has_all_escape(result.state[Argument(2)]) + end + # alias via ifelse + let result = code_escapes((Bool,Any,Any)) do c, a, b + r = ifelse(c, a, b) + return r + end + r = only(findall(isreturn, result.ir.stmts.inst)) + val = (result.ir.stmts.inst[r]::ReturnNode).val::SSAValue + @test has_return_escape(result.state[Argument(3)], r) # a + @test has_return_escape(result.state[Argument(4)], r) # b + @test !isaliased(Argument(2), val, result.state) # c r + @test isaliased(Argument(3), val, result.state) # a <-> r + @test isaliased(Argument(4), val, result.state) # b <-> r + end + let result = @eval EATModule() begin + const Lx, Rx = SafeRef("Lx"), SafeRef("Rx") + $code_escapes((Bool,String,)) do c, a + r = ifelse(c, Lx, Rx) + r[] = a + nothing + end + end + @test has_all_escape(result.state[Argument(3)]) # a + end + # alias via ϕ-node + let result = code_escapes((Bool,String)) do cond, x + if cond + ϕ2 = ϕ1 = SafeRef("foo") + else + ϕ2 = ϕ1 = SafeRef("bar") + end + ϕ2[] = x + return ϕ1[] + end + r = only(findall(isreturn, result.ir.stmts.inst)) + val = (result.ir.stmts.inst[r]::ReturnNode).val::SSAValue + @test has_return_escape(result.state[Argument(3)], r) # x + @test isaliased(Argument(3), val, result.state) # x + for i in findall(isϕ, result.ir.stmts.inst) + @test is_load_forwardable(result.state[SSAValue(i)]) + end + for i in findall(isnew, result.ir.stmts.inst) + @test is_load_forwardable(result.state[SSAValue(i)]) + end + end + let result = code_escapes((Bool,Bool,String)) do cond1, cond2, x + if cond1 + ϕ2 = ϕ1 = SafeRef("foo") + else + ϕ2 = ϕ1 = SafeRef("bar") + end + cond2 && (ϕ2[] = x) + return ϕ1[] + end + r = only(findall(isreturn, result.ir.stmts.inst)) + val = (result.ir.stmts.inst[r]::ReturnNode).val::SSAValue + @test has_return_escape(result.state[Argument(4)], r) # x + @test isaliased(Argument(4), val, result.state) # x + for i in findall(isϕ, result.ir.stmts.inst) + @test is_load_forwardable(result.state[SSAValue(i)]) + end + for i in findall(isnew, result.ir.stmts.inst) + @test is_load_forwardable(result.state[SSAValue(i)]) + end + end + # alias via π-node + let result = code_escapes((Any,)) do x + if isa(x, String) + return x + end + throw("error!") + end + r = only(findall(isreturn, result.ir.stmts.inst)) + rval = (result.ir.stmts.inst[r]::ReturnNode).val::SSAValue + @test has_return_escape(result.state[Argument(2)], r) # x + @test isaliased(Argument(2), rval, result.state) + end + let result = code_escapes((String,)) do x + global GV + l = g + if isa(l, SafeRef{String}) + l[] = x + end + nothing + end + @test has_all_escape(result.state[Argument(2)]) # x + end + # circular reference + let result = code_escapes() do + x = Ref{Any}() + x[] = x + return x[] + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[SSAValue(i)], r) + end + let result = @eval Module() begin + const Rx = Ref{Any}() + Rx[] = Rx + $code_escapes() do + r = Rx[]::Base.RefValue{Any} + return r[] + end + end + r = only(findall(isreturn, result.ir.stmts.inst)) + for i in findall(iscall((result.ir, getfield)), result.ir.stmts.inst) + @test has_return_escape(result.state[SSAValue(i)], r) + end + end + let result = @eval Module() begin + @noinline function genr() + r = Ref{Any}() + r[] = r + return r + end + $code_escapes() do + x = genr() + return x[] + end + end + i = only(findall(isinvoke(:genr), result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[SSAValue(i)], r) + end + + # dynamic semantics + # ----------------- + + # conservatively handle untyped objects + let result = @eval code_escapes((Any,Any,)) do T, x + obj = $(Expr(:new, :T, :x)) + end + t = only(findall(isnew, result.ir.stmts.inst)) + @test #=T=# has_thrown_escape(result.state[Argument(2)], t) # T + @test #=x=# has_thrown_escape(result.state[Argument(3)], t) # x + end + let result = @eval code_escapes((Any,Any,Any,Any)) do T, x, y, z + obj = $(Expr(:new, :T, :x, :y)) + return getfield(obj, :x) + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test #=x=# has_return_escape(result.state[Argument(3)], r) + @test #=y=# has_return_escape(result.state[Argument(4)], r) + @test #=z=# !has_return_escape(result.state[Argument(5)], r) + end + let result = @eval code_escapes((Any,Any,Any,Any)) do T, x, y, z + obj = $(Expr(:new, :T, :x)) + setfield!(obj, :x, y) + return getfield(obj, :x) + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test #=x=# has_return_escape(result.state[Argument(3)], r) + @test #=y=# has_return_escape(result.state[Argument(4)], r) + @test #=z=# !has_return_escape(result.state[Argument(5)], r) + end + + # conservatively handle unknown field: + # all fields should be escaped, but the allocation itself doesn't need to be escaped + let result = code_escapes((String, Symbol)) do a, fld + obj = SafeRef(a) + return getfield(obj, fld) + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) # a + @test !is_load_forwardable(result.state[SSAValue(i)]) # obj + end + let result = code_escapes((String, String, Symbol)) do a, b, fld + obj = SafeRefs(a, b) + return getfield(obj, fld) # should escape both `a` and `b` + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) # a + @test has_return_escape(result.state[Argument(3)], r) # b + @test !is_load_forwardable(result.state[SSAValue(i)]) # obj + end + let result = code_escapes((String, String, Int)) do a, b, idx + obj = SafeRefs(a, b) + return obj[idx] # should escape both `a` and `b` + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) # a + @test has_return_escape(result.state[Argument(3)], r) # b + @test !is_load_forwardable(result.state[SSAValue(i)]) # obj + end + let result = code_escapes((String, String, Symbol)) do a, b, fld + obj = SafeRefs("a", "b") + setfield!(obj, fld, a) + return obj[2] # should escape `a` + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) # a + @test !has_return_escape(result.state[Argument(3)], r) # b + @test !is_load_forwardable(result.state[SSAValue(i)]) # obj + end + let result = code_escapes((String, Symbol)) do a, fld + obj = SafeRefs("a", "b") + setfield!(obj, fld, a) + return obj[1] # this should escape `a` + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) # a + @test !is_load_forwardable(result.state[SSAValue(i)]) # obj + end + let result = code_escapes((String, String, Int)) do a, b, idx + obj = SafeRefs("a", "b") + obj[idx] = a + return obj[2] # should escape `a` + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) # a + @test !has_return_escape(result.state[Argument(3)], r) # b + @test !is_load_forwardable(result.state[SSAValue(i)]) # obj + end + + # interprocedural + # --------------- + + let result = @eval EATModule() begin + @noinline getx(obj) = obj[] + $code_escapes((String,)) do a + obj = SafeRef(a) + fld = getx(obj) + return fld + end + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) + # NOTE we can't scalar replace `obj`, but still we may want to stack allocate it + @test_broken is_load_forwardable(result.state[SSAValue(i)]) + end + + # TODO interprocedural alias analysis + let result = code_escapes((SafeRef{String},)) do s + s[] = "bar" + global GV = s[] + nothing + end + @test_broken !has_all_escape(result.state[Argument(2)]) + end + + # aliasing between arguments + let result = @eval EATModule() begin + @noinline setxy!(x, y) = x[] = y + $code_escapes((String,)) do y + x = SafeRef("init") + setxy!(x, y) + return x + end + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[SSAValue(i)], r) + @test has_return_escape(result.state[Argument(2)], r) # y + end + let result = @eval EATModule() begin + @noinline setxy!(x, y) = x[] = y + $code_escapes((String,)) do y + x1 = SafeRef("init") + x2 = SafeRef(y) + setxy!(x1, x2[]) + return x1 + end + end + i1, i2 = findall(isnew, result.ir.stmts.inst) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[SSAValue(i1)], r) + @test !has_return_escape(result.state[SSAValue(i2)], r) + @test has_return_escape(result.state[Argument(2)], r) # y + end + let result = @eval EATModule() begin + @noinline mysetindex!(x, a) = x[1] = a + const Ax = Vector{Any}(undef, 1) + $code_escapes((String,)) do s + mysetindex!(Ax, s) + end + end + @test has_all_escape(result.state[Argument(2)]) # s + end + + # TODO flow-sensitivity? + # ---------------------- + + let result = code_escapes((Any,Any)) do a, b + r = SafeRef{Any}(a) + r[] = b + return r[] + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test_broken !has_return_escape(result.state[Argument(2)], r) # a + @test has_return_escape(result.state[Argument(3)], r) # b + @test is_load_forwardable(result.state[SSAValue(i)]) + end + let result = code_escapes((Any,Any)) do a, b + r = SafeRef{Any}(:init) + r[] = a + r[] = b + return r[] + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test_broken !has_return_escape(result.state[Argument(2)], r) # a + @test has_return_escape(result.state[Argument(3)], r) # b + @test is_load_forwardable(result.state[SSAValue(i)]) + end + let result = code_escapes((Any,Any,Bool)) do a, b, cond + r = SafeRef{Any}(:init) + if cond + r[] = a + return r[] + else + r[] = b + return nothing + end + end + i = only(findall(isnew, result.ir.stmts.inst)) + @test is_load_forwardable(result.state[SSAValue(i)]) + r = only(findall(result.ir.stmts.inst) do @nospecialize x + isreturn(x) && isa(x.val, Core.SSAValue) + end) + @test has_return_escape(result.state[Argument(2)], r) # a + @test_broken !has_return_escape(result.state[Argument(3)], r) # b + end + + # handle conflicting field information correctly + let result = code_escapes((Bool,String,String,)) do cnd, baz, qux + if cnd + o = SafeRef("foo") + else + o = SafeRefs("bar", baz) + r = getfield(o, 2) + end + if cnd + o = o::SafeRef + setfield!(o, 1, qux) + r = getfield(o, 1) + end + r + end + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(3)], r) # baz + @test has_return_escape(result.state[Argument(4)], r) # qux + for new in findall(isnew, result.ir.stmts.inst) + @test is_load_forwardable(result.state[SSAValue(new)]) + end + end + let result = code_escapes((Bool,String,String,)) do cnd, baz, qux + if cnd + o = SafeRefs("foo", "bar") + r = setfield!(o, 2, baz) + else + o = SafeRef(qux) + end + if !cnd + o = o::SafeRef + r = getfield(o, 1) + end + r + end + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(3)], r) # baz + @test has_return_escape(result.state[Argument(4)], r) # qux + end + + # foreigncall should disable field analysis + let result = code_escapes((Any,Nothing,Int,UInt)) do t, mt, lim, world + ambig = false + min = Ref{UInt}(typemin(UInt)) + max = Ref{UInt}(typemax(UInt)) + has_ambig = Ref{Int32}(0) + mt = ccall(:jl_matching_methods, Any, + (Any, Any, Cint, Cint, UInt, Ptr{UInt}, Ptr{UInt}, Ref{Int32}), + t, mt, lim, ambig, world, min, max, has_ambig)::Union{Array{Any,1}, Bool} + return mt, has_ambig[] + end + for i in findall(isnew, result.ir.stmts.inst) + @test !is_load_forwardable(result.state[SSAValue(i)]) + end + end +end + +# demonstrate the power of our field / alias analysis with a realistic end to end example +abstract type AbstractPoint{T} end +mutable struct MPoint{T} <: AbstractPoint{T} + x::T + y::T +end +add(a::P, b::P) where P<:AbstractPoint = P(a.x + b.x, a.y + b.y) +function compute(T, ax, ay, bx, by) + a = T(ax, ay) + b = T(bx, by) + for i in 0:(100000000-1) + a = add(add(a, b), b) + end + a.x, a.y +end +function compute(a, b) + for i in 0:(100000000-1) + a = add(add(a, b), b) # unreplaceable, since it can be the call argument + end + a.x, a.y +end +function compute!(a, b) + for i in 0:(100000000-1) + a′ = add(add(a, b), b) + a.x = a′.x + a.y = a′.y + end +end +let result = @code_escapes compute(MPoint, 1+.5im, 2+.5im, 2+.25im, 4+.75im) + for i in findall(isnew, result.ir.stmts.inst) + @test is_load_forwardable(result.state[SSAValue(i)]) + end + for i in findall(isϕ, result.ir.stmts.inst) + @test is_load_forwardable(result.state[SSAValue(i)]) + end +end +let result = @code_escapes compute(MPoint(1+.5im, 2+.5im), MPoint(2+.25im, 4+.75im)) + for i in findall(1:length(result.ir.stmts)) do i + isϕ(result.ir.stmts[i][:inst]) && isT(MPoint{ComplexF64})(result.ir.stmts[i][:type]) + end + @test !is_load_forwardable(result.state[SSAValue(i)]) + end +end +let result = @code_escapes compute!(MPoint(1+.5im, 2+.5im), MPoint(2+.25im, 4+.75im)) + # FIXME with a proper interprocedural alias analysis + # for i in findall(isnew, result.ir.stmts.inst) + # @test is_load_forwardable(result.state[SSAValue(i)]) + # end + for i in findall(isϕ, result.ir.stmts.inst) + @test is_load_forwardable(result.state[SSAValue(i)]) + end +end + +@testset "array primitives" begin + inbounds = Base.JLOptions().check_bounds == 0 + + # arrayref + let result = code_escapes((Vector{String},Int)) do xs, i + s = Base.arrayref(true, xs, i) + return s + end + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) # xs + @test has_thrown_escape(result.state[Argument(2)]) # xs + @test !has_return_escape(result.state[Argument(3)], r) # i + end + let result = code_escapes((Vector{String},Int)) do xs, i + s = Base.arrayref(false, xs, i) + return s + end + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) # xs + @test !has_thrown_escape(result.state[Argument(2)]) # xs + @test !has_return_escape(result.state[Argument(3)], r) # i + end + inbounds && let result = code_escapes((Vector{String},Int)) do xs, i + s = @inbounds xs[i] + return s + end + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) # xs + @test !has_thrown_escape(result.state[Argument(2)]) # xs + @test !has_return_escape(result.state[Argument(3)], r) # i + end + let result = code_escapes((Vector{String},Bool)) do xs, i + c = Base.arrayref(true, xs, i) # TypeError will happen here + return c + end + t = only(findall(iscall((result.ir, Base.arrayref)), result.ir.stmts.inst)) + @test has_thrown_escape(result.state[Argument(2)], t) # xs + end + let result = code_escapes((String,Int)) do xs, i + c = Base.arrayref(true, xs, i) # TypeError will happen here + return c + end + t = only(findall(iscall((result.ir, Base.arrayref)), result.ir.stmts.inst)) + @test has_thrown_escape(result.state[Argument(2)], t) # xs + end + let result = code_escapes((AbstractVector{String},Int)) do xs, i + c = Base.arrayref(true, xs, i) # TypeError may happen here + return c + end + t = only(findall(iscall((result.ir, Base.arrayref)), result.ir.stmts.inst)) + @test has_thrown_escape(result.state[Argument(2)], t) # xs + end + let result = code_escapes((Vector{String},Any)) do xs, i + c = Base.arrayref(true, xs, i) # TypeError may happen here + return c + end + t = only(findall(iscall((result.ir, Base.arrayref)), result.ir.stmts.inst)) + @test has_thrown_escape(result.state[Argument(2)], t) # xs + end + + # arrayset + let result = code_escapes((Vector{String},String,Int,)) do xs, x, i + Base.arrayset(true, xs, x, i) + return xs + end + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) # xs + @test has_thrown_escape(result.state[Argument(2)]) # xs + @test has_return_escape(result.state[Argument(3)], r) # x + end + let result = code_escapes((Vector{String},String,Int,)) do xs, x, i + Base.arrayset(false, xs, x, i) + return xs + end + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) # xs + @test !has_thrown_escape(result.state[Argument(2)]) # xs + @test has_return_escape(result.state[Argument(3)], r) # x + end + inbounds && let result = code_escapes((Vector{String},String,Int,)) do xs, x, i + @inbounds xs[i] = x + return xs + end + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) # xs + @test !has_thrown_escape(result.state[Argument(2)]) # xs + @test has_return_escape(result.state[Argument(3)], r) # x + end + let result = code_escapes((String,String,String,)) do s, t, u + xs = Vector{String}(undef, 3) + Base.arrayset(true, xs, s, 1) + Base.arrayset(true, xs, t, 2) + Base.arrayset(true, xs, u, 3) + return xs + end + i = only(findall(isarrayalloc, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[SSAValue(i)], r) + for i in 2:result.state.nargs + @test has_return_escape(result.state[Argument(i)], r) + end + end + let result = code_escapes((Vector{String},String,Bool,)) do xs, x, i + Base.arrayset(true, xs, x, i) # TypeError will happen here + return xs + end + t = only(findall(iscall((result.ir, Base.arrayset)), result.ir.stmts.inst)) + @test has_thrown_escape(result.state[Argument(2)], t) # xs + @test has_thrown_escape(result.state[Argument(3)], t) # x + end + let result = code_escapes((String,String,Int,)) do xs, x, i + Base.arrayset(true, xs, x, i) # TypeError will happen here + return xs + end + t = only(findall(iscall((result.ir, Base.arrayset)), result.ir.stmts.inst)) + @test has_thrown_escape(result.state[Argument(2)], t) # xs::String + @test has_thrown_escape(result.state[Argument(3)], t) # x::String + end + let result = code_escapes((AbstractVector{String},String,Int,)) do xs, x, i + Base.arrayset(true, xs, x, i) # TypeError may happen here + return xs + end + t = only(findall(iscall((result.ir, Base.arrayset)), result.ir.stmts.inst)) + @test has_thrown_escape(result.state[Argument(2)], t) # xs + @test has_thrown_escape(result.state[Argument(3)], t) # x + end + let result = code_escapes((Vector{String},AbstractString,Int,)) do xs, x, i + Base.arrayset(true, xs, x, i) # TypeError may happen here + return xs + end + t = only(findall(iscall((result.ir, Base.arrayset)), result.ir.stmts.inst)) + @test has_thrown_escape(result.state[Argument(2)], t) # xs + @test has_thrown_escape(result.state[Argument(3)], t) # x + end + + # arrayref and arrayset + let result = code_escapes() do + a = Vector{Vector{Any}}(undef, 1) + b = Any[] + a[1] = b + return a[1] + end + r = only(findall(isreturn, result.ir.stmts.inst)) + ai = only(findall(result.ir.stmts.inst) do @nospecialize x + isarrayalloc(x) && x.args[2] === Vector{Vector{Any}} + end) + bi = only(findall(result.ir.stmts.inst) do @nospecialize x + isarrayalloc(x) && x.args[2] === Vector{Any} + end) + @test !has_return_escape(result.state[SSAValue(ai)], r) + @test has_return_escape(result.state[SSAValue(bi)], r) + end + let result = code_escapes() do + a = Vector{Vector{Any}}(undef, 1) + b = Any[] + a[1] = b + return a + end + r = only(findall(isreturn, result.ir.stmts.inst)) + ai = only(findall(result.ir.stmts.inst) do @nospecialize x + isarrayalloc(x) && x.args[2] === Vector{Vector{Any}} + end) + bi = only(findall(result.ir.stmts.inst) do @nospecialize x + isarrayalloc(x) && x.args[2] === Vector{Any} + end) + @test has_return_escape(result.state[SSAValue(ai)], r) + @test has_return_escape(result.state[SSAValue(bi)], r) + end + let result = code_escapes((Vector{Any},String,Int,Int)) do xs, s, i, j + x = SafeRef(s) + xs[i] = x + xs[j] # potential error + end + i = only(findall(isnew, result.ir.stmts.inst)) + t = only(findall(iscall((result.ir, Base.arrayref)), result.ir.stmts.inst)) + @test has_thrown_escape(result.state[Argument(3)], t) # s + @test has_thrown_escape(result.state[SSAValue(i)], t) # x + end + + # arraysize + let result = code_escapes((Vector{Any},)) do xs + Core.arraysize(xs, 1) + end + t = only(findall(iscall((result.ir, Core.arraysize)), result.ir.stmts.inst)) + @test !has_thrown_escape(result.state[Argument(2)], t) + end + let result = code_escapes((Vector{Any},Int,)) do xs, dim + Core.arraysize(xs, dim) + end + t = only(findall(iscall((result.ir, Core.arraysize)), result.ir.stmts.inst)) + @test !has_thrown_escape(result.state[Argument(2)], t) + end + let result = code_escapes((Any,)) do xs + Core.arraysize(xs, 1) + end + t = only(findall(iscall((result.ir, Core.arraysize)), result.ir.stmts.inst)) + @test has_thrown_escape(result.state[Argument(2)], t) + end + + # arraylen + let result = code_escapes((Vector{Any},)) do xs + Base.arraylen(xs) + end + t = only(findall(iscall((result.ir, Base.arraylen)), result.ir.stmts.inst)) + @test !has_thrown_escape(result.state[Argument(2)], t) # xs + end + let result = code_escapes((String,)) do xs + Base.arraylen(xs) + end + t = only(findall(iscall((result.ir, Base.arraylen)), result.ir.stmts.inst)) + @test has_thrown_escape(result.state[Argument(2)], t) # xs + end + let result = code_escapes((Vector{Any},)) do xs + Base.arraylen(xs, 1) + end + t = only(findall(iscall((result.ir, Base.arraylen)), result.ir.stmts.inst)) + @test has_thrown_escape(result.state[Argument(2)], t) # xs + end + + # array resizing + # without BoundsErrors + let result = code_escapes((Vector{Any},String)) do xs, x + @ccall jl_array_grow_beg(xs::Any, 2::UInt)::Cvoid + xs[1] = x + xs + end + t = only(findall(isarrayresize, result.ir.stmts.inst)) + @test !has_thrown_escape(result.state[Argument(2)], t) # xs + @test !has_thrown_escape(result.state[Argument(3)], t) # x + end + let result = code_escapes((Vector{Any},String)) do xs, x + @ccall jl_array_grow_end(xs::Any, 2::UInt)::Cvoid + xs[1] = x + xs + end + t = only(findall(isarrayresize, result.ir.stmts.inst)) + @test !has_thrown_escape(result.state[Argument(2)], t) # xs + @test !has_thrown_escape(result.state[Argument(3)], t) # x + end + # with possible BoundsErrors + let result = code_escapes((String,)) do x + xs = Any[1,2,3] + xs[3] = x + @ccall jl_array_del_beg(xs::Any, 2::UInt)::Cvoid # can potentially throw + xs + end + i = only(findall(isarrayalloc, result.ir.stmts.inst)) + t = only(findall(isarrayresize, result.ir.stmts.inst)) + @test has_thrown_escape(result.state[SSAValue(i)], t) # xs + @test has_thrown_escape(result.state[Argument(2)], t) # x + end + let result = code_escapes((String,)) do x + xs = Any[1,2,3] + xs[1] = x + @ccall jl_array_del_end(xs::Any, 2::UInt)::Cvoid # can potentially throw + xs + end + i = only(findall(isarrayalloc, result.ir.stmts.inst)) + t = only(findall(isarrayresize, result.ir.stmts.inst)) + @test has_thrown_escape(result.state[SSAValue(i)], t) # xs + @test has_thrown_escape(result.state[Argument(2)], t) # x + end + let result = code_escapes((String,)) do x + xs = Any[x] + @ccall jl_array_grow_at(xs::Any, 1::UInt, 2::UInt)::Cvoid # can potentially throw + end + i = only(findall(isarrayalloc, result.ir.stmts.inst)) + t = only(findall(isarrayresize, result.ir.stmts.inst)) + @test has_thrown_escape(result.state[SSAValue(i)], t) # xs + @test has_thrown_escape(result.state[Argument(2)], t) # x + end + let result = code_escapes((String,)) do x + xs = Any[x] + @ccall jl_array_del_at(xs::Any, 1::UInt, 2::UInt)::Cvoid # can potentially throw + end + i = only(findall(isarrayalloc, result.ir.stmts.inst)) + t = only(findall(isarrayresize, result.ir.stmts.inst)) + @test has_thrown_escape(result.state[SSAValue(i)], t) # xs + @test has_thrown_escape(result.state[Argument(2)], t) # x + end + inbounds && let result = code_escapes((String,)) do x + xs = @inbounds Any[x] + @ccall jl_array_del_at(xs::Any, 1::UInt, 2::UInt)::Cvoid # can potentially throw + end + i = only(findall(isarrayalloc, result.ir.stmts.inst)) + t = only(findall(isarrayresize, result.ir.stmts.inst)) + @test has_thrown_escape(result.state[SSAValue(i)], t) # xs + @test has_thrown_escape(result.state[Argument(2)], t) # x + end + + # array copy + let result = code_escapes((Vector{Any},)) do xs + return copy(xs) + end + i = only(findall(isarraycopy, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[SSAValue(i)], r) + @test_broken !has_return_escape(result.state[Argument(2)], r) + end + let result = code_escapes((String,)) do s + xs = String[s] + xs′ = copy(xs) + return xs′[1] + end + i1 = only(findall(isarrayalloc, result.ir.stmts.inst)) + i2 = only(findall(isarraycopy, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test !has_return_escape(result.state[SSAValue(i1)]) + @test !has_return_escape(result.state[SSAValue(i2)]) + @test has_return_escape(result.state[Argument(2)], r) # s + end + let result = code_escapes((Vector{Any},)) do xs + xs′ = copy(xs) + return xs′[1] # may potentially throw BoundsError, should escape `xs` conservatively (i.e. escape its elements) + end + i = only(findall(isarraycopy, result.ir.stmts.inst)) + ref = only(findall(iscall((result.ir, Base.arrayref)), result.ir.stmts.inst)) + ret = only(findall(isreturn, result.ir.stmts.inst)) + @test_broken !has_thrown_escape(result.state[SSAValue(i)], ref) + @test_broken !has_return_escape(result.state[SSAValue(i)], ret) + @test has_thrown_escape(result.state[Argument(2)], ref) + @test has_return_escape(result.state[Argument(2)], ret) + end + let result = code_escapes((String,)) do s + xs = Vector{String}(undef, 1) + xs[1] = s + xs′ = copy(xs) + length(xs′) > 2 && throw(xs′) + return xs′ + end + i1 = only(findall(isarrayalloc, result.ir.stmts.inst)) + i2 = only(findall(isarraycopy, result.ir.stmts.inst)) + t = only(findall(iscall((result.ir, throw)), result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test_broken !has_thrown_escape(result.state[SSAValue(i1)], t) + @test_broken !has_return_escape(result.state[SSAValue(i1)], r) + @test has_thrown_escape(result.state[SSAValue(i2)], t) + @test has_return_escape(result.state[SSAValue(i2)], r) + @test has_thrown_escape(result.state[Argument(2)], t) + @test has_return_escape(result.state[Argument(2)], r) + end + + # isassigned + let result = code_escapes((Vector{Any},Int)) do xs, i + return isassigned(xs, i) + end + r = only(findall(isreturn, result.ir.stmts.inst)) + @test !has_return_escape(result.state[Argument(2)], r) + @test !has_thrown_escape(result.state[Argument(2)]) + end + + # indexing analysis + # ----------------- + + # safe case + let result = code_escapes((String,String)) do s, t + a = Vector{Any}(undef, 2) + a[1] = s + a[2] = t + return a[1] + end + r = only(findall(isreturn, result.ir.stmts.inst)) + i = only(findall(isarrayalloc, result.ir.stmts.inst)) + @test !has_return_escape(result.state[SSAValue(i)], r) + @test is_load_forwardable(result.state[SSAValue(i)]) + @test has_return_escape(result.state[Argument(2)], r) # s + @test !has_return_escape(result.state[Argument(3)], r) # t + end + let result = code_escapes((String,String)) do s, t + a = Matrix{Any}(undef, 1, 2) + a[1, 1] = s + a[1, 2] = t + return a[1, 1] + end + r = only(findall(isreturn, result.ir.stmts.inst)) + i = only(findall(isarrayalloc, result.ir.stmts.inst)) + @test !has_return_escape(result.state[SSAValue(i)], r) + @test is_load_forwardable(result.state[SSAValue(i)]) + @test has_return_escape(result.state[Argument(2)], r) # s + @test !has_return_escape(result.state[Argument(3)], r) # t + end + let result = code_escapes((Bool,String,String,String)) do c, s, t, u + a = Vector{Any}(undef, 2) + if c + a[1] = s + a[2] = u + else + a[1] = t + a[2] = u + end + return a[1] + end + r = only(findall(isreturn, result.ir.stmts.inst)) + i = only(findall(isarrayalloc, result.ir.stmts.inst)) + @test is_load_forwardable(result.state[SSAValue(i)]) + @test !has_return_escape(result.state[SSAValue(i)], r) + @test has_return_escape(result.state[Argument(3)], r) # s + @test has_return_escape(result.state[Argument(4)], r) # t + @test !has_return_escape(result.state[Argument(5)], r) # u + end + let result = code_escapes((Bool,String,String,String)) do c, s, t, u + a = Any[nothing, nothing] # TODO how to deal with loop indexing? + if c + a[1] = s + a[2] = u + else + a[1] = t + a[2] = u + end + return a[1] + end + r = only(findall(isreturn, result.ir.stmts.inst)) + i = only(findall(isarrayalloc, result.ir.stmts.inst)) + @test !has_return_escape(result.state[SSAValue(i)], r) + @test_broken is_load_forwardable(result.state[SSAValue(i)]) + @test has_return_escape(result.state[Argument(3)], r) # s + @test has_return_escape(result.state[Argument(4)], r) # t + @test_broken !has_return_escape(result.state[Argument(5)], r) # u + end + let result = code_escapes((String,)) do s + a = Vector{Vector{Any}}(undef, 1) + b = Any[s] + a[1] = b + return a[1][1] + end + r = only(findall(isreturn, result.ir.stmts.inst)) + is = findall(isarrayalloc, result.ir.stmts.inst) + @assert length(is) == 2 + ia, ib = is + @test !has_return_escape(result.state[SSAValue(ia)], r) + @test is_load_forwardable(result.state[SSAValue(ia)]) + @test !has_return_escape(result.state[SSAValue(ib)], r) + @test_broken is_load_forwardable(result.state[SSAValue(ib)]) + @test has_return_escape(result.state[Argument(2)], r) # s + end + let result = code_escapes((Bool,String,String,Regex,Regex,)) do c, s1, s2, t1, t2 + if c + a = Vector{String}(undef, 2) + a[1] = s1 + a[2] = s2 + else + a = Vector{Regex}(undef, 2) + a[1] = t1 + a[2] = t2 + end + return a[1] + end + r = only(findall(isreturn, result.ir.stmts.inst)) + for i in findall(isarrayalloc, result.ir.stmts.inst) + @test !has_return_escape(result.state[SSAValue(i)], r) + @test is_load_forwardable(result.state[SSAValue(i)]) + end + @test has_return_escape(result.state[Argument(3)], r) # s1 + @test !has_return_escape(result.state[Argument(4)], r) # s2 + @test has_return_escape(result.state[Argument(5)], r) # t1 + @test !has_return_escape(result.state[Argument(6)], r) # t2 + end + let result = code_escapes((String,String,Int)) do s, t, i + a = Any[s] + push!(a, t) + return a[2] + end + r = only(findall(isreturn, result.ir.stmts.inst)) + i = only(findall(isarrayalloc, result.ir.stmts.inst)) + @test !has_return_escape(result.state[SSAValue(i)], r) + @test_broken is_load_forwardable(result.state[SSAValue(i)]) + @test_broken !has_return_escape(result.state[Argument(2)], r) # s + @test has_return_escape(result.state[Argument(3)], r) # t + end + # unsafe cases + let result = code_escapes((String,String,Int)) do s, t, i + a = Vector{Any}(undef, 2) + a[1] = s + a[2] = t + return a[i] + end + r = only(findall(isreturn, result.ir.stmts.inst)) + i = only(findall(isarrayalloc, result.ir.stmts.inst)) + @test !has_return_escape(result.state[SSAValue(i)], r) + @test !is_load_forwardable(result.state[SSAValue(i)]) + @test has_return_escape(result.state[Argument(2)], r) # s + @test has_return_escape(result.state[Argument(3)], r) # t + end + let result = code_escapes((String,String,Int)) do s, t, i + a = Vector{Any}(undef, 2) + a[1] = s + a[i] = t + return a[1] + end + r = only(findall(isreturn, result.ir.stmts.inst)) + i = only(findall(isarrayalloc, result.ir.stmts.inst)) + @test !has_return_escape(result.state[SSAValue(i)], r) + @test !is_load_forwardable(result.state[SSAValue(i)]) + @test has_return_escape(result.state[Argument(2)], r) # s + @test has_return_escape(result.state[Argument(3)], r) # t + end + let result = code_escapes((String,String,Int,Int,Int)) do s, t, i, j, k + a = Vector{Any}(undef, 2) + a[3] = s # BoundsError + a[1] = t + return a[1] + end + r = only(findall(isreturn, result.ir.stmts.inst)) + i = only(findall(isarrayalloc, result.ir.stmts.inst)) + @test !has_return_escape(result.state[SSAValue(i)], r) + @test !is_load_forwardable(result.state[SSAValue(i)]) + end + let result = @eval Module() begin + @noinline some_resize!(a) = pushfirst!(a, nothing) + $code_escapes((String,String,Int)) do s, t, i + a = Vector{Any}(undef, 2) + a[1] = s + some_resize!(a) + return a[2] + end + end + r = only(findall(isreturn, result.ir.stmts.inst)) + i = only(findall(isarrayalloc, result.ir.stmts.inst)) + @test_broken !has_return_escape(result.state[SSAValue(i)], r) + @test !is_load_forwardable(result.state[SSAValue(i)]) + end + + # circular reference + let result = code_escapes() do + xs = Vector{Any}(undef, 1) + xs[1] = xs + return xs[1] + end + i = only(findall(isarrayalloc, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[SSAValue(i)], r) + end + let result = @eval Module() begin + const Ax = Vector{Any}(undef, 1) + Ax[1] = Ax + $code_escapes() do + xs = Ax[1]::Vector{Any} + return xs[1] + end + end + r = only(findall(isreturn, result.ir.stmts.inst)) + for i in findall(iscall((result.ir, Core.arrayref)), result.ir.stmts.inst) + @test has_return_escape(result.state[SSAValue(i)], r) + end + end + let result = @eval Module() begin + @noinline function genxs() + xs = Vector{Any}(undef, 1) + xs[1] = xs + return xs + end + $code_escapes() do + xs = genxs() + return xs[1] + end + end + i = only(findall(isinvoke(:genxs), result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[SSAValue(i)], r) + end +end + +# demonstrate array primitive support with a realistic end to end example +let result = code_escapes((Int,String,)) do n,s + xs = String[] + for i in 1:n + push!(xs, s) + end + xs + end + i = only(findall(isarrayalloc, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[SSAValue(i)], r) + Base.JLOptions().check_bounds ≠ 0 && @test has_thrown_escape(result.state[SSAValue(i)]) + @test has_return_escape(result.state[Argument(3)], r) # s + Base.JLOptions().check_bounds ≠ 0 && @test has_thrown_escape(result.state[Argument(3)]) # s +end +let result = code_escapes((Int,String,)) do n,s + xs = String[] + for i in 1:n + pushfirst!(xs, s) + end + xs + end + i = only(findall(isarrayalloc, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[SSAValue(i)], r) # xs + @test has_thrown_escape(result.state[SSAValue(i)]) # xs + @test has_return_escape(result.state[Argument(3)], r) # s + @test has_thrown_escape(result.state[Argument(3)]) # s +end +let result = code_escapes((String,String,String)) do s, t, u + xs = String[] + resize!(xs, 3) + xs[1] = s + xs[1] = t + xs[1] = u + xs + end + i = only(findall(isarrayalloc, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[SSAValue(i)], r) + @test has_thrown_escape(result.state[SSAValue(i)]) # xs + @test has_return_escape(result.state[Argument(2)], r) # s + @test has_return_escape(result.state[Argument(3)], r) # t + @test has_return_escape(result.state[Argument(4)], r) # u +end + +@static if isdefined(Core, :ImmutableArray) + +import Core: ImmutableArray, arrayfreeze, mutating_arrayfreeze, arraythaw + +@testset "ImmutableArray" begin + # arrayfreeze + let result = code_escapes((Vector{Any},)) do xs + arrayfreeze(xs) + end + @test !has_thrown_escape(result.state[Argument(2)]) + end + let result = code_escapes((Vector,)) do xs + arrayfreeze(xs) + end + @test !has_thrown_escape(result.state[Argument(2)]) + end + let result = code_escapes((Any,)) do xs + arrayfreeze(xs) + end + @test has_thrown_escape(result.state[Argument(2)]) + end + let result = code_escapes((ImmutableArray{Any,1},)) do xs + arrayfreeze(xs) + end + @test has_thrown_escape(result.state[Argument(2)]) + end + let result = code_escapes() do + xs = Any[] + arrayfreeze(xs) + end + i = only(findall(isarrayalloc, result.ir.stmts.inst)) + @test has_no_escape(result.state[SSAValue(1)]) + end + + # mutating_arrayfreeze + let result = code_escapes((Vector{Any},)) do xs + mutating_arrayfreeze(xs) + end + @test !has_thrown_escape(result.state[Argument(2)]) + end + let result = code_escapes((Vector,)) do xs + mutating_arrayfreeze(xs) + end + @test !has_thrown_escape(result.state[Argument(2)]) + end + let result = code_escapes((Any,)) do xs + mutating_arrayfreeze(xs) + end + @test has_thrown_escape(result.state[Argument(2)]) + end + let result = code_escapes((ImmutableArray{Any,1},)) do xs + mutating_arrayfreeze(xs) + end + @test has_thrown_escape(result.state[Argument(2)]) + end + let result = code_escapes() do + xs = Any[] + mutating_arrayfreeze(xs) + end + i = only(findall(isarrayalloc, result.ir.stmts.inst)) + @test has_no_escape(result.state[SSAValue(1)]) + end + + # arraythaw + let result = code_escapes((ImmutableArray{Any,1},)) do xs + arraythaw(xs) + end + @test !has_thrown_escape(result.state[Argument(2)]) + end + let result = code_escapes((ImmutableArray,)) do xs + arraythaw(xs) + end + @test !has_thrown_escape(result.state[Argument(2)]) + end + let result = code_escapes((Any,)) do xs + arraythaw(xs) + end + @test has_thrown_escape(result.state[Argument(2)]) + end + let result = code_escapes((Vector{Any},)) do xs + arraythaw(xs) + end + @test has_thrown_escape(result.state[Argument(2)]) + end + let result = code_escapes() do + xs = ImmutableArray(Any[]) + arraythaw(xs) + end + i = only(findall(isarrayalloc, result.ir.stmts.inst)) + @test has_no_escape(result.state[SSAValue(1)]) + end +end + +# demonstrate some arrayfreeze optimizations +# !has_return_escape(ary) means ary is eligible for arrayfreeze to mutating_arrayfreeze optimization +let result = code_escapes((Int,)) do n + xs = collect(1:n) + ImmutableArray(xs) + end + i = only(findall(isarrayalloc, result.ir.stmts.inst)) + @test !has_return_escape(result.state[SSAValue(i)]) +end +let result = code_escapes((Vector{Float64},)) do xs + ys = sin.(xs) + ImmutableArray(ys) + end + i = only(findall(isarrayalloc, result.ir.stmts.inst)) + @test !has_return_escape(result.state[SSAValue(i)]) +end +let result = code_escapes((Vector{Pair{Int,String}},)) do xs + n = maximum(first, xs) + ys = Vector{String}(undef, n) + for (i, s) in xs + ys[i] = s + end + ImmutableArray(xs) + end + i = only(findall(isarrayalloc, result.ir.stmts.inst)) + @test !has_return_escape(result.state[SSAValue(i)]) +end + +end # @static if isdefined(Core, :ImmutableArray) + +# demonstrate a simple type level analysis can sometimes improve the analysis accuracy +# by compensating the lack of yet unimplemented analyses +@testset "special-casing bitstype" begin + let result = code_escapes((Nothing,)) do a + global GV = a + end + @test !(has_all_escape(result.state[Argument(2)])) + end + + let result = code_escapes((Int,)) do a + o = SafeRef(a) + f = o[] + return f + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test !has_return_escape(result.state[SSAValue(i)], r) + end + + # an escaped tuple stmt will not propagate to its Int argument (since `Int` is of bitstype) + let result = code_escapes((Int,Any,)) do a, b + t = tuple(a, b) + return t + end + i = only(findall(iscall((result.ir, tuple)), result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test !has_return_escape(result.state[Argument(2)], r) + @test has_return_escape(result.state[Argument(3)], r) + end +end + +# # TODO implement a finalizer elision pass +# mutable struct WithFinalizer +# v +# function WithFinalizer(v) +# x = new(v) +# f(t) = @async println("Finalizing $t.") +# return finalizer(x, x) +# end +# end +# make_m(v = 10) = MyMutable(v) +# function simple(cond) +# m = make_m() +# if cond +# # println(m.v) +# return nothing # <= insert `finalize` call here +# end +# return m +# end diff --git a/test/compiler/EscapeAnalysis/setup.jl b/test/compiler/EscapeAnalysis/setup.jl new file mode 100644 index 0000000000000..620b0ec4f0b16 --- /dev/null +++ b/test/compiler/EscapeAnalysis/setup.jl @@ -0,0 +1,72 @@ +include(normpath(@__DIR__, "EAUtils.jl")) +using Test, Core.Compiler.EscapeAnalysis, .EAUtils +import Core: Argument, SSAValue, ReturnNode +const EA = Core.Compiler.EscapeAnalysis + +isT(T) = (@nospecialize x) -> x === T +isreturn(@nospecialize x) = isa(x, Core.ReturnNode) && isdefined(x, :val) +isthrow(@nospecialize x) = Meta.isexpr(x, :call) && Core.Compiler.is_throw_call(x) +isnew(@nospecialize x) = Meta.isexpr(x, :new) +isϕ(@nospecialize x) = isa(x, Core.PhiNode) +function with_normalized_name(@nospecialize(f), @nospecialize(x)) + if Meta.isexpr(x, :foreigncall) + name = x.args[1] + nn = EA.normalize(name) + return isa(nn, Symbol) && f(nn) + end + return false +end +isarrayalloc(@nospecialize x) = with_normalized_name(nn->!isnothing(Core.Compiler.alloc_array_ndims(nn)), x) +isarrayresize(@nospecialize x) = with_normalized_name(nn->!isnothing(EA.array_resize_info(nn)), x) +isarraycopy(@nospecialize x) = with_normalized_name(nn->EA.is_array_copy(nn), x) +import Core.Compiler: argextype, singleton_type +iscall(y) = @nospecialize(x) -> iscall(y, x) +function iscall((ir, f), @nospecialize(x)) + return iscall(x) do @nospecialize x + singleton_type(Core.Compiler.argextype(x, ir, Any[])) === f + end +end +iscall(pred::Function, @nospecialize(x)) = Meta.isexpr(x, :call) && pred(x.args[1]) + +# check if `x` is a statically-resolved call of a function whose name is `sym` +isinvoke(y) = @nospecialize(x) -> isinvoke(y, x) +isinvoke(sym::Symbol, @nospecialize(x)) = isinvoke(mi->mi.def.name===sym, x) +isinvoke(pred::Function, @nospecialize(x)) = Meta.isexpr(x, :invoke) && pred(x.args[1]::Core.MethodInstance) + +""" + is_load_forwardable(x::EscapeInfo) -> Bool + +Queries if `x` is elibigle for store-to-load forwarding optimization. +""" +function is_load_forwardable(x::EA.EscapeInfo) + AliasInfo = x.AliasInfo + AliasInfo === false && return true # allows this query to work for immutables since we don't impose escape on them + # NOTE technically we also need to check `!has_thrown_escape(x)` here as well, + # but we can also do equivalent check during forwarding + return isa(AliasInfo, EA.IndexableFields) || isa(AliasInfo, EA.IndexableElements) +end + +let setup_ex = quote + mutable struct SafeRef{T} + x::T + end + Base.getindex(s::SafeRef) = getfield(s, 1) + Base.setindex!(s::SafeRef, x) = setfield!(s, 1, x) + + mutable struct SafeRefs{S,T} + x1::S + x2::T + end + Base.getindex(s::SafeRefs, idx::Int) = getfield(s, idx) + Base.setindex!(s::SafeRefs, x, idx::Int) = setfield!(s, idx, x) + + global GV::Any + const global GR = Ref{Any}() + end + global function EATModule(setup_ex = setup_ex) + M = Module() + Core.eval(M, setup_ex) + return M + end + Core.eval(@__MODULE__, setup_ex) +end From 04f2d4828d4d7f2b119aa7b383e53ba505f70f05 Mon Sep 17 00:00:00 2001 From: Shuhei Kadowaki Date: Sat, 22 Jan 2022 03:12:25 +0900 Subject: [PATCH 2/3] optimizer: alias-aware SROA MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Enhances SROA of mutables using the novel Julia-level escape analysis (on top of #43800): 1. alias-aware SROA, mutable ϕ-node elimination 2. `isdefined` check elimination 3. load-forwarding for non-eliminable but analyzable mutables --- 1. alias-aware SROA, mutable ϕ-node elimination EA's alias analysis allows this new SROA to handle nested mutables allocations pretty well. Now we can eliminate the heap allocations completely from this insanely nested examples by the single analysis/optimization pass: ```julia julia> function refs(x) (Ref(Ref(Ref(Ref(Ref(Ref(Ref(Ref(Ref(Ref((x))))))))))))[][][][][][][][][][] end refs (generic function with 1 method) julia> refs("julia"); @allocated refs("julia") 0 ``` EA can also analyze escape of ϕ-node as well as its aliasing. Mutable ϕ-nodes would be eliminated even for a very tricky case as like: ```julia julia> code_typed((Bool,String,)) do cond, x # these allocation form multiple ϕ-nodes if cond ϕ2 = ϕ1 = Ref{Any}("foo") else ϕ2 = ϕ1 = Ref{Any}("bar") end ϕ2[] = x y = ϕ1[] # => x return y end 1-element Vector{Any}: CodeInfo( 1 ─ goto #3 if not cond 2 ─ goto #4 3 ─ nothing::Nothing 4 ┄ return x ) => Any ``` Combined with the alias analysis and ϕ-node handling above, allocations in the following "realistic" examples will be optimized: ```julia julia> # demonstrate the power of our field / alias analysis with realistic end to end examples # adapted from http://wiki.luajit.org/Allocation-Sinking-Optimization#implementation%5B abstract type AbstractPoint{T} end julia> struct Point{T} <: AbstractPoint{T} x::T y::T end julia> mutable struct MPoint{T} <: AbstractPoint{T} x::T y::T end julia> add(a::P, b::P) where P<:AbstractPoint = P(a.x + b.x, a.y + b.y); julia> function compute_point(T, n, ax, ay, bx, by) a = T(ax, ay) b = T(bx, by) for i in 0:(n-1) a = add(add(a, b), b) end a.x, a.y end; julia> function compute_point(n, a, b) for i in 0:(n-1) a = add(add(a, b), b) end a.x, a.y end; julia> function compute_point!(n, a, b) for i in 0:(n-1) a′ = add(add(a, b), b) a.x = a′.x a.y = a′.y end end; julia> compute_point(MPoint, 10, 1+.5, 2+.5, 2+.25, 4+.75); julia> compute_point(MPoint, 10, 1+.5im, 2+.5im, 2+.25im, 4+.75im); julia> @allocated compute_point(MPoint, 10000, 1+.5, 2+.5, 2+.25, 4+.75) 0 julia> @allocated compute_point(MPoint, 10000, 1+.5im, 2+.5im, 2+.25im, 4+.75im) 0 julia> compute_point(10, MPoint(1+.5, 2+.5), MPoint(2+.25, 4+.75)); julia> compute_point(10, MPoint(1+.5im, 2+.5im), MPoint(2+.25im, 4+.75im)); julia> @allocated compute_point(10000, MPoint(1+.5, 2+.5), MPoint(2+.25, 4+.75)) 0 julia> @allocated compute_point(10000, MPoint(1+.5im, 2+.5im), MPoint(2+.25im, 4+.75im)) 0 julia> af, bf = MPoint(1+.5, 2+.5), MPoint(2+.25, 4+.75); julia> ac, bc = MPoint(1+.5im, 2+.5im), MPoint(2+.25im, 4+.75im); julia> compute_point!(10, af, bf); julia> compute_point!(10, ac, bc); julia> @allocated compute_point!(10000, af, bf) 0 julia> @allocated compute_point!(10000, ac, bc) 0 ``` 2. `isdefined` check elimination This commit also implements a simple optimization to eliminate `isdefined` call by checking load-fowardability. This optimization may be especially useful to eliminate extra allocation involved with a capturing closure, e.g.: ```julia julia> callit(f, args...) = f(args...); julia> function isdefined_elim() local arr::Vector{Any} callit() do arr = Any[] end return arr end; julia> code_typed(isdefined_elim) 1-element Vector{Any}: CodeInfo( 1 ─ %1 = $(Expr(:foreigncall, :(:jl_alloc_array_1d), Vector{Any}, svec(Any, Int64), 0, :(:ccall), Vector{Any}, 0, 0))::Vector{Any} └── goto #3 if not true 2 ─ goto #4 3 ─ $(Expr(:throw_undef_if_not, :arr, false))::Any 4 ┄ return %1 ) => Vector{Any} ``` 3. load-forwarding for non-eliminable but analyzable mutables EA also allows us to forward loads even when the mutable allocation can't be eliminated but still its fields are known precisely. The load forwarding might be useful since it may derive new type information that succeeding optimization passes can use (or just because it allows simpler code transformations down the load): ```julia julia> code_typed((Bool,String,)) do c, s r = Ref{Any}(s) if c return r[]::String # adce_pass! will further eliminate this type assert call also else return r end end 1-element Vector{Any}: CodeInfo( 1 ─ %1 = %new(Base.RefValue{Any}, s)::Base.RefValue{Any} └── goto #3 if not c 2 ─ return s 3 ─ return %1 ) => Union{Base.RefValue{Any}, String} ``` --- Please refer to the newly added test cases for more examples. Also, EA's alias analysis already succeeds to reason about arrays, and so this EA-based SROA will hopefully be generalized for array SROA as well. --- base/compiler/bootstrap.jl | 6 +- base/compiler/optimize.jl | 26 +- .../ssair/EscapeAnalysis/EscapeAnalysis.jl | 23 +- base/compiler/ssair/passes.jl | 814 ++++++++++-------- test/compiler/EscapeAnalysis/EAUtils.jl | 4 +- test/compiler/irpasses.jl | 703 +++++++++++++-- 6 files changed, 1144 insertions(+), 432 deletions(-) diff --git a/base/compiler/bootstrap.jl b/base/compiler/bootstrap.jl index 1989d8aa57393..487ddf2ccdd1b 100644 --- a/base/compiler/bootstrap.jl +++ b/base/compiler/bootstrap.jl @@ -11,7 +11,11 @@ let world = get_world_counter() interp = NativeInterpreter(world) - analyze_escapes_tt = Tuple{typeof(analyze_escapes), IRCode, Int, Bool, typeof(get_escape_cache(code_cache(interp)))} + analyze_escapes_tt = Any[typeof(analyze_escapes), IRCode, Int, Bool, + # typeof(get_escape_cache(code_cache(interp))) # once we enable IPO EA + typeof(null_escape_cache) + ] + analyze_escapes_tt = Tuple{analyze_escapes_tt...} fs = Any[ # we first create caches for the optimizer, because they contain many loop constructions # and they're better to not run in interpreter even during bootstrapping diff --git a/base/compiler/optimize.jl b/base/compiler/optimize.jl index 635e53a9e1f1d..e84f77ae1ea48 100644 --- a/base/compiler/optimize.jl +++ b/base/compiler/optimize.jl @@ -98,7 +98,7 @@ and then caches it into a global cache for later interprocedural propagation. cache_escapes!(caller::InferenceResult, estate::EscapeState) = caller.argescapes = ArgEscapeCache(estate) -function get_escape_cache(mi_cache::MICache) where MICache +function ipo_escape_cache(mi_cache::MICache) where MICache return function (linfo::Union{InferenceResult,MethodInstance}) if isa(linfo, InferenceResult) argescapes = linfo.argescapes @@ -110,6 +110,7 @@ function get_escape_cache(mi_cache::MICache) where MICache return argescapes !== nothing ? argescapes::ArgEscapeCache : nothing end end +null_escape_cache(linfo::Union{InferenceResult,MethodInstance}) = nothing mutable struct OptimizationState linfo::MethodInstance @@ -540,17 +541,24 @@ function run_passes(ci::CodeInfo, sv::OptimizationState, caller::InferenceResult # TODO: Domsorting can produce an updated domtree - no need to recompute here @timeit "compact 1" ir = compact!(ir) nargs = let def = sv.linfo.def; isa(def, Method) ? Int(def.nargs) : 0; end - get_escape_cache = (@__MODULE__).get_escape_cache(sv.inlining.mi_cache) - if is_ipo_profitable(ir, nargs) - @timeit "IPO EA" begin - state = analyze_escapes(ir, nargs, false, get_escape_cache) - cache_escapes!(caller, state) - end - end + # if is_ipo_profitable(ir, nargs) + # @timeit "IPO EA" begin + # state = analyze_escapes(ir, + # nargs, #=call_resolved=#false, ipo_escape_cache(sv.inlining.mi_cache)) + # cache_escapes!(caller, state) + # end + # end @timeit "Inlining" ir = ssa_inlining_pass!(ir, ir.linetable, sv.inlining, ci.propagate_inbounds) # @timeit "verify 2" verify_ir(ir) @timeit "compact 2" ir = compact!(ir) - @timeit "SROA" ir = sroa_pass!(ir) + @timeit "SROA" ir, memory_opt = linear_pass!(ir) + if memory_opt + @timeit "memory_opt_pass!" begin + @timeit "Local EA" estate = analyze_escapes(ir, + nargs, #=call_resolved=#true, null_escape_cache) + @timeit "memory_opt_pass!" ir = memory_opt_pass!(ir, estate) + end + end @timeit "ADCE" ir = adce_pass!(ir) @timeit "type lift" ir = type_lift_pass!(ir) @timeit "compact 3" ir = compact!(ir) diff --git a/base/compiler/ssair/EscapeAnalysis/EscapeAnalysis.jl b/base/compiler/ssair/EscapeAnalysis/EscapeAnalysis.jl index 218afaefa431f..3fc9f5abb0559 100644 --- a/base/compiler/ssair/EscapeAnalysis/EscapeAnalysis.jl +++ b/base/compiler/ssair/EscapeAnalysis/EscapeAnalysis.jl @@ -1383,17 +1383,18 @@ function escape_new!(astate::AnalysisState, pc::Int, args::Vector{Any}) AliasInfo = objinfo.AliasInfo nargs = length(args) if isa(AliasInfo, Bool) - AliasInfo && @goto conservative_propagation - # AliasInfo of this object hasn't been analyzed yet: set AliasInfo now - typ = widenconst(argextype(obj, astate.ir)) - nfields = fieldcount_noerror(typ) - if nfields === nothing - AliasInfo = Unindexable() - @goto escape_unindexable_def - else - AliasInfo = IndexableFields(nfields) - @goto escape_indexable_def - end + @goto conservative_propagation + # AliasInfo && @goto conservative_propagation + # # AliasInfo of this object hasn't been analyzed yet: set AliasInfo now + # typ = widenconst(argextype(obj, astate.ir)) + # nfields = fieldcount_noerror(typ) + # if nfields === nothing + # AliasInfo = Unindexable() + # @goto escape_unindexable_def + # else + # AliasInfo = IndexableFields(nfields) + # @goto escape_indexable_def + # end elseif isa(AliasInfo, IndexableFields) @label escape_indexable_def # fields are known precisely: propagate escape information imposed on recorded possibilities to the exact field values diff --git a/base/compiler/ssair/passes.jl b/base/compiler/ssair/passes.jl index 67610f0c1df60..c8037fac648fa 100644 --- a/base/compiler/ssair/passes.jl +++ b/base/compiler/ssair/passes.jl @@ -6,29 +6,6 @@ function is_known_call(@nospecialize(x), @nospecialize(func), ir::Union{IRCode,I return singleton_type(ft) === func end -""" - du::SSADefUse - -This struct keeps track of all uses of some mutable struct allocated in the current function: -- `du.uses::Vector{Int}` are all instances of `getfield` on the struct -- `du.defs::Vector{Int}` are all instances of `setfield!` on the struct -The terminology refers to the uses/defs of the "slot bundle" that the mutable struct represents. - -In addition we keep track of all instances of a `:foreigncall` that preserves of this mutable -struct in `du.ccall_preserve_uses`. Somewhat counterintuitively, we don't actually need to -make sure that the struct itself is live (or even allocated) at a `ccall` site. -If there are no other places where the struct escapes (and thus e.g. where its address is taken), -it need not be allocated. We do however, need to make sure to preserve any elements of this struct. -""" -struct SSADefUse - uses::Vector{Int} - defs::Vector{Int} - ccall_preserve_uses::Vector{Int} -end -SSADefUse() = SSADefUse(Int[], Int[], Int[]) - -compute_live_ins(cfg::CFG, du::SSADefUse) = compute_live_ins(cfg, du.defs, du.uses) - # assume `stmt == getfield(obj, field, ...)` or `stmt == setfield!(obj, field, val, ...)` try_compute_field_stmt(ir::Union{IncrementalCompact,IRCode}, stmt::Expr) = try_compute_field(ir, stmt.args[3]) @@ -55,112 +32,6 @@ function try_compute_fieldidx_stmt(ir::Union{IncrementalCompact,IRCode}, stmt::E return try_compute_fieldidx(typ, field) end -function find_curblock(domtree::DomTree, allblocks::Vector{Int}, curblock::Int) - # TODO: This can be much faster by looking at current level and only - # searching for those blocks in a sorted order - while !(curblock in allblocks) - curblock = domtree.idoms_bb[curblock] - end - return curblock -end - -function val_for_def_expr(ir::IRCode, def::Int, fidx::Int) - ex = ir[SSAValue(def)][:inst] - if isexpr(ex, :new) - return ex.args[1+fidx] - else - @assert isa(ex, Expr) - # The use is whatever the setfield was - return ex.args[4] - end -end - -function compute_value_for_block(ir::IRCode, domtree::DomTree, allblocks::Vector{Int}, du::SSADefUse, phinodes::IdDict{Int, SSAValue}, fidx::Int, curblock::Int) - curblock = find_curblock(domtree, allblocks, curblock) - def = 0 - for stmt in du.defs - if block_for_inst(ir.cfg, stmt) == curblock - def = max(def, stmt) - end - end - def == 0 ? phinodes[curblock] : val_for_def_expr(ir, def, fidx) -end - -function compute_value_for_use(ir::IRCode, domtree::DomTree, allblocks::Vector{Int}, du::SSADefUse, phinodes::IdDict{Int, SSAValue}, fidx::Int, use::Int) - def, useblock, curblock = find_def_for_use(ir, domtree, allblocks, du, use) - if def == 0 - if !haskey(phinodes, curblock) - # If this happens, we need to search the predecessors for defs. Which - # one doesn't matter - if it did, we'd have had a phinode - return compute_value_for_block(ir, domtree, allblocks, du, phinodes, fidx, first(ir.cfg.blocks[useblock].preds)) - end - # The use is the phinode - return phinodes[curblock] - else - return val_for_def_expr(ir, def, fidx) - end -end - -# even when the allocation contains an uninitialized field, we try an extra effort to check -# if this load at `idx` have any "safe" `setfield!` calls that define the field -function has_safe_def( - ir::IRCode, domtree::DomTree, allblocks::Vector{Int}, du::SSADefUse, - newidx::Int, idx::Int) - def, _, _ = find_def_for_use(ir, domtree, allblocks, du, idx) - # will throw since we already checked this `:new` site doesn't define this field - def == newidx && return false - # found a "safe" definition - def ≠ 0 && return true - # we may still be able to replace this load with `PhiNode` - # examine if all predecessors of `block` have any "safe" definition - block = block_for_inst(ir, idx) - seen = BitSet(block) - worklist = BitSet(ir.cfg.blocks[block].preds) - isempty(worklist) && return false - while !isempty(worklist) - pred = pop!(worklist) - # if this block has already been examined, bail out to avoid infinite cycles - pred in seen && return false - idx = last(ir.cfg.blocks[pred].stmts) - # NOTE `idx` isn't a load, thus we can use inclusive coondition within the `find_def_for_use` - def, _, _ = find_def_for_use(ir, domtree, allblocks, du, idx, true) - # will throw since we already checked this `:new` site doesn't define this field - def == newidx && return false - push!(seen, pred) - # found a "safe" definition for this predecessor - def ≠ 0 && continue - # check for the predecessors of this predecessor - for newpred in ir.cfg.blocks[pred].preds - push!(worklist, newpred) - end - end - return true -end - -# find the first dominating def for the given use -function find_def_for_use( - ir::IRCode, domtree::DomTree, allblocks::Vector{Int}, du::SSADefUse, use::Int, inclusive::Bool=false) - useblock = block_for_inst(ir.cfg, use) - curblock = find_curblock(domtree, allblocks, useblock) - local def = 0 - for idx in du.defs - if block_for_inst(ir.cfg, idx) == curblock - if curblock != useblock - # Find the last def in this block - def = max(def, idx) - else - # Find the last def before our use - if inclusive - def = max(def, idx ≤ use ? idx : 0) - else - def = max(def, idx < use ? idx : 0) - end - end - end - end - return def, useblock, curblock -end - function collect_leaves(compact::IncrementalCompact, @nospecialize(val), @nospecialize(typeconstraint)) if isa(val, Union{OldSSAValue, SSAValue}) val, typeconstraint = simple_walk_constraint(compact, val, typeconstraint) @@ -657,38 +528,35 @@ end const SPCSet = IdSet{Int} """ - sroa_pass!(ir::IRCode) -> newir::IRCode - -`getfield` elimination pass, a.k.a. Scalar Replacements of Aggregates optimization. - -This pass is based on a local field analysis by def-use chain walking. -It looks for struct allocation sites ("definitions"), and `getfield` calls as well as -`:foreigncall`s that preserve the structs ("usages"). If "definitions" have enough information, -then this pass will replace corresponding usages with forwarded values. -`mutable struct`s require additional cares and need to be handled separately from immutables. -For `mutable struct`s, `setfield!` calls account for "definitions" also, and the pass should -give up the lifting conservatively when there are any "intermediate usages" that may escape -the mutable struct (e.g. non-inlined generic function call that takes the mutable struct as -its argument). - -In a case when all usages are fully eliminated, `struct` allocation may also be erased as -a result of succeeding dead code elimination. + linear_pass!(ir::IRCode) -> (newir::IRCode, memory_opt::Bool) + +This pass consists of the following optimizations that can be performed by +a single linear traversal over IR statements: +- load forwarding of immutables (`getfield` elimination): immutable allocations whose + loads are all eliminated by this pass may be erased entirely as a result of succeeding + dead code elimination (this allocation elimination is called "SROA", Scalar Replacements of Aggregates) +- lifting of builtin comparisons: see [`lift_comparison!`](@ref) +- canonicalization of `typeassert` calls: see [`canonicalize_typeassert!`](@ref) + +In addition to performing the optimizations above, the linear traversal also examines each +statement and checks if there is any profitability of running [`memory_opt_pass!`](@ref) pass. +In such cases `memory_opt` is flagged on and it indicates `ir` may be further optimized by +running `memory_opt_pass!(ir, estate::EscapeState)`. """ -function sroa_pass!(ir::IRCode) +function linear_pass!(ir::IRCode) compact = IncrementalCompact(ir) - defuses = nothing # will be initialized once we encounter mutability in order to reduce dynamic allocations lifting_cache = IdDict{Pair{AnySSAValue, Any}, AnySSAValue}() + local memory_opt = false # whether or not to run the memory_opt_pass! pass later for ((_, idx), stmt) in compact - # check whether this statement is `getfield` / `setfield!` (or other "interesting" statement) isa(stmt, Expr) || continue - is_setfield = false field_ordering = :unspecified - if is_known_call(stmt, setfield!, compact) - 4 <= length(stmt.args) <= 5 || continue - is_setfield = true - if length(stmt.args) == 5 - field_ordering = argextype(stmt.args[5], compact) + if isexpr(stmt, :new) + typ = unwrap_unionall(widenconst(argextype(SSAValue(idx), compact))) + if ismutabletype(typ) + # mutable SROA may eliminate this eliminate this allocation, mark it now + memory_opt = true end + continue elseif is_known_call(stmt, getfield, compact) 3 <= length(stmt.args) <= 5 || continue if length(stmt.args) == 5 @@ -704,40 +572,21 @@ function sroa_pass!(ir::IRCode) for pidx in (6+nccallargs):length(stmt.args) preserved_arg = stmt.args[pidx] isa(preserved_arg, SSAValue) || continue - let intermediaries = SPCSet() - callback = function (@nospecialize(pi), @nospecialize(ssa)) - push!(intermediaries, ssa.id) - return false - end - def = simple_walk(compact, preserved_arg, callback) - isa(def, SSAValue) || continue - defidx = def.id - def = compact[defidx] - if is_known_call(def, tuple, compact) + def = simple_walk(compact, preserved_arg) + isa(def, SSAValue) || continue + defidx = def.id + def = compact[defidx] + if is_known_call(def, tuple, compact) + record_immutable_preserve!(new_preserves, def, compact) + push!(preserved, preserved_arg.id) + elseif isexpr(def, :new) + typ = unwrap_unionall(widenconst(argextype(SSAValue(defidx), compact))) + if typ isa DataType + ismutabletype(typ) && continue # mutable SROA is performed later record_immutable_preserve!(new_preserves, def, compact) push!(preserved, preserved_arg.id) - continue - elseif isexpr(def, :new) - typ = widenconst(argextype(SSAValue(defidx), compact)) - if isa(typ, UnionAll) - typ = unwrap_unionall(typ) - end - if typ isa DataType && !ismutabletype(typ) - record_immutable_preserve!(new_preserves, def, compact) - push!(preserved, preserved_arg.id) - continue - end - else - continue end - if defuses === nothing - defuses = IdDict{Int, Tuple{SPCSet, SSADefUse}}() - end - mid, defuse = get!(defuses, defidx, (SPCSet(), SSADefUse())) - push!(defuse.ccall_preserve_uses, idx) - union!(mid, intermediaries) end - continue end if !isempty(new_preserves) compact[idx] = form_new_preserves(stmt, preserved, new_preserves) @@ -756,7 +605,7 @@ function sroa_pass!(ir::IRCode) continue end - # analyze this `getfield` / `setfield!` call + # analyze this `getfield` call field = try_compute_field_stmt(compact, stmt) field === nothing && continue @@ -774,32 +623,7 @@ function sroa_pass!(ir::IRCode) continue end - # analyze this mutable struct here for the later pass - if ismutabletype(struct_typ) - isa(val, SSAValue) || continue - let intermediaries = SPCSet() - callback = function (@nospecialize(pi), @nospecialize(ssa)) - push!(intermediaries, ssa.id) - return false - end - def = simple_walk(compact, val, callback) - # Mutable stuff here - isa(def, SSAValue) || continue - if defuses === nothing - defuses = IdDict{Int, Tuple{SPCSet, SSADefUse}}() - end - mid, defuse = get!(defuses, def.id, (SPCSet(), SSADefUse())) - if is_setfield - push!(defuse.defs, idx) - else - push!(defuse.uses, idx) - end - union!(mid, intermediaries) - end - continue - elseif is_setfield - continue # invalid `setfield!` call, but just ignore here - end + ismutabletype(struct_typ) && continue # mutable SROA is performed later # perform SROA on immutable structs here on @@ -837,177 +661,459 @@ function sroa_pass!(ir::IRCode) end non_dce_finish!(compact) - if defuses !== nothing - # now go through analyzed mutable structs and see which ones we can eliminate - # NOTE copy the use count here, because `simple_dce!` may modify it and we need it - # consistent with the state of the IR here (after tracking `PhiNode` arguments, - # but before the DCE) for our predicate within `sroa_mutables!`, but we also - # try an extra effort using a callback so that reference counts are updated - used_ssas = copy(compact.used_ssas) - simple_dce!(compact, (x::SSAValue) -> used_ssas[x.id] -= 1) - ir = complete(compact) - sroa_mutables!(ir, defuses, used_ssas) - return ir - else - simple_dce!(compact) - return complete(compact) - end + simple_dce!(compact) + ir = complete(compact) + return ir, memory_opt end -function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse}}, used_ssas::Vector{Int}) - # initialization of domtree is delayed to avoid the expensive computation in many cases - local domtree = nothing - for (idx, (intermediaries, defuse)) in defuses - intermediaries = collect(intermediaries) - # Check if there are any uses we did not account for. If so, the variable - # escapes and we cannot eliminate the allocation. This works, because we're guaranteed - # not to include any intermediaries that have dead uses. As a result, missing uses will only ever - # show up in the nuses_total count. - nleaves = length(defuse.uses) + length(defuse.defs) + length(defuse.ccall_preserve_uses) - nuses = 0 - for idx in intermediaries - nuses += used_ssas[idx] +function form_new_preserves(origex::Expr, intermediates::Vector{Int}, new_preserves::Vector{Any}) + newex = Expr(:foreigncall) + nccallargs = length(origex.args[3]::SimpleVector) + for i in 1:(6+nccallargs-1) + push!(newex.args, origex.args[i]) + end + for i in (6+nccallargs):length(origex.args) + x = origex.args[i] + # don't need to preserve intermediaries + if isa(x, SSAValue) && x.id in intermediates + continue end - nuses_total = used_ssas[idx] + nuses - length(intermediaries) - nleaves == nuses_total || continue - # Find the type for this allocation - defexpr = ir[SSAValue(idx)][:inst] - isexpr(defexpr, :new) || continue - newidx = idx - typ = ir.stmts[newidx][:type] - if isa(typ, UnionAll) - typ = unwrap_unionall(typ) + push!(newex.args, x) + end + for i in 1:length(new_preserves) + push!(newex.args, new_preserves[i]) + end + return newex +end + +import .EscapeAnalysis: + EscapeState, EscapeInfo, IndexableFields, LivenessSet, getaliases, LocalUse, LocalDef + +""" + memory_opt_pass!(ir::IRCode, estate::EscapeState) -> newir::IRCode + +Performs memory optimizations using escape information analyzed by `EscapeAnalysis`. +Specifically, this optimization pass does SROA of mutable allocations. + +`estate::EscapeState` is expected to be a result of `analyze_escapes(ir, ...)`. +Since the computational cost of running `analyze_escapes` can be relatively expensive, +it is recommended to run this pass "selectively" i.e. only when there seems to be +a profitability for the memory optimizations. +""" +function memory_opt_pass!(ir::IRCode, estate::EscapeState) + # Compute domtree now, needed below, now that we have finished compacting the IR. + # This needs to be after we iterate through the IR with `IncrementalCompact` + # because removing dead blocks can invalidate the domtree. + # TODO initialization of the domtree can be delayed to avoid the expensive computation + # in cases when there are no loads to be forwarded + @timeit "domtree 2" domtree = construct_domtree(ir.cfg.blocks) + wset = BitSet(1:length(ir.stmts)+length(ir.new_nodes.stmts)) + eliminated = BitSet() + revisit = Tuple{#=related=#Vector{SSAValue}, #=Liveness=#LivenessSet}[] + all_preserved = true + newpreserves = nothing + while !isempty(wset) + idx = pop!(wset) + ssa = SSAValue(idx) + stmt = ir[ssa][:inst] + isexpr(stmt, :new) || continue + einfo = estate[ssa] + is_load_forwardable(einfo) || continue + aliases = getaliases(ssa, estate) + if aliases === nothing + related = SSAValue[ssa] + else + related = SSAValue[] + for alias in aliases + @assert isa(alias, SSAValue) "invalid escape analysis" + push!(related, alias) + delete!(wset, alias.id) + end end - # Could still end up here if we tried to setfield! on an immutable, which would - # error at runtime, but is not illegal to have in the IR. - ismutabletype(typ) || continue - typ = typ::DataType + finfos = (einfo.AliasInfo::IndexableFields).infos + nfields = length(finfos) + # Partition defuses by field - fielddefuse = SSADefUse[SSADefUse() for _ = 1:fieldcount(typ)] - all_forwarded = true - for use in defuse.uses - stmt = ir[SSAValue(use)][:inst] # == `getfield` call - # We may have discovered above that this use is dead - # after the getfield elim of immutables. In that case, - # it would have been deleted. That's fine, just ignore - # the use in that case. - if stmt === nothing - all_forwarded = false - continue + fdefuses = Vector{FieldDefUse}(undef, nfields) + for i = 1:nfields + finfo = finfos[i] + fdu = FieldDefUse() + for fx in finfo + if isa(fx, LocalUse) + push!(fdu.uses, GetfieldLoad(fx.idx)) # use (getfield call) + else + @assert isa(fx, LocalDef) + push!(fdu.defs, fx.idx) # def (setfield! call or :new expression) + end end - field = try_compute_fieldidx_stmt(ir, stmt::Expr, typ) - field === nothing && @goto skip - push!(fielddefuse[field].uses, use) + fdefuses[i] = fdu end - for def in defuse.defs - stmt = ir[SSAValue(def)][:inst]::Expr # == `setfield!` call - field = try_compute_fieldidx_stmt(ir, stmt, typ) - field === nothing && @goto skip - isconst(typ, field) && @goto skip # we discovered an attempt to mutate a const field, which must error - push!(fielddefuse[field].defs, def) + + Liveness = einfo.Liveness + for livepc in Liveness + livestmt = ir[SSAValue(livepc)][:inst] + if is_known_call(livestmt, Core.ifelse, ir) + # the succeeding domination analysis doesn't account for conditional branching + # by ifelse branching at this moment + @goto next_itr + elseif is_known_call(livestmt, isdefined, ir) + args = livestmt.args + length(args) ≥ 3 || continue + obj = args[2] + isa(obj, SSAValue) || continue + obj in related || continue + fld = args[3] + fldval = try_compute_field(ir, fld) + fldval === nothing && continue + typ = unwrap_unionall(widenconst(argextype(obj, ir))) + isa(typ, DataType) || continue + fldidx = try_compute_fieldidx(typ, fldval) + fldidx === nothing && continue + push!(fdefuses[fldidx].uses, IsdefinedUse(livepc)) + elseif isexpr(livestmt, :foreigncall) + # we shouldn't eliminate this use if it's used as a direct argument + args = livestmt.args + nccallargs = length(args[3]::SimpleVector) + for i = 6:(5+nccallargs) + arg = args[i] + isa(arg, SSAValue) && arg in related && @goto next_liveness + end + # this use is preserve, and may be eliminable + for fidx in 1:nfields + push!(fdefuses[fidx].uses, PreserveUse(livepc)) + end + end + @label next_liveness end - # Check that the defexpr has defined values for all the fields - # we're accessing. In the future, we may want to relax this, - # but we should come up with semantics for well defined semantics - # for uninitialized fields first. - ndefuse = length(fielddefuse) - blocks = Vector{Tuple{#=phiblocks=# Vector{Int}, #=allblocks=# Vector{Int}}}(undef, ndefuse) - for fidx in 1:ndefuse - du = fielddefuse[fidx] - isempty(du.uses) && continue - push!(du.defs, newidx) - ldu = compute_live_ins(ir.cfg, du) + + for fidx in 1:nfields + fdu = fdefuses[fidx] + isempty(fdu.uses) && @goto next_use + # check if all uses have safe definitions first, otherwise we should bail out + # since then we may fail to form new ϕ-nodes + ldu = compute_live_ins(ir.cfg, fdu) if isempty(ldu.live_in_bbs) phiblocks = Int[] else - domtree === nothing && (@timeit "domtree 2" domtree = construct_domtree(ir.cfg.blocks)) phiblocks = iterated_dominance_frontier(ir.cfg, ldu, domtree) end - allblocks = sort(vcat(phiblocks, ldu.def_bbs)) - blocks[fidx] = phiblocks, allblocks - if fidx + 1 > length(defexpr.args) - for use in du.uses - domtree === nothing && (@timeit "domtree 2" domtree = construct_domtree(ir.cfg.blocks)) - has_safe_def(ir, domtree, allblocks, du, newidx, use) || @goto skip - end - end - end - # Everything accounted for. Go field by field and perform idf: - # Compute domtree now, needed below, now that we have finished compacting the IR. - # This needs to be after we iterate through the IR with `IncrementalCompact` - # because removing dead blocks can invalidate the domtree. - domtree === nothing && (@timeit "domtree 2" domtree = construct_domtree(ir.cfg.blocks)) - preserve_uses = isempty(defuse.ccall_preserve_uses) ? nothing : - IdDict{Int, Vector{Any}}((idx=>Any[] for idx in SPCSet(defuse.ccall_preserve_uses))) - for fidx in 1:ndefuse - du = fielddefuse[fidx] - ftyp = fieldtype(typ, fidx) - if !isempty(du.uses) - phiblocks, allblocks = blocks[fidx] - phinodes = IdDict{Int, SSAValue}() - for b in phiblocks - phinodes[b] = insert_node!(ir, first(ir.cfg.blocks[b].stmts), - NewInstruction(PhiNode(), ftyp)) + allblocks = sort!(vcat(phiblocks, ldu.def_bbs)) + for use in fdu.uses + isa(use, IsdefinedUse) && continue + if isa(use, PreserveUse) && isempty(fdu.defs) + # nothing to preserve, just ignore this use (may happen when there are unintialized fields) + continue end - # Now go through all uses and rewrite them - for stmt in du.uses - ir[SSAValue(stmt)][:inst] = compute_value_for_use(ir, domtree, allblocks, du, phinodes, fidx, stmt) + if !has_safe_def(ir, domtree, allblocks, fdu, getuseidx(use)) + all_preserved = false + @goto next_use end - if !isbitstype(ftyp) - if preserve_uses !== nothing - for (use, list) in preserve_uses - push!(list, compute_value_for_use(ir, domtree, allblocks, du, phinodes, fidx, use)) - end + end + phinodes = IdDict{Int, SSAValue}() + for b in phiblocks + phinodes[b] = insert_node!(ir, first(ir.cfg.blocks[b].stmts), + NewInstruction(PhiNode(), Any)) + end + # Now go through all uses and rewrite them + for use in fdu.uses + if isa(use, GetfieldLoad) + use = getuseidx(use) + ir[SSAValue(use)][:inst] = compute_value_for_use( + ir, domtree, allblocks, fdu, phinodes, fidx, use) + push!(eliminated, use) + elseif all_preserved && isa(use, PreserveUse) + if newpreserves === nothing + newpreserves = IdDict{Int,Vector{Any}}() end - end - for b in phiblocks - n = ir[phinodes[b]][:inst]::PhiNode - for p in ir.cfg.blocks[b].preds - push!(n.edges, p) - push!(n.values, compute_value_for_block(ir, domtree, - allblocks, du, phinodes, fidx, p)) + # record this `use` as replaceable no matter if we preserve new value or not + use = getuseidx(use) + newvalues = get!(()->Any[], newpreserves, use) + isempty(fdu.defs) && continue # nothing to preserve (may happen when there are unintialized fields) + newval = compute_value_for_use( + ir, domtree, allblocks, fdu, phinodes, fidx, use) + if !isbitstype(widenconst(argextype(newval, ir))) + push!(newvalues, newval) + end + elseif isa(use, IsdefinedUse) + use = getuseidx(use) + if has_safe_def(ir, domtree, allblocks, fdu, use) + ir[SSAValue(use)][:inst] = true + push!(eliminated, use) end + else + throw("unexpected use") end end - for stmt in du.defs - stmt == newidx && continue - ir[SSAValue(stmt)][:inst] = nothing + for b in phiblocks + ϕssa = phinodes[b] + n = ir[ϕssa][:inst]::PhiNode + t = Bottom + for p in ir.cfg.blocks[b].preds + push!(n.edges, p) + v = compute_value_for_block(ir, domtree, allblocks, fdu, phinodes, fidx, p) + push!(n.values, v) + if t !== Any + t = tmerge(t, argextype(v, ir)) + end + end + ir[ϕssa][:type] = t end + @label next_use end - preserve_uses === nothing && continue - if all_forwarded - # this means all ccall preserves have been replaced with forwarded loads - # so we can potentially eliminate the allocation, otherwise we must preserve - # the whole allocation. - push!(intermediaries, newidx) + push!(revisit, (related, Liveness)) + @label next_itr + end + + # remove dead setfield! and :new allocs + deadssas = IdSet{SSAValue}() + if all_preserved && newpreserves !== nothing + preserved = keys(newpreserves) + else + preserved = EMPTY_PRESERVED_SSAS + end + mark_dead_ssas!(ir, deadssas, revisit, eliminated, preserved) + for ssa in deadssas + ir[ssa][:inst] = nothing + end + if all_preserved && newpreserves !== nothing + deadssas = Int[ssa.id for ssa in deadssas] + for (idx, newuses) in newpreserves + ir[SSAValue(idx)][:inst] = form_new_preserves( + ir[SSAValue(idx)][:inst]::Expr, deadssas, newuses) end - # Insert the new preserves - for (use, new_preserves) in preserve_uses - ir[SSAValue(use)][:inst] = form_new_preserves(ir[SSAValue(use)][:inst]::Expr, intermediaries, new_preserves) + end + + return ir +end + +const EMPTY_PRESERVED_SSAS = keys(IdDict{Int,Vector{Any}}()) +const PreservedSets = typeof(EMPTY_PRESERVED_SSAS) + +function is_load_forwardable(x::EscapeInfo) + AliasInfo = x.AliasInfo + return isa(AliasInfo, IndexableFields) +end + +struct FieldDefUse + uses::Vector{Any} + defs::Vector{Int} +end +FieldDefUse() = FieldDefUse(Any[], Int[]) +struct GetfieldLoad + idx::Int +end +struct PreserveUse + idx::Int +end +struct IsdefinedUse + idx::Int +end +function getuseidx(@nospecialize use) + if isa(use, GetfieldLoad) + return use.idx + elseif isa(use, PreserveUse) + return use.idx + elseif isa(use, IsdefinedUse) + return use.idx + end + throw("getuseidx: unexpected use") +end + +function compute_live_ins(cfg::CFG, fdu::FieldDefUse) + uses = Int[] + for use in fdu.uses + isa(use, IsdefinedUse) && continue + push!(uses, getuseidx(use)) + end + return compute_live_ins(cfg, fdu.defs, uses) +end + +# even when the allocation contains an uninitialized field, we try an extra effort to check +# if this load at `idx` have any "safe" `setfield!` calls that define the field +# try to find +function has_safe_def(ir::IRCode, domtree::DomTree, allblocks::Vector{Int}, + fdu::FieldDefUse, use::Int) + dfu = find_def_for_use(ir, domtree, allblocks, fdu, use) + dfu === nothing && return false + def = dfu[1] + def ≠ 0 && return true # found a "safe" definition + # we may still be able to replace this load with `PhiNode` -- examine if all predecessors of + # this `block` have any "safe" definition + block = block_for_inst(ir, use) + seen = BitSet(block) + worklist = BitSet(ir.cfg.blocks[block].preds) + isempty(worklist) && return false + while !isempty(worklist) + pred = pop!(worklist) + # if this block has already been examined, bail out to avoid infinite cycles + pred in seen && return false + use = last(ir.cfg.blocks[pred].stmts) + # NOTE this `use` isn't a load, and so the inclusive condition can be used + dfu = find_def_for_use(ir, domtree, allblocks, fdu, use, true) + dfu === nothing && return false + def = dfu[1] + push!(seen, pred) + def ≠ 0 && continue # found a "safe" definition for this predecessor + # if not, check for the predecessors of this predecessor + for newpred in ir.cfg.blocks[pred].preds + push!(worklist, newpred) end + end + return true +end - @label skip +# find the first dominating def for the given use +function find_def_for_use(ir::IRCode, domtree::DomTree, allblocks::Vector{Int}, + fdu::FieldDefUse, use::Int, inclusive::Bool=false) + useblock = block_for_inst(ir.cfg, use) + curblock = find_curblock(domtree, allblocks, useblock) + curblock === nothing && return nothing + local def = 0 + for idx in fdu.defs + if block_for_inst(ir.cfg, idx) == curblock + if curblock != useblock + # Find the last def in this block + def = max(def, idx) + else + # Find the last def before our use + if inclusive + def = max(def, idx ≤ use ? idx : 0) + else + def = max(def, idx < use ? idx : 0) + end + end + end end + return def, useblock, curblock end -function form_new_preserves(origex::Expr, intermediates::Vector{Int}, new_preserves::Vector{Any}) - newex = Expr(:foreigncall) - nccallargs = length(origex.args[3]::SimpleVector) - for i in 1:(6+nccallargs-1) - push!(newex.args, origex.args[i]) +function find_curblock(domtree::DomTree, allblocks::Vector{Int}, curblock::Int) + # TODO: This can be much faster by looking at current level and only + # searching for those blocks in a sorted order + while !(curblock in allblocks) + curblock = domtree.idoms_bb[curblock] + curblock == 0 && return nothing end - for i in (6+nccallargs):length(origex.args) - x = origex.args[i] - # don't need to preserve intermediaries - if isa(x, SSAValue) && x.id in intermediates - continue + return curblock +end + +function compute_value_for_use(ir::IRCode, domtree::DomTree, allblocks::Vector{Int}, + fdu::FieldDefUse, phinodes::IdDict{Int, SSAValue}, fidx::Int, use::Int) + dfu = find_def_for_use(ir, domtree, allblocks, fdu, use) + @assert dfu !== nothing "has_safe_def condition unsatisfied" + def, useblock, curblock = dfu + if def == 0 + if !haskey(phinodes, curblock) + # If this happens, we need to search the predecessors for defs. Which + # one doesn't matter - if it did, we'd have had a phinode + return compute_value_for_block(ir, domtree, allblocks, fdu, phinodes, fidx, first(ir.cfg.blocks[useblock].preds)) end - push!(newex.args, x) + # The use is the phinode + return phinodes[curblock] + else + return val_for_def_expr(ir, def, fidx) end - for i in 1:length(new_preserves) - push!(newex.args, new_preserves[i]) +end + +function compute_value_for_block(ir::IRCode, domtree::DomTree, allblocks::Vector{Int}, + fdu::FieldDefUse, phinodes::IdDict{Int, SSAValue}, fidx::Int, curblock::Int) + curblock = find_curblock(domtree, allblocks, curblock) + @assert curblock !== nothing "has_safe_def condition unsatisfied" + def = 0 + for stmt in fdu.defs + if block_for_inst(ir.cfg, stmt) == curblock + def = max(def, stmt) + end end - return newex + return def == 0 ? phinodes[curblock] : val_for_def_expr(ir, def, fidx) +end + +function val_for_def_expr(ir::IRCode, def::Int, fidx::Int) + ex = ir[SSAValue(def)][:inst] + if isexpr(ex, :new) + return ex.args[1+fidx] + else + @assert is_known_call(ex, setfield!, ir) "invalid load forwarding" + return ex.args[4] + end +end + +function mark_dead_ssas!(ir::IRCode, deadssas::IdSet{SSAValue}, + revisit::Vector{Tuple{Vector{SSAValue},LivenessSet}}, eliminated::BitSet, + preserved::PreservedSets) + wset = BitSet(1:length(revisit)) + while !isempty(wset) + revisit_idx = pop!(wset) + mark_dead_ssas!(ir, deadssas, revisit, eliminated, preserved, wset, revisit_idx) + end +end + +function mark_dead_ssas!(ir::IRCode, deadssas::IdSet{SSAValue}, + revisit::Vector{Tuple{Vector{SSAValue},LivenessSet}}, eliminated::BitSet, + preserved::PreservedSets, wset::BitSet, revisit_idx::Int) + related, Liveness = revisit[revisit_idx] + eliminable = SSAValue[] + for livepc in Liveness + livepc in eliminated && @goto next_live + ssa = SSAValue(livepc) + stmt = ir[ssa][:inst] + if isexpr(stmt, :new) + ssa in deadssas && @goto next_live + for new_revisit_idx in wset + if ssa in revisit[new_revisit_idx][1] + delete!(wset, new_revisit_idx) + if mark_dead_ssas!(ir, deadssas, + revisit, eliminated, + preserved, wset, new_revisit_idx) + push!(eliminable, ssa) + @goto next_live + else + return false + end + end + end + return false + elseif is_known_call(stmt, setfield!, ir) + @assert length(stmt.args) ≥ 4 "invalid escape analysis" + obj = stmt.args[2] + val = stmt.args[4] + if isa(obj, SSAValue) + if obj in related + push!(eliminable, ssa) + @goto next_live + end + if isa(val, SSAValue) && val in related + if obj in deadssas + push!(eliminable, ssa) + @goto next_live + end + for new_revisit_idx in wset + if obj in revisit[new_revisit_idx][1] + delete!(wset, new_revisit_idx) + if mark_dead_ssas!(ir, deadssas, + revisit, eliminated, + preserved, wset, new_revisit_idx) + push!(eliminable, ssa) + @goto next_live + else + return false + end + end + end + end + end + return false + elseif isexpr(stmt, :foreigncall) + livepc in preserved && @goto next_live + return false + else + return false + end + @label next_live + end + for ssa in related; push!(deadssas, ssa); end + for ssa in eliminable; push!(deadssas, ssa); end + return true end """ @@ -1084,15 +1190,15 @@ In addition to a simple DCE for unused values and allocations, this pass also nullifies `typeassert` calls that can be proved to be no-op, in order to allow LLVM to emit simpler code down the road. -Note that this pass is more effective after SROA optimization (i.e. `sroa_pass!`), +Note that this pass is more effective after SROA optimization (i.e. `linear_pass!`), since SROA often allows this pass to: - eliminate allocation of object whose field references are all replaced with scalar values, and - nullify `typeassert` call whose first operand has been replaced with a scalar value (, which may have introduced new type information that inference did not understand) -Also note that currently this pass _needs_ to run after `sroa_pass!`, because +Also note that currently this pass _needs_ to run after `linear_pass!`, because the `typeassert` elimination depends on the transformation by `canonicalize_typeassert!` done -within `sroa_pass!` which redirects references of `typeassert`ed value to the corresponding `PiNode`. +within `linear_pass!` which redirects references of `typeassert`ed value to the corresponding `PiNode`. """ function adce_pass!(ir::IRCode) phi_uses = fill(0, length(ir.stmts) + length(ir.new_nodes)) diff --git a/test/compiler/EscapeAnalysis/EAUtils.jl b/test/compiler/EscapeAnalysis/EAUtils.jl index 3ae9b41a0ddac..a76d87938344f 100644 --- a/test/compiler/EscapeAnalysis/EAUtils.jl +++ b/test/compiler/EscapeAnalysis/EAUtils.jl @@ -71,7 +71,7 @@ import Core: CodeInstance, MethodInstance, CodeInfo import .CC: InferenceResult, OptimizationState, IRCode, copy as cccopy, - @timeit, convert_to_ircode, slot2reg, compact!, ssa_inlining_pass!, sroa_pass!, + @timeit, convert_to_ircode, slot2reg, compact!, ssa_inlining_pass!, linear_pass!, adce_pass!, type_lift_pass!, JLOptions, verify_ir, verify_linetable import .EA: analyze_escapes, ArgEscapeCache, EscapeInfo, EscapeState, is_ipo_profitable @@ -240,7 +240,7 @@ function run_passes_with_ea(interp::EscapeAnalyzer, ci::CodeInfo, sv::Optimizati interp.state = state interp.linfo = sv.linfo end - @timeit "SROA" ir = sroa_pass!(ir) + @timeit "SROA" ir, _ = linear_pass!(ir) @timeit "ADCE" ir = adce_pass!(ir) @timeit "type lift" ir = type_lift_pass!(ir) @timeit "compact 3" ir = compact!(ir) diff --git a/test/compiler/irpasses.jl b/test/compiler/irpasses.jl index 128fd6cc84b7b..7793489d0fc2b 100644 --- a/test/compiler/irpasses.jl +++ b/test/compiler/irpasses.jl @@ -2,7 +2,9 @@ using Test using Base.Meta -using Core: PhiNode, SSAValue, GotoNode, PiNode, QuoteNode, ReturnNode, GotoIfNot +import Core: + CodeInfo, Argument, SSAValue, GotoNode, GotoIfNot, PiNode, PhiNode, + QuoteNode, ReturnNode include(normpath(@__DIR__, "irutils.jl")) @@ -12,7 +14,7 @@ include(normpath(@__DIR__, "irutils.jl")) ## Test that domsort doesn't mangle single-argument phis (#29262) let m = Meta.@lower 1 + 1 @assert Meta.isexpr(m, :thunk) - src = m.args[1]::Core.CodeInfo + src = m.args[1]::CodeInfo src.code = Any[ # block 1 Expr(:call, :opaque), @@ -47,7 +49,7 @@ end # test that we don't stack-overflow in SNCA with large functions. let m = Meta.@lower 1 + 1 @assert Meta.isexpr(m, :thunk) - src = m.args[1]::Core.CodeInfo + src = m.args[1]::CodeInfo code = Any[] N = 2^15 for i in 1:2:N @@ -73,30 +75,87 @@ end # SROA # ==== +import Core.Compiler: widenconst + +is_load_forwarded(src::CodeInfo) = !any(iscall((src, getfield)), src.code) +is_scalar_replaced(src::CodeInfo) = + is_load_forwarded(src) && !any(iscall((src, setfield!)), src.code) && !any(isnew, src.code) + +function is_load_forwarded(@nospecialize(T), src::CodeInfo) + for i in 1:length(src.code) + x = src.code[i] + if iscall((src, getfield), x) + widenconst(argextype(x.args[1], src)) <: T && return false + end + end + return true +end +function is_scalar_replaced(@nospecialize(T), src::CodeInfo) + is_load_forwarded(T, src) || return false + for i in 1:length(src.code) + x = src.code[i] + if iscall((src, setfield!), x) + widenconst(argextype(x.args[1], src)) <: T && return false + elseif isnew(x) + widenconst(argextype(SSAValue(i), src)) <: T && return false + end + end + return true +end + struct ImmutableXYZ; x; y; z; end mutable struct MutableXYZ; x; y; z; end +struct ImmutableOuter{T}; x::T; y::T; z::T; end +mutable struct MutableOuter{T}; x::T; y::T; z::T; end +struct ImmutableRef{T}; x::T; end +Base.getindex(r::ImmutableRef) = r.x +mutable struct SafeRef{T}; x::T; end +Base.getindex(s::SafeRef) = getfield(s, 1) +Base.setindex!(s::SafeRef, x) = setfield!(s, 1, x) + +# simple immutability +# ------------------- -# should optimize away very basic cases let src = code_typed1((Any,Any,Any)) do x, y, z xyz = ImmutableXYZ(x, y, z) xyz.x, xyz.y, xyz.z end - @test !any(isnew, src.code) + @test is_scalar_replaced(src) + @test any(src.code) do @nospecialize x + iscall((src, tuple), x) && + x.args[2:end] == Any[#=x=# Core.Argument(2), #=y=# Core.Argument(3), #=z=# Core.Argument(4)] + end +end +let src = code_typed1((Any,Any,Any)) do x, y, z + xyz = (x, y, z) + xyz[1], xyz[2], xyz[3] + end + @test is_scalar_replaced(src) + @test any(src.code) do @nospecialize x + iscall((src, tuple), x) && + x.args[2:end] == Any[#=x=# Core.Argument(2), #=y=# Core.Argument(3), #=z=# Core.Argument(4)] + end end + +# simple mutability +# ----------------- + let src = code_typed1((Any,Any,Any)) do x, y, z xyz = MutableXYZ(x, y, z) xyz.x, xyz.y, xyz.z end - @test !any(isnew, src.code) + @test is_scalar_replaced(src) + @test any(src.code) do @nospecialize x + iscall((src, tuple), x) && + x.args[2:end] == Any[#=x=# Core.Argument(2), #=y=# Core.Argument(3), #=z=# Core.Argument(4)] + end end - -# should handle simple mutabilities let src = code_typed1((Any,Any,Any)) do x, y, z xyz = MutableXYZ(x, y, z) xyz.y = 42 xyz.x, xyz.y, xyz.z end - @test !any(isnew, src.code) + @test is_scalar_replaced(src) @test any(src.code) do @nospecialize x iscall((src, tuple), x) && x.args[2:end] == Any[#=x=# Core.Argument(2), 42, #=x=# Core.Argument(4)] @@ -107,19 +166,23 @@ let src = code_typed1((Any,Any,Any)) do x, y, z xyz.x, xyz.z = xyz.z, xyz.x xyz.x, xyz.y, xyz.z end - @test !any(isnew, src.code) + @test is_scalar_replaced(src) @test any(src.code) do @nospecialize x iscall((src, tuple), x) && x.args[2:end] == Any[#=z=# Core.Argument(4), #=y=# Core.Argument(3), #=x=# Core.Argument(2)] end end -# circumvent uninitialized fields as far as there is a solid `setfield!` definition + +# uninitialized fields +# -------------------- + +# safe cases let src = code_typed1() do r = Ref{Any}() r[] = 42 return r[] end - @test !any(isnew, src.code) + @test is_scalar_replaced(src) end let src = code_typed1((Bool,)) do cond r = Ref{Any}() @@ -131,7 +194,7 @@ let src = code_typed1((Bool,)) do cond return r[] end end - @test !any(isnew, src.code) + @test is_scalar_replaced(src) end let src = code_typed1((Bool,)) do cond r = Ref{Any}() @@ -142,7 +205,7 @@ let src = code_typed1((Bool,)) do cond end return r[] end - @test !any(isnew, src.code) + @test is_scalar_replaced(src) end let src = code_typed1((Bool,Bool,Any,Any,Any)) do c1, c2, x, y, z r = Ref{Any}() @@ -157,7 +220,16 @@ let src = code_typed1((Bool,Bool,Any,Any,Any)) do c1, c2, x, y, z end return r[] end - @test !any(isnew, src.code) + @test is_scalar_replaced(src) +end + +# unsafe cases +let src = code_typed1() do + r = Ref{Any}() + return r[] + end + @test count(isnew, src.code) == 1 + @test count(iscall((src, getfield)), src.code) == 1 end let src = code_typed1((Bool,)) do cond r = Ref{Any}() @@ -167,7 +239,9 @@ let src = code_typed1((Bool,)) do cond return r[] end # N.B. `r` should be allocated since `cond` might be `false` and then it will be thrown - @test any(isnew, src.code) + @test count(isnew, src.code) == 1 + @test count(iscall((src, setfield!)), src.code) == 1 + @test count(iscall((src, getfield)), src.code) == 1 end let src = code_typed1((Bool,Bool,Any,Any)) do c1, c2, x, y r = Ref{Any}() @@ -181,12 +255,95 @@ let src = code_typed1((Bool,Bool,Any,Any)) do c1, c2, x, y return r[] end # N.B. `r` should be allocated since `c2` might be `false` and then it will be thrown - @test any(isnew, src.code) + @test count(isnew, src.code) == 1 + @test count(iscall((src, setfield!)), src.code) == 2 + @test count(iscall((src, getfield)), src.code) == 1 end -# should include a simple alias analysis -struct ImmutableOuter{T}; x::T; y::T; z::T; end -mutable struct MutableOuter{T}; x::T; y::T; z::T; end +# load forwarding +# --------------- +# even if allocation can't be eliminated + +# safe cases +for T in (ImmutableRef{Any}, Ref{Any}) + let src = @eval code_typed1((Bool,Any,)) do c, a + r = $T(a) + if c + return r[] + else + return r + end + end + @test is_load_forwarded(src) + @test count(isnew, src.code) == 1 + end + let src = @eval code_typed1((Bool,String,)) do c, a + r = $T(a) + if c + return r[]::String # adce_pass! will further eliminate this type assert call also + else + return r + end + end + @test is_load_forwarded(src) + @test count(isnew, src.code) == 1 + @test !any(iscall((src, typeassert)), src.code) + end + let src = @eval code_typed1((Bool,Any,)) do c, a + r = $T(a) + if c + return r[] + else + throw(r) + end + end + @test is_load_forwarded(src) + @test count(isnew, src.code) == 1 + end +end +let src = code_typed1((Bool,Any,Any)) do c, a, b + r = Ref{Any}(a) + if c + return r[] + end + r[] = b + return r + end + @test is_load_forwarded(src) + @test count(isnew, src.code) == 1 + @test count(iscall((src, setfield!)), src.code) == 1 + @test count(src.code) do @nospecialize x + isreturn(x) && x.val === Argument(3) # a + end == 1 +end + +# unsafe case +let src = code_typed1((Bool,Any,Any)) do c, a, b + r = Ref{Any}(a) + r[] = b + @noinline some_escape!(r) + return r[] + end + @test !is_load_forwarded(src) + @test count(isnew, src.code) == 1 + @test count(iscall((src, setfield!)), src.code) == 1 +end +let src = code_typed1((Bool,String,Regex)) do c, a, b + r1 = Ref{Any}(a) + r2 = Ref{Any}(b) + return ifelse(c, r1, r2)[] + end + r = only(findall(isreturn, src.code)) + v = (src.code[r]::Core.ReturnNode).val + @test v !== Argument(3) # a + @test v !== Argument(4) # b + @test_broken is_load_forwarded(src) # ideally +end + +# aliased load forwarding +# ----------------------- + +# OK: immutable(immutable(...)) case let src = code_typed1((Any,Any,Any)) do x, y, z xyz = ImmutableXYZ(x, y, z) outer = ImmutableOuter(xyz, xyz, xyz) @@ -214,7 +371,6 @@ let src = code_typed1((Any,Any,Any)) do x, y, z end end -# FIXME our analysis isn't yet so powerful at this moment: may be unable to handle nested objects well # OK: mutable(immutable(...)) case let src = code_typed1((Any,Any,Any)) do x, y, z xyz = MutableXYZ(x, y, z) @@ -222,14 +378,14 @@ let src = code_typed1((Any,Any,Any)) do x, y, z v = t[1].x v, v, v end - @test !any(isnew, src.code) + @test is_scalar_replaced(src) end let src = code_typed1((Any,Any,Any)) do x, y, z xyz = MutableXYZ(x, y, z) outer = ImmutableOuter(xyz, xyz, xyz) outer.x.x, outer.y.y, outer.z.z end - @test !any(isnew, src.code) + @test is_scalar_replaced(src) @test any(src.code) do @nospecialize x iscall((src, tuple), x) && x.args[2:end] == Any[#=x=# Core.Argument(2), #=y=# Core.Argument(3), #=y=# Core.Argument(4)] @@ -240,32 +396,489 @@ let # this is a simple end to end test case, which demonstrates allocation elimi # NOTE this test case isn't so robust and might be subject to future changes of the broadcasting implementation, # in that case you don't really need to stick to keeping this test case around simple_sroa(s) = broadcast(identity, Ref(s)) + let src = code_typed1(simple_sroa, (String,)) + @test is_scalar_replaced(src) + end s = Base.inferencebarrier("julia")::String simple_sroa(s) # NOTE don't hard-code `"julia"` in `@allocated` clause and make sure to execute the # compiled code for `simple_sroa`, otherwise everything can be folded even without SROA @test @allocated(simple_sroa(s)) == 0 end -# FIXME: immutable(mutable(...)) case +let # some insanely nested example + src = code_typed1((Int,)) do x + (Ref(Ref(Ref(Ref(Ref(Ref(Ref(Ref(Ref(Ref((x))))))))))))[][][][][][][][][][] + end + @test is_scalar_replaced(src) +end + +# OK: immutable(mutable(...)) case let src = code_typed1((Any,Any,Any)) do x, y, z xyz = ImmutableXYZ(x, y, z) outer = MutableOuter(xyz, xyz, xyz) outer.x.x, outer.y.y, outer.z.z end - @test_broken !any(isnew, src.code) + @test is_scalar_replaced(src) + @test any(src.code) do @nospecialize x + iscall((src, tuple), x) && + x.args[2:end] == Any[#=x=# Core.Argument(2), #=y=# Core.Argument(3), #=z=# Core.Argument(4)] + end +end + +# OK: mutable(mutable(...)) case +# new chain +let src = code_typed1((Any,Any,Any)) do x, y, z + xyz = MutableXYZ(x, y, z) + outer = MutableOuter(xyz, xyz, xyz) + outer.x.x, outer.y.y, outer.z.z + end + @test is_scalar_replaced(src) + @test any(src.code) do @nospecialize x + iscall((src, tuple), x) && + x.args[2:end] == Any[#=x=# Core.Argument(2), #=y=# Core.Argument(3), #=z=# Core.Argument(4)] + end end -# FIXME: mutable(mutable(...)) case let src = code_typed1((Any,Any,Any)) do x, y, z xyz = MutableXYZ(x, y, z) + xyz.x, xyz.y, xyz.z = z, y, x outer = MutableOuter(xyz, xyz, xyz) outer.x.x, outer.y.y, outer.z.z end - @test_broken !any(isnew, src.code) + @test is_scalar_replaced(src) + @test any(src.code) do @nospecialize x + iscall((src, tuple), x) && + x.args[2:end] == Any[#=z=# Core.Argument(4), #=y=# Core.Argument(3), #=x=# Core.Argument(2)] + end +end +let src = code_typed1((Any,Any,Any)) do x, y, z + xyz = MutableXYZ(x, y, z) + xyz.x, xyz.y, xyz.z = xyz.z, xyz.y, xyz.x + outer = MutableOuter(xyz, xyz, xyz) + outer.x.x, outer.y.y, outer.z.z + end + @test is_scalar_replaced(src) + @test any(src.code) do @nospecialize x + iscall((src, tuple), x) && + x.args[2:end] == Any[#=z=# Core.Argument(4), #=y=# Core.Argument(3), #=x=# Core.Argument(2)] + end +end +let src = code_typed1((Any,Any,Any)) do x, y, z + xyz = MutableXYZ(x, y, z) + inner = MutableOuter(xyz, xyz, xyz) + outer = MutableOuter(inner, inner, inner) + outer.x.x.x, outer.y.y.y, outer.z.z.z + end + @test is_scalar_replaced(src) + @test any(src.code) do @nospecialize x + iscall((src, tuple), x) && + x.args[2:end] == Any[#=x=# Core.Argument(2), #=y=# Core.Argument(3), #=z=# Core.Argument(4)] + end +end +let src = code_typed1((Any,Any,Any)) do x, y, z + xyz = MutableXYZ(x, y, z) + xyz.x, xyz.y, xyz.z = z, y, x + inner = MutableOuter(xyz, xyz, xyz) + outer = MutableOuter(inner, inner, inner) + outer.x.x.x, outer.y.y.y, outer.z.z.z + end + @test is_scalar_replaced(src) + @test any(src.code) do @nospecialize x + iscall((src, tuple), x) && + x.args[2:end] == Any[#=z=# Core.Argument(4), #=y=# Core.Argument(3), #=x=# Core.Argument(2)] + end +end +# setfield! chain +let src = code_typed1((Any,Any,Any)) do x, y, z + xyz = MutableXYZ(x, y, z) + outer = Ref{MutableXYZ}() + outer[] = xyz + return outer[].x, outer[].y, outer[].z + end + @test is_scalar_replaced(src) + @test any(src.code) do @nospecialize x + iscall((src, tuple), x) && + x.args[2:end] == Any[#=x=# Core.Argument(2), #=y=# Core.Argument(3), #=z=# Core.Argument(4)] + end +end +let src = code_typed1((Any,Any,Any)) do x, y, z + xyz = MutableXYZ(x, y, z) + outer = Ref{MutableXYZ}() + outer[] = xyz + xyz.z = 42 + return outer[].x, outer[].y, outer[].z + end + @test is_scalar_replaced(src) + @test any(src.code) do @nospecialize x + iscall((src, tuple), x) && + x.args[2:end] == Any[#=x=# Core.Argument(2), #=y=# Core.Argument(3), 42] + end end -let # should work with constant globals - # immutable case - # -------------- +# ϕ-allocation elimination +# ------------------------ + +# safe cases +let src = code_typed1((Bool,Any,Any)) do cond, x, y + if cond + ϕ = Ref{Any}(x) + else + ϕ = Ref{Any}(y) + end + ϕ[] + end + @test is_scalar_replaced(src) + @test count(src.code) do @nospecialize x + isa(x, Core.PhiNode) && + #=x=# Core.Argument(3) in x.values && + #=y=# Core.Argument(4) in x.values + end == 1 +end +let src = code_typed1((Bool,Bool,Any,Any,Any)) do cond1, cond2, x, y, z + if cond1 + ϕ = Ref{Any}(x) + elseif cond2 + ϕ = Ref{Any}(y) + else + ϕ = Ref{Any}(z) + end + ϕ[] + end + @test is_scalar_replaced(src) + @test count(src.code) do @nospecialize x + isa(x, Core.PhiNode) && + #=x=# Core.Argument(4) in x.values && + #=y=# Core.Argument(5) in x.values && + #=z=# Core.Argument(6) in x.values + end == 1 +end +let src = code_typed1((Bool,Any,Any,Any)) do cond, x, y, z + if cond + ϕ = Ref{Any}(x) + else + ϕ = Ref{Any}(y) + end + ϕ[] = z + ϕ[] + end + @test is_scalar_replaced(src) + @test count(src.code) do @nospecialize x + isa(x, Core.ReturnNode) && + #=z=# Core.Argument(5) === x.val + end == 1 +end +let src = code_typed1((Bool,Any,Any,)) do cond, x, y + if cond + ϕ = Ref{Any}(x) + out1 = ϕ[] + else + ϕ = Ref{Any}(y) + out1 = ϕ[] + end + out2 = ϕ[] + out1, out2 + end + @test is_scalar_replaced(src) + @test count(src.code) do @nospecialize x + isa(x, Core.PhiNode) && + #=x=# Core.Argument(3) in x.values && + #=y=# Core.Argument(4) in x.values + end == 2 +end +let src = code_typed1((Bool,Any,Any,Any)) do cond, x, y, z + if cond + ϕ = Ref{Any}(x) + else + ϕ = Ref{Any}(y) + ϕ[] = z + end + ϕ[] + end + @test is_scalar_replaced(src) + @test count(src.code) do @nospecialize x + isa(x, Core.PhiNode) && + #=x=# Core.Argument(3) in x.values && + #=z=# Core.Argument(5) in x.values + end == 1 +end +let src = code_typed1((Bool,Any,Any,Any)) do cond, x, y, z + if cond + ϕ = Ref{Any}(x) + out1 = ϕ[] + else + ϕ = Ref{Any}(y) + out1 = ϕ[] + ϕ[] = z + end + out2 = ϕ[] + out1, out2 + end + @test is_scalar_replaced(src) + @test count(src.code) do @nospecialize x + isa(x, Core.PhiNode) && + #=x=# Core.Argument(3) in x.values && + #=y=# Core.Argument(4) in x.values + end == 1 + @test count(src.code) do @nospecialize x + isa(x, Core.PhiNode) && + #=x=# Core.Argument(3) in x.values && + #=z=# Core.Argument(5) in x.values + end == 1 +end +let src = code_typed1((Bool,Any,Any)) do cond, x, y + # these allocation form multiple ϕ-nodes + if cond + ϕ2 = ϕ1 = Ref{Any}(x) + else + ϕ2 = ϕ1 = Ref{Any}(y) + end + ϕ1[], ϕ2[] + end + @test is_scalar_replaced(src) + @test count(src.code) do @nospecialize x + isa(x, Core.PhiNode) && + #=x=# Core.Argument(3) in x.values && + #=y=# Core.Argument(4) in x.values + end == 1 +end +let src = code_typed1((Bool,String,)) do cond, x + # these allocation form multiple ϕ-nodes + if cond + ϕ2 = ϕ1 = Ref{Any}("foo") + else + ϕ2 = ϕ1 = Ref{Any}("bar") + end + ϕ2[] = x + y = ϕ1[] # => x + return y + end + @test is_scalar_replaced(src) + @test count(src.code) do @nospecialize x + isa(x, Core.ReturnNode) && + #=x=# x.val === Core.Argument(3) + end == 1 +end + +# unsafe cases +let src = code_typed1((Bool,Any,Any)) do cond, x, y + if cond + ϕ = Ref{Any}(x) + else + ϕ = Ref{Any}(y) + end + some_escape!(ϕ) + ϕ[] + end + @test count(isnew, src.code) == 2 + @test count(iscall((src, getfield)), src.code) == 1 +end +let src = code_typed1((Bool,Any,Any)) do cond, x, y + if cond + ϕ = Ref{Any}(x) + some_escape!(ϕ) + else + ϕ = Ref{Any}(y) + end + ϕ[] + end + @test count(isnew, src.code) == 2 + @test count(iscall((src, getfield)), src.code) == 1 +end +let src = code_typed1((Bool,Any,)) do cond, x + if cond + ϕ = Ref{Any}(x) + else + ϕ = Ref{Any}() + end + ϕ[] + end + @test count(isnew, src.code) == 2 + @test count(iscall((src, getfield)), src.code) == 1 +end +let src = code_typed1((Bool,Any)) do c, a + local r + if c + r = Ref{Any}(a) + end + (r::Base.RefValue{Any})[] + end + @test count(isnew, src.code) == 1 + @test count(iscall((src, getfield)), src.code) == 1 +end + +function mutable_ϕ_elim(x, xs) + r = Ref(x) + for x in xs + r = Ref(x) + end + return r[] +end +let src = code_typed1(mutable_ϕ_elim, (String, Vector{String})) + @test is_scalar_replaced(src) + + xs = String[string(gensym()) for _ in 1:100] + mutable_ϕ_elim("init", xs) + @test @allocated(mutable_ϕ_elim("init", xs)) == 0 +end + +# demonstrate the power of our field / alias analysis with realistic end to end examples +# adapted from http://wiki.luajit.org/Allocation-Sinking-Optimization#implementation%5B +abstract type AbstractPoint{T} end +struct Point{T} <: AbstractPoint{T} + x::T + y::T +end +mutable struct MPoint{T} <: AbstractPoint{T} + x::T + y::T +end +add(a::P, b::P) where P<:AbstractPoint = P(a.x + b.x, a.y + b.y) +function compute_point(T, n, ax, ay, bx, by) + a = T(ax, ay) + b = T(bx, by) + for i in 0:(n-1) + a = add(add(a, b), b) + end + a.x, a.y +end +function compute_point(n, a, b) + for i in 0:(n-1) + a = add(add(a, b), b) + end + a.x, a.y +end +function compute_point!(n, a, b) + for i in 0:(n-1) + a′ = add(add(a, b), b) + a.x = a′.x + a.y = a′.y + end +end + +let # immutable case + src = code_typed1((Int,)) do n + compute_point(Point, n, 1+.5, 2+.5, 2+.25, 4+.75) + end + @test is_scalar_replaced(Point, src) + src = code_typed1((Int,)) do n + compute_point(Point, n, 1+.5im, 2+.5im, 2+.25im, 4+.75im) + end + @test is_scalar_replaced(Point, src) + @test is_load_forwarded(ComplexF64, src) + @test !is_scalar_replaced(ComplexF64, src) + + # mutable case + src = code_typed1((Int,)) do n + compute_point(MPoint, n, 1+.5, 2+.5, 2+.25, 4+.75) + end + @test is_scalar_replaced(MPoint, src) + src = code_typed1((Int,)) do n + compute_point(MPoint, n, 1+.5im, 2+.5im, 2+.25im, 4+.75im) + end + @test is_scalar_replaced(MPoint, src) + @test is_load_forwarded(ComplexF64, src) + @test !is_scalar_replaced(ComplexF64, src) +end +compute_point(MPoint, 10, 1+.5, 2+.5, 2+.25, 4+.75) +compute_point(MPoint, 10, 1+.5im, 2+.5im, 2+.25im, 4+.75im) +@test @allocated(compute_point(MPoint, 10000, 1+.5, 2+.5, 2+.25, 4+.75)) == 0 +@test @allocated(compute_point(MPoint, 10000, 1+.5im, 2+.5im, 2+.25im, 4+.75im)) == 0 + +let # immutable case + src = code_typed1((Int,)) do n + compute_point(n, Point(1+.5, 2+.5), Point(2+.25, 4+.75)) + end + @test is_scalar_replaced(Point, src) + src = code_typed1((Int,)) do n + compute_point(n, Point(1+.5im, 2+.5im), Point(2+.25im, 4+.75im)) + end + @test is_scalar_replaced(Point, src) + @test is_load_forwarded(ComplexF64, src) + @test !is_scalar_replaced(ComplexF64, src) + + # mutable case + src = code_typed1((Int,)) do n + compute_point(n, MPoint(1+.5, 2+.5), MPoint(2+.25, 4+.75)) + end + @test is_scalar_replaced(MPoint, src) + src = code_typed1((Int,)) do n + compute_point(n, MPoint(1+.5im, 2+.5im), MPoint(2+.25im, 4+.75im)) + end + @test is_scalar_replaced(MPoint, src) + @test is_load_forwarded(ComplexF64, src) + @test !is_scalar_replaced(ComplexF64, src) +end +compute_point(10, MPoint(1+.5, 2+.5), MPoint(2+.25, 4+.75)) +compute_point(10, MPoint(1+.5im, 2+.5im), MPoint(2+.25im, 4+.75im)) +@test @allocated(compute_point(10000, MPoint(1+.5, 2+.5), MPoint(2+.25, 4+.75))) == 0 +@test @allocated(compute_point(10000, MPoint(1+.5im, 2+.5im), MPoint(2+.25im, 4+.75im))) == 0 + +let # mutable case + src = code_typed1(compute_point!, (Int,MPoint{Float64},MPoint{Float64})) + @test is_scalar_replaced(MPoint, src) + src = code_typed1(compute_point!, (Int,MPoint{ComplexF64},MPoint{ComplexF64})) + @test is_scalar_replaced(MPoint, src) + @test is_load_forwarded(ComplexF64, src) + @test !is_scalar_replaced(ComplexF64, src) +end +let + af, bf = MPoint(1+.5, 2+.5), MPoint(2+.25, 4+.75) + ac, bc = MPoint(1+.5im, 2+.5im), MPoint(2+.25im, 4+.75im) + compute_point!(10, af, bf) + compute_point!(10, ac, bc) + @test @allocated(compute_point!(10000, af, bf)) == 0 + @test @allocated(compute_point!(10000, ac, bc)) == 0 +end + +# isdefined elimination +# --------------------- + +let src = code_typed1((Any,)) do a + r = Ref{Any}() + r[] = a + if isassigned(r) + return r[] + end + return nothing + end + @test is_scalar_replaced(src) +end + +callit(f, args...) = f(args...) +function isdefined_elim() + local arr::Vector{Any} + callit() do + arr = Any[] + end + return arr +end +let src = code_typed1(isdefined_elim) + @test is_scalar_replaced(src) +end +@test isdefined_elim() == Any[] + +# preserve elimination +# -------------------- + +let src = code_typed1((String,)) do s + ccall(:some_ccall, Cint, (Ptr{String},), Ref(s)) + end + @test count(isnew, src.code) == 0 +end + +# if the mutable struct is directly used, we shouldn't eliminate it +let src = code_typed1() do + a = MutableXYZ(-512275808,882558299,-2133022131) + b = Int32(42) + ccall(:some_ccall, Cvoid, (MutableXYZ, Int32), a, b) + return a.x + end + @test count(isnew, src.code) == 1 +end + +# constant globals +# ---------------- + +let # immutable case src = @eval Module() begin const REF_FLD = :x struct ImmutableRef{T} @@ -282,7 +895,6 @@ let # should work with constant globals @test count(isnew, src.code) == 0 # mutable case - # ------------ src = @eval Module() begin const REF_FLD = :x code_typed() do @@ -295,25 +907,6 @@ let # should work with constant globals @test count(isnew, src.code) == 0 end -# should work nicely with inlining to optimize away a complicated case -# adapted from http://wiki.luajit.org/Allocation-Sinking-Optimization#implementation%5B -struct Point - x::Float64 - y::Float64 -end -#=@inline=# add(a::Point, b::Point) = Point(a.x + b.x, a.y + b.y) -function compute_points() - a = Point(1.5, 2.5) - b = Point(2.25, 4.75) - for i in 0:(100000000-1) - a = add(add(a, b), b) - end - a.x, a.y -end -let src = code_typed1(compute_points) - @test !any(isnew, src.code) -end - # comparison lifting # ================== @@ -454,7 +1047,7 @@ end # A SSAValue after the compaction line let m = Meta.@lower 1 + 1 @assert Meta.isexpr(m, :thunk) - src = m.args[1]::Core.CodeInfo + src = m.args[1]::CodeInfo src.code = Any[ # block 1 nothing, @@ -492,7 +1085,7 @@ let m = Meta.@lower 1 + 1 src.ssaflags = fill(Int32(0), nstmts) ir = Core.Compiler.inflate_ir(src, Any[], Any[Any, Any]) @test Core.Compiler.verify_ir(ir) === nothing - ir = @test_nowarn Core.Compiler.sroa_pass!(ir) + ir, = @test_nowarn Core.Compiler.linear_pass!(ir) @test Core.Compiler.verify_ir(ir) === nothing end @@ -517,7 +1110,7 @@ end let m = Meta.@lower 1 + 1 # Test that CFG simplify combines redundant basic blocks @assert Meta.isexpr(m, :thunk) - src = m.args[1]::Core.CodeInfo + src = m.args[1]::CodeInfo src.code = Any[ Core.Compiler.GotoNode(2), Core.Compiler.GotoNode(3), @@ -542,7 +1135,7 @@ end let m = Meta.@lower 1 + 1 # Test that CFG simplify doesn't mess up when chaining past return blocks @assert Meta.isexpr(m, :thunk) - src = m.args[1]::Core.CodeInfo + src = m.args[1]::CodeInfo src.code = Any[ Core.Compiler.GotoIfNot(Core.Compiler.Argument(2), 3), Core.Compiler.GotoNode(4), @@ -572,7 +1165,7 @@ let m = Meta.@lower 1 + 1 # Test that CFG simplify doesn't try to merge every block in a loop into # its predecessor @assert Meta.isexpr(m, :thunk) - src = m.args[1]::Core.CodeInfo + src = m.args[1]::CodeInfo src.code = Any[ # Block 1 Core.Compiler.GotoNode(2), From 7b05508092b7fa3dcfb0a4629a82906235877616 Mon Sep 17 00:00:00 2001 From: Shuhei Kadowaki Date: Mon, 24 Jan 2022 15:57:58 +0900 Subject: [PATCH 3/3] optimizer: simple array SROA MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implements a simple Julia-level array allocation elimination on top of #43888. ```julia julia> code_typed((String,String)) do s, t a = Vector{Base.RefValue{String}}(undef, 2) a[1] = Ref(s) a[2] = Ref(t) return a[1][] end ``` ```diff diff --git a/master b/pr index 9c8da14380..5b63d08190 100644 --- a/master +++ b/pr @@ -1,11 +1,4 @@ 1-element Vector{Any}: CodeInfo( -1 ─ %1 = $(Expr(:foreigncall, :(:jl_alloc_array_1d), Vector{Base.RefValue{String}}, svec(Any, Int64), 0, :(:ccall), Vector{Base.RefValue{String}}, 2, 2))::Vector{Base.RefValue{String}} -│ %2 = %new(Base.RefValue{String}, s)::Base.RefValue{String} -│ Base.arrayset(true, %1, %2, 1)::Vector{Base.RefValue{String}} -│ %4 = %new(Base.RefValue{String}, t)::Base.RefValue{String} -│ Base.arrayset(true, %1, %4, 2)::Vector{Base.RefValue{String}} -│ %6 = Base.arrayref(true, %1, 1)::Base.RefValue{String} -│ %7 = Base.getfield(%6, :x)::String -└── return %7 +1 ─ return s ) => String ``` Still this array SROA handle is very limited and able to handle only trivial examples (though I confirmed this version already eliminates few array allocations during sysimg build). For those who interested, I added some discussions on array optimization [here](https://aviatesk.github.io/EscapeAnalysis.jl/dev/#EA-Array-Analysis). --- base/compiler/optimize.jl | 14 +- base/compiler/ssair/passes.jl | 512 ++++++++++++++++++++++++---------- test/compiler/codegen.jl | 8 +- test/compiler/irpasses.jl | 142 +++++++++- 4 files changed, 508 insertions(+), 168 deletions(-) diff --git a/base/compiler/optimize.jl b/base/compiler/optimize.jl index e84f77ae1ea48..cdf1e0cf40f33 100644 --- a/base/compiler/optimize.jl +++ b/base/compiler/optimize.jl @@ -288,8 +288,7 @@ end function foreigncall_effect_free(stmt::Expr, src::Union{IRCode,IncrementalCompact}) args = stmt.args - name = args[1] - isa(name, QuoteNode) && (name = name.value) + name = normalize(args[1]) isa(name, Symbol) || return false ndims = alloc_array_ndims(name) if ndims !== nothing @@ -315,6 +314,17 @@ function alloc_array_ndims(name::Symbol) return nothing end +normalize(@nospecialize x) = isa(x, QuoteNode) ? x.value : x + +function is_array_alloc(@nospecialize stmt) + isa(stmt, Expr) || return false + if isexpr(stmt, :foreigncall) + name = normalize(stmt.args[1]) + return isa(name, Symbol) && alloc_array_ndims(name) !== nothing + end + return false +end + const FOREIGNCALL_ARG_START = 6 function alloc_array_no_throw(args::Vector{Any}, ndims::Int, src::Union{IRCode,IncrementalCompact}) diff --git a/base/compiler/ssair/passes.jl b/base/compiler/ssair/passes.jl index c8037fac648fa..af1bc63bfd7a3 100644 --- a/base/compiler/ssair/passes.jl +++ b/base/compiler/ssair/passes.jl @@ -557,6 +557,9 @@ function linear_pass!(ir::IRCode) memory_opt = true end continue + elseif is_array_alloc(stmt) + memory_opt = true + continue elseif is_known_call(stmt, getfield, compact) 3 <= length(stmt.args) <= 5 || continue if length(stmt.args) == 5 @@ -687,7 +690,8 @@ function form_new_preserves(origex::Expr, intermediates::Vector{Int}, new_preser end import .EscapeAnalysis: - EscapeState, EscapeInfo, IndexableFields, LivenessSet, getaliases, LocalUse, LocalDef + EscapeState, EscapeInfo, IndexableFields, IndexableElements, LivenessSet, getaliases, + LocalUse, LocalDef, ArrayInfo """ memory_opt_pass!(ir::IRCode, estate::EscapeState) -> newir::IRCode @@ -711,12 +715,12 @@ function memory_opt_pass!(ir::IRCode, estate::EscapeState) eliminated = BitSet() revisit = Tuple{#=related=#Vector{SSAValue}, #=Liveness=#LivenessSet}[] all_preserved = true - newpreserves = nothing + newpreserves = IdDict{Int,Vector{Any}}() while !isempty(wset) idx = pop!(wset) ssa = SSAValue(idx) stmt = ir[ssa][:inst] - isexpr(stmt, :new) || continue + isexpr(stmt, :new) || is_array_alloc(stmt) || continue einfo = estate[ssa] is_load_forwardable(einfo) || continue aliases = getaliases(ssa, estate) @@ -730,152 +734,48 @@ function memory_opt_pass!(ir::IRCode, estate::EscapeState) delete!(wset, alias.id) end end - finfos = (einfo.AliasInfo::IndexableFields).infos - nfields = length(finfos) - - # Partition defuses by field - fdefuses = Vector{FieldDefUse}(undef, nfields) - for i = 1:nfields - finfo = finfos[i] - fdu = FieldDefUse() - for fx in finfo - if isa(fx, LocalUse) - push!(fdu.uses, GetfieldLoad(fx.idx)) # use (getfield call) - else - @assert isa(fx, LocalDef) - push!(fdu.defs, fx.idx) # def (setfield! call or :new expression) - end - end - fdefuses[i] = fdu - end - - Liveness = einfo.Liveness - for livepc in Liveness - livestmt = ir[SSAValue(livepc)][:inst] - if is_known_call(livestmt, Core.ifelse, ir) - # the succeeding domination analysis doesn't account for conditional branching - # by ifelse branching at this moment - @goto next_itr - elseif is_known_call(livestmt, isdefined, ir) - args = livestmt.args - length(args) ≥ 3 || continue - obj = args[2] - isa(obj, SSAValue) || continue - obj in related || continue - fld = args[3] - fldval = try_compute_field(ir, fld) - fldval === nothing && continue - typ = unwrap_unionall(widenconst(argextype(obj, ir))) - isa(typ, DataType) || continue - fldidx = try_compute_fieldidx(typ, fldval) - fldidx === nothing && continue - push!(fdefuses[fldidx].uses, IsdefinedUse(livepc)) - elseif isexpr(livestmt, :foreigncall) - # we shouldn't eliminate this use if it's used as a direct argument - args = livestmt.args - nccallargs = length(args[3]::SimpleVector) - for i = 6:(5+nccallargs) - arg = args[i] - isa(arg, SSAValue) && arg in related && @goto next_liveness - end - # this use is preserve, and may be eliminable - for fidx in 1:nfields - push!(fdefuses[fidx].uses, PreserveUse(livepc)) - end - end - @label next_liveness - end - for fidx in 1:nfields - fdu = fdefuses[fidx] - isempty(fdu.uses) && @goto next_use - # check if all uses have safe definitions first, otherwise we should bail out - # since then we may fail to form new ϕ-nodes - ldu = compute_live_ins(ir.cfg, fdu) - if isempty(ldu.live_in_bbs) - phiblocks = Int[] - else - phiblocks = iterated_dominance_frontier(ir.cfg, ldu, domtree) - end - allblocks = sort!(vcat(phiblocks, ldu.def_bbs)) - for use in fdu.uses - isa(use, IsdefinedUse) && continue - if isa(use, PreserveUse) && isempty(fdu.defs) - # nothing to preserve, just ignore this use (may happen when there are unintialized fields) - continue - end - if !has_safe_def(ir, domtree, allblocks, fdu, getuseidx(use)) - all_preserved = false - @goto next_use - end - end - phinodes = IdDict{Int, SSAValue}() - for b in phiblocks - phinodes[b] = insert_node!(ir, first(ir.cfg.blocks[b].stmts), - NewInstruction(PhiNode(), Any)) - end - # Now go through all uses and rewrite them - for use in fdu.uses - if isa(use, GetfieldLoad) - use = getuseidx(use) - ir[SSAValue(use)][:inst] = compute_value_for_use( - ir, domtree, allblocks, fdu, phinodes, fidx, use) - push!(eliminated, use) - elseif all_preserved && isa(use, PreserveUse) - if newpreserves === nothing - newpreserves = IdDict{Int,Vector{Any}}() - end - # record this `use` as replaceable no matter if we preserve new value or not - use = getuseidx(use) - newvalues = get!(()->Any[], newpreserves, use) - isempty(fdu.defs) && continue # nothing to preserve (may happen when there are unintialized fields) - newval = compute_value_for_use( - ir, domtree, allblocks, fdu, phinodes, fidx, use) - if !isbitstype(widenconst(argextype(newval, ir))) - push!(newvalues, newval) - end - elseif isa(use, IsdefinedUse) - use = getuseidx(use) - if has_safe_def(ir, domtree, allblocks, fdu, use) - ir[SSAValue(use)][:inst] = true - push!(eliminated, use) - end - else - throw("unexpected use") - end - end - for b in phiblocks - ϕssa = phinodes[b] - n = ir[ϕssa][:inst]::PhiNode - t = Bottom - for p in ir.cfg.blocks[b].preds - push!(n.edges, p) - v = compute_value_for_block(ir, domtree, allblocks, fdu, phinodes, fidx, p) - push!(n.values, v) - if t !== Any - t = tmerge(t, argextype(v, ir)) - end - end - ir[ϕssa][:type] = t - end - @label next_use + AliasInfo = einfo.AliasInfo + if isa(AliasInfo, IndexableFields) + @assert isexpr(stmt, :new) "invalid escape analysis" + all_preserved &= load_forward_object!(ir, domtree, + eliminated, revisit, + newpreserves, related, + AliasInfo, einfo.Liveness) + else + @assert is_array_alloc(stmt) "invalid escape analysis" + arrayinfo = estate.arrayinfo + @assert isa(arrayinfo, ArrayInfo) && haskey(arrayinfo, idx) "invalid escape analysis" + dims = arrayinfo[idx] + all_preserved &= load_forward_array!(ir, domtree, + eliminated, revisit, + newpreserves, related, + AliasInfo::IndexableElements, einfo.Liveness, dims) end - push!(revisit, (related, Liveness)) - @label next_itr end # remove dead setfield! and :new allocs deadssas = IdSet{SSAValue}() - if all_preserved && newpreserves !== nothing + if all_preserved preserved = keys(newpreserves) else preserved = EMPTY_PRESERVED_SSAS end mark_dead_ssas!(ir, deadssas, revisit, eliminated, preserved) for ssa in deadssas + # stmt = ir[ssa][:inst] + # if is_known_call(stmt, setfield!, ir) + # println("[SROA] eliminated setfield!: ", argtypes_to_type(ir.argtypes[1:nargs]), " ", ssa, ": ", stmt) + # elseif isexpr(stmt, :new) + # println("[SROA] eliminated object alloc: ", argtypes_to_type(ir.argtypes[1:nargs]), " ", ssa, ": ", stmt) + # elseif is_known_call(stmt, arrayset, ir) + # println("[SROA] eliminated arrayset: ", argtypes_to_type(ir.argtypes[1:nargs]), " ", ssa, ": ", stmt) + # elseif is_array_alloc(stmt) + # println("[SROA] eliminated array alloc: ", argtypes_to_type(ir.argtypes[1:nargs]), " ", ssa, ": ", stmt) + # end ir[ssa][:inst] = nothing end - if all_preserved && newpreserves !== nothing + if all_preserved deadssas = Int[ssa.id for ssa in deadssas] for (idx, newuses) in newpreserves ir[SSAValue(idx)][:inst] = form_new_preserves( @@ -886,20 +786,291 @@ function memory_opt_pass!(ir::IRCode, estate::EscapeState) return ir end +function load_forward_object!(ir::IRCode, domtree::DomTree, + eliminated::BitSet, revisit::Vector{Tuple{Vector{SSAValue}, LivenessSet}}, + newpreserves::IdDict{Int,Vector{Any}}, related::Vector{SSAValue}, + AliasInfo::IndexableFields, Liveness::LivenessSet) + finfos = AliasInfo.infos + nfields = length(finfos) + + # Partition defuses by field + all_preserved = true + fdefuses = Vector{IndexedDefUse}(undef, nfields) + for i = 1:nfields + finfo = finfos[i] + idu = IndexedDefUse() + for fx in finfo + if isa(fx, LocalUse) + push!(idu.uses, LoadUse(fx.idx)) # use (getfield call) + else + @assert isa(fx, LocalDef) + push!(idu.defs, fx.idx) # def (setfield! call or :new expression) + end + end + fdefuses[i] = idu + end + + for livepc in Liveness + livestmt = ir[SSAValue(livepc)][:inst] + if is_known_call(livestmt, Core.ifelse, ir) + # the succeeding domination analysis doesn't account for conditional branching + # by ifelse branching at this moment + return false + elseif is_known_call(livestmt, isdefined, ir) + args = livestmt.args + length(args) ≥ 3 || continue + obj = args[2] + isa(obj, SSAValue) || continue + obj in related || continue + fld = args[3] + fldval = try_compute_field(ir, fld) + fldval === nothing && continue + typ = unwrap_unionall(widenconst(argextype(obj, ir))) + isa(typ, DataType) || continue + fldidx = try_compute_fieldidx(typ, fldval) + fldidx === nothing && continue + push!(fdefuses[fldidx].uses, IsdefinedUse(livepc)) + elseif isexpr(livestmt, :foreigncall) + # we shouldn't eliminate this use if it's used as a direct argument + args = livestmt.args + nccallargs = length(args[3]::SimpleVector) + for i = 6:(5+nccallargs) + arg = args[i] + isa(arg, SSAValue) && arg in related && @goto next_liveness + end + # this use is preserve, and may be eliminable + for fidx in 1:nfields + push!(fdefuses[fidx].uses, PreserveUse(livepc)) + end + end + @label next_liveness + end + + for fidx in 1:nfields + idu = fdefuses[fidx] + isempty(idu.uses) && @goto next_use + # check if all uses have safe definitions first, otherwise we should bail out + # since then we may fail to form new ϕ-nodes + ldu = compute_live_ins(ir.cfg, idu) + if isempty(ldu.live_in_bbs) + phiblocks = Int[] + else + phiblocks = iterated_dominance_frontier(ir.cfg, ldu, domtree) + end + allblocks = sort!(vcat(phiblocks, ldu.def_bbs)) + for use in idu.uses + isa(use, IsdefinedUse) && continue + if isa(use, PreserveUse) && isempty(idu.defs) + # nothing to preserve, just ignore this use (may happen when there are unintialized fields) + continue + end + if !has_safe_def(ir, domtree, allblocks, idu, getuseidx(use)) + all_preserved = false + @goto next_use + end + end + phinodes = IdDict{Int, SSAValue}() + for b in phiblocks + phinodes[b] = insert_node!(ir, first(ir.cfg.blocks[b].stmts), + NewInstruction(PhiNode(), Any)) + end + # Now go through all uses and rewrite them + for use in idu.uses + if isa(use, LoadUse) + use = getuseidx(use) + ir[SSAValue(use)][:inst] = compute_value_for_use( + ir, domtree, allblocks, idu, phinodes, fidx, use) + push!(eliminated, use) + elseif isa(use, PreserveUse) + all_preserved || continue + # record this `use` as replaceable no matter if we preserve new value or not + use = getuseidx(use) + newvalues = get!(()->Any[], newpreserves, use) + isempty(idu.defs) && continue # nothing to preserve (may happen when there are unintialized fields) + newval = compute_value_for_use( + ir, domtree, allblocks, idu, phinodes, fidx, use) + if !isbitstype(widenconst(argextype(newval, ir))) + push!(newvalues, newval) + end + elseif isa(use, IsdefinedUse) + use = getuseidx(use) + if has_safe_def(ir, domtree, allblocks, idu, use) + ir[SSAValue(use)][:inst] = true + push!(eliminated, use) + end + else + throw("load_forward_object!: unexpected use") + end + end + for b in phiblocks + ϕssa = phinodes[b] + n = ir[ϕssa][:inst]::PhiNode + t = Bottom + for p in ir.cfg.blocks[b].preds + push!(n.edges, p) + v = compute_value_for_block(ir, domtree, allblocks, idu, phinodes, fidx, p) + push!(n.values, v) + if t !== Any + t = tmerge(t, argextype(v, ir)) + end + end + ir[ϕssa][:type] = t + end + @label next_use + end + push!(revisit, (related, Liveness)) + + return all_preserved +end + +# TODO is_array_isassigned folding? +function load_forward_array!(ir::IRCode, domtree::DomTree, + eliminated::BitSet, revisit::Vector{Tuple{Vector{SSAValue}, LivenessSet}}, + newpreserves::IdDict{Int,Vector{Any}}, related::Vector{SSAValue}, + AliasInfo::IndexableElements, Liveness::LivenessSet, dims::Vector{Int}) + elminfos = AliasInfo.infos + elmkeys = keys(elminfos) + + # Partition defuses by index + all_preserved = true + edefuses = IdDict{Int,IndexedDefUse}() + for eidx in elmkeys + einfo = elminfos[eidx] + idu = IndexedDefUse() + for ex in einfo + if isa(ex, LocalUse) + push!(idu.uses, LoadUse(ex.idx)) # use (arrayref call) + else + @assert isa(ex, LocalDef) + push!(idu.defs, ex.idx) # def (arrayset call) + end + end + edefuses[eidx] = idu + end + + for livepc in Liveness + ssa = SSAValue(livepc) + livestmt = ir[ssa][:inst] + if is_known_call(livestmt, Core.ifelse, ir) + # the succeeding domination analysis doesn't account for conditional branching + # by ifelse branching at this moment + return false + elseif is_known_call(livestmt, arraylen, ir) + len = 1 + for dim in dims + len *= dim + end + ir[ssa][:inst] = len + push!(eliminated, livepc) + elseif is_known_call(livestmt, arraysize, ir) + length(livestmt.args) ≥ 3 || continue + dim = argextype(livestmt.args[3], ir) + isa(dim, Const) || continue + dim = dim.val + isa(dim, Int) || continue + checkbounds(Bool, dims, dim) || continue + ir[ssa][:inst] = dims[dim] + push!(eliminated, livepc) + elseif isexpr(livestmt, :foreigncall) + # we shouldn't eliminate this use if it's used as a direct argument + args = livestmt.args + nccallargs = length(args[3]::SimpleVector) + for i = 6:(5+nccallargs) + arg = args[i] + isa(arg, SSAValue) && arg in related && @goto next_liveness + end + # this use is preserve, and may be eliminable + for eidx in elmkeys + push!(edefuses[eidx].uses, PreserveUse(livepc)) + end + end + @label next_liveness + end + + for eidx in elmkeys + idu = edefuses[eidx] + isempty(idu.uses) && @goto next_use + # check if all uses have safe definitions first, otherwise we should bail out + # since then we may fail to form new ϕ-nodes + ldu = compute_live_ins(ir.cfg, idu) + if isempty(ldu.live_in_bbs) + phiblocks = Int[] + else + phiblocks = iterated_dominance_frontier(ir.cfg, ldu, domtree) + end + allblocks = sort!(vcat(phiblocks, ldu.def_bbs)) + for use in idu.uses + if isa(use, PreserveUse) && isempty(idu.defs) + # nothing to preserve, just ignore this use (may happen when there are unintialized fields) + continue + end + if !has_safe_def(ir, domtree, allblocks, idu, getuseidx(use)) + all_preserved = false + @goto next_use + end + end + phinodes = IdDict{Int, SSAValue}() + for b in phiblocks + phinodes[b] = insert_node!(ir, first(ir.cfg.blocks[b].stmts), + NewInstruction(PhiNode(), Any)) + end + # Now go through all uses and rewrite them + for use in idu.uses + if isa(use, LoadUse) + use = getuseidx(use) + ir[SSAValue(use)][:inst] = compute_value_for_use( + ir, domtree, allblocks, idu, phinodes, eidx, use) + push!(eliminated, use) + elseif isa(use, PreserveUse) + all_preserved || continue + # record this `use` as replaceable no matter if we preserve new value or not + use = getuseidx(use) + newvalues = get!(()->Any[], newpreserves, use) + isempty(idu.defs) && continue # nothing to preserve (may happen when there are unintialized fields) + newval = compute_value_for_use( + ir, domtree, allblocks, idu, phinodes, eidx, use) + if !isbitstype(widenconst(argextype(newval, ir))) + push!(newvalues, newval) + end + else + throw("load_forward_array!: unexpected use") + end + end + for b in phiblocks + ϕssa = phinodes[b] + n = ir[ϕssa][:inst]::PhiNode + t = Bottom + for p in ir.cfg.blocks[b].preds + push!(n.edges, p) + v = compute_value_for_block(ir, domtree, allblocks, idu, phinodes, eidx, p) + push!(n.values, v) + if t !== Any + t = tmerge(t, argextype(v, ir)) + end + end + ir[ϕssa][:type] = t + end + @label next_use + end + push!(revisit, (related, Liveness)) + + return all_preserved +end + const EMPTY_PRESERVED_SSAS = keys(IdDict{Int,Vector{Any}}()) const PreservedSets = typeof(EMPTY_PRESERVED_SSAS) function is_load_forwardable(x::EscapeInfo) AliasInfo = x.AliasInfo - return isa(AliasInfo, IndexableFields) + return isa(AliasInfo, IndexableFields) || isa(AliasInfo, IndexableElements) end -struct FieldDefUse +struct IndexedDefUse uses::Vector{Any} defs::Vector{Int} end -FieldDefUse() = FieldDefUse(Any[], Int[]) -struct GetfieldLoad +IndexedDefUse() = IndexedDefUse(Any[], Int[]) +struct LoadUse idx::Int end struct PreserveUse @@ -909,7 +1080,7 @@ struct IsdefinedUse idx::Int end function getuseidx(@nospecialize use) - if isa(use, GetfieldLoad) + if isa(use, LoadUse) return use.idx elseif isa(use, PreserveUse) return use.idx @@ -919,21 +1090,21 @@ function getuseidx(@nospecialize use) throw("getuseidx: unexpected use") end -function compute_live_ins(cfg::CFG, fdu::FieldDefUse) +function compute_live_ins(cfg::CFG, idu::IndexedDefUse) uses = Int[] - for use in fdu.uses + for use in idu.uses isa(use, IsdefinedUse) && continue push!(uses, getuseidx(use)) end - return compute_live_ins(cfg, fdu.defs, uses) + return compute_live_ins(cfg, idu.defs, uses) end # even when the allocation contains an uninitialized field, we try an extra effort to check # if this load at `idx` have any "safe" `setfield!` calls that define the field # try to find function has_safe_def(ir::IRCode, domtree::DomTree, allblocks::Vector{Int}, - fdu::FieldDefUse, use::Int) - dfu = find_def_for_use(ir, domtree, allblocks, fdu, use) + idu::IndexedDefUse, use::Int) + dfu = find_def_for_use(ir, domtree, allblocks, idu, use) dfu === nothing && return false def = dfu[1] def ≠ 0 && return true # found a "safe" definition @@ -949,7 +1120,7 @@ function has_safe_def(ir::IRCode, domtree::DomTree, allblocks::Vector{Int}, pred in seen && return false use = last(ir.cfg.blocks[pred].stmts) # NOTE this `use` isn't a load, and so the inclusive condition can be used - dfu = find_def_for_use(ir, domtree, allblocks, fdu, use, true) + dfu = find_def_for_use(ir, domtree, allblocks, idu, use, true) dfu === nothing && return false def = dfu[1] push!(seen, pred) @@ -964,12 +1135,12 @@ end # find the first dominating def for the given use function find_def_for_use(ir::IRCode, domtree::DomTree, allblocks::Vector{Int}, - fdu::FieldDefUse, use::Int, inclusive::Bool=false) + idu::IndexedDefUse, use::Int, inclusive::Bool=false) useblock = block_for_inst(ir.cfg, use) curblock = find_curblock(domtree, allblocks, useblock) curblock === nothing && return nothing local def = 0 - for idx in fdu.defs + for idx in idu.defs if block_for_inst(ir.cfg, idx) == curblock if curblock != useblock # Find the last def in this block @@ -998,15 +1169,15 @@ function find_curblock(domtree::DomTree, allblocks::Vector{Int}, curblock::Int) end function compute_value_for_use(ir::IRCode, domtree::DomTree, allblocks::Vector{Int}, - fdu::FieldDefUse, phinodes::IdDict{Int, SSAValue}, fidx::Int, use::Int) - dfu = find_def_for_use(ir, domtree, allblocks, fdu, use) + idu::IndexedDefUse, phinodes::IdDict{Int, SSAValue}, fidx::Int, use::Int) + dfu = find_def_for_use(ir, domtree, allblocks, idu, use) @assert dfu !== nothing "has_safe_def condition unsatisfied" def, useblock, curblock = dfu if def == 0 if !haskey(phinodes, curblock) # If this happens, we need to search the predecessors for defs. Which # one doesn't matter - if it did, we'd have had a phinode - return compute_value_for_block(ir, domtree, allblocks, fdu, phinodes, fidx, first(ir.cfg.blocks[useblock].preds)) + return compute_value_for_block(ir, domtree, allblocks, idu, phinodes, fidx, first(ir.cfg.blocks[useblock].preds)) end # The use is the phinode return phinodes[curblock] @@ -1016,11 +1187,11 @@ function compute_value_for_use(ir::IRCode, domtree::DomTree, allblocks::Vector{I end function compute_value_for_block(ir::IRCode, domtree::DomTree, allblocks::Vector{Int}, - fdu::FieldDefUse, phinodes::IdDict{Int, SSAValue}, fidx::Int, curblock::Int) + idu::IndexedDefUse, phinodes::IdDict{Int, SSAValue}, fidx::Int, curblock::Int) curblock = find_curblock(domtree, allblocks, curblock) @assert curblock !== nothing "has_safe_def condition unsatisfied" def = 0 - for stmt in fdu.defs + for stmt in idu.defs if block_for_inst(ir.cfg, stmt) == curblock def = max(def, stmt) end @@ -1032,9 +1203,12 @@ function val_for_def_expr(ir::IRCode, def::Int, fidx::Int) ex = ir[SSAValue(def)][:inst] if isexpr(ex, :new) return ex.args[1+fidx] - else - @assert is_known_call(ex, setfield!, ir) "invalid load forwarding" + elseif is_known_call(ex, setfield!, ir) return ex.args[4] + elseif is_known_call(ex, arrayset, ir) + return ex.args[4] + else + throw("invalid load forwarding") end end @@ -1103,6 +1277,34 @@ function mark_dead_ssas!(ir::IRCode, deadssas::IdSet{SSAValue}, end end return false + elseif is_known_call(stmt, arrayset, ir) + @assert length(stmt.args) ≥ 4 "invalid escape analysis" + ary = stmt.args[3] + val = stmt.args[4] + if isa(ary, SSAValue) + if ary in related + push!(eliminable, ssa) + @goto next_live + end + if isa(val, SSAValue) && val in related + if ary in deadssas + push!(eliminable, ssa) + @goto next_live + end + for new_revisit_idx in wset + if ary in revisit[new_revisit_idx][1] + delete!(wset, new_revisit_idx) + if mark_dead_ssas!(ir, deadssas, revisit, eliminated, preserved, wset, new_revisit_idx) + push!(eliminable, ssa) + @goto next_live + else + return false + end + end + end + end + end + return false elseif isexpr(stmt, :foreigncall) livepc in preserved && @goto next_live return false diff --git a/test/compiler/codegen.jl b/test/compiler/codegen.jl index ec89ac9cd72a4..d21765180a4b9 100644 --- a/test/compiler/codegen.jl +++ b/test/compiler/codegen.jl @@ -548,27 +548,27 @@ end # main use case function f1(cond) val = [1] - GC.@preserve val begin end + GC.@preserve val begin val end end @test occursin("llvm.julia.gc_preserve_begin", get_llvm(f1, Tuple{Bool}, true, false, false)) # stack allocated objects (JuliaLang/julia#34241) function f3(cond) val = ([1],) - GC.@preserve val begin end + GC.@preserve val begin val end end @test occursin("llvm.julia.gc_preserve_begin", get_llvm(f3, Tuple{Bool}, true, false, false)) # unions of immutables (JuliaLang/julia#39501) function f2(cond) val = cond ? 1 : 1f0 - GC.@preserve val begin end + GC.@preserve val begin val end end @test !occursin("llvm.julia.gc_preserve_begin", get_llvm(f2, Tuple{Bool}, true, false, false)) # make sure the fix for the above doesn't regress #34241 function f4(cond) val = cond ? ([1],) : ([1f0],) - GC.@preserve val begin end + GC.@preserve val begin val end end @test occursin("llvm.julia.gc_preserve_begin", get_llvm(f4, Tuple{Bool}, true, false, false)) end diff --git a/test/compiler/irpasses.jl b/test/compiler/irpasses.jl index 7793489d0fc2b..0933bd3b49ca9 100644 --- a/test/compiler/irpasses.jl +++ b/test/compiler/irpasses.jl @@ -75,17 +75,26 @@ end # SROA # ==== -import Core.Compiler: widenconst - -is_load_forwarded(src::CodeInfo) = !any(iscall((src, getfield)), src.code) -is_scalar_replaced(src::CodeInfo) = - is_load_forwarded(src) && !any(iscall((src, setfield!)), src.code) && !any(isnew, src.code) +import Core.Compiler: widenconst, is_array_alloc + +is_load_forwarded(src::CodeInfo) = + !any(iscall((src, getfield)), src.code) && !any(iscall((src, Core.arrayref)), src.code) +function is_scalar_replaced(src::CodeInfo) + is_load_forwarded(src) || return false + any(iscall((src, setfield!)), src.code) && return false + any(isnew, src.code) && return false + any(iscall((src, Core.arrayset)), src.code) && return false + any(is_array_alloc, src.code) && return false + return true +end function is_load_forwarded(@nospecialize(T), src::CodeInfo) for i in 1:length(src.code) x = src.code[i] if iscall((src, getfield), x) widenconst(argextype(x.args[1], src)) <: T && return false + elseif iscall((src, Core.arrayref), x) + widenconst(argextype(x.args[1], src)) <: T && return false end end return true @@ -98,6 +107,10 @@ function is_scalar_replaced(@nospecialize(T), src::CodeInfo) widenconst(argextype(x.args[1], src)) <: T && return false elseif isnew(x) widenconst(argextype(SSAValue(i), src)) <: T && return false + elseif iscall((src, Core.arrayset), x) + widenconst(argextype(x.args[1], src)) <: T && return false + elseif is_array_alloc(x) + widenconst(argextype(SSAValue(i), src)) <: T && return false end end return true @@ -713,7 +726,7 @@ function mutable_ϕ_elim(x, xs) return r[] end let src = code_typed1(mutable_ϕ_elim, (String, Vector{String})) - @test is_scalar_replaced(src) + @test is_scalar_replaced(Ref{String}, src) xs = String[string(gensym()) for _ in 1:100] mutable_ϕ_elim("init", xs) @@ -852,7 +865,7 @@ function isdefined_elim() return arr end let src = code_typed1(isdefined_elim) - @test is_scalar_replaced(src) + @test count(isnew, src.code) == 0 # eliminates closure constructs end @test isdefined_elim() == Any[] @@ -907,6 +920,121 @@ let # immutable case @test count(isnew, src.code) == 0 end +# array SROA +# ---------- + +let src = code_typed1((Any,)) do s + a = Vector{Any}(undef, 1) + a[1] = s + return a[1] + end + @test is_scalar_replaced(src) +end +let src = code_typed1((Any,)) do s + a = Any[nothing] + a[1] = s + return a[1] + end + @test is_scalar_replaced(src) +end +let src = code_typed1((String,String)) do s, t + a = Vector{Any}(undef, 2) + a[1] = Ref(s) + a[2] = Ref(t) + return a[1] + end + @test count(isnew, src.code) == 1 +end +let src = code_typed1((String,)) do s + a = Vector{Base.RefValue{String}}(undef, 1) + a[1] = Ref(s) + return a[1][] + end + @test is_scalar_replaced(src) +end +let src = code_typed1((String,String)) do s, t + a = Vector{Base.RefValue{String}}(undef, 2) + a[1] = Ref(s) + a[2] = Ref(t) + return a[1][] + end + @test is_scalar_replaced(src) +end +let src = code_typed1((Any,)) do s + a = Vector{Any}[Any[nothing]] + a[1][1] = s + return a[1][1] + end + @test_broken is_scalar_replaced(src) +end +let src = code_typed1((Bool,Any,Any)) do c, s, t + a = Any[nothing] + if c + a[1] = s + else + a[1] = t + end + return a[1] + end + @test is_scalar_replaced(src) +end +let src = code_typed1((Bool,Any,Any,Any,Any,)) do c, s1, s2, t1, t2 + if c + a = Vector{Any}(undef, 2) + a[1] = s1 + a[2] = s2 + else + a = Vector{Any}(undef, 2) + a[1] = t1 + a[2] = t2 + end + return a[1] + end + @test is_scalar_replaced(src) +end +let src = code_typed1((Bool,Any,Any)) do c, s, t + # XXX this implicitly forms tuple to getfield chains + # and SROA on it produces complicated control flow + if c + a = Any[s] + else + a = Any[t] + end + return a[1] + end + @test_broken is_scalar_replaced(src) +end + +# arraylen / arraysize elimination +let src = code_typed1((Any,)) do s + a = Vector{Any}(undef, 1) + a[1] = s + return a[1], length(a) + end + @test is_scalar_replaced(src) +end +let src = code_typed1((Any,)) do s + a = Matrix{Any}(undef, 2, 2) + a[1, 1] = s + return a[1, 1], length(a) + end + @test is_scalar_replaced(src) +end +let src = code_typed1((Any,)) do s + a = Vector{Any}(undef, 1) + a[1] = s + return a[1], size(a, 1) + end + @test is_scalar_replaced(src) +end +let src = code_typed1((Any,)) do s + a = Matrix{Any}(undef, 2, 2) + a[1, 1] = s + return a[1, 1], size(a) + end + @test is_scalar_replaced(src) +end + # comparison lifting # ==================