Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 14 additions & 14 deletions Manifest.toml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions src/analysis/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,9 @@ function dae_result_for_inst(interp, inst::CC.Instruction)
info = inst[:info]
stmt = inst[:stmt]
mi = stmt.args[1]
if isa(info, Diffractor.FRuleCallInfo) && info.frule_call.rt === Const(nothing)
info = info.info
end
if isa(info, CC.ConstCallInfo)
if length(info.results) != 1
# TODO: When does this happen? Union split?
Expand Down
48 changes: 18 additions & 30 deletions src/analysis/interpreter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
using Core: CodeInfo, MethodInstance, CodeInstance, SimpleVector, MethodMatch, MethodTable
using .CC: AbstractInterpreter, NativeInterpreter, InferenceParams, OptimizationParams,
InferenceResult, InferenceState, OptimizationState, WorldRange, WorldView, ArgInfo,
StmtInfo, MethodCallResult, ConstCallResults, ConstPropResult, MethodTableView,
StmtInfo, MethodCallResult, ConstCallResult, ConstPropResult, MethodTableView,
CachedMethodTable, InternalMethodTable, OverlayMethodTable, CallMeta, CallInfo,
IRCode, LazyDomtree, IRInterpretationState, set_inlineable!, block_for_inst,
BitSetBoundedMinPrioritySet, AbsIntState, Future
Expand Down Expand Up @@ -117,7 +117,7 @@ struct DAEInterpreter <: AbstractInterpreter
ipo_analysis_mode::Bool = false,
in_analysis::Bool = false)
if code_cache === nothing
code_cache = get_code_cache(method_table, ipo_analysis_mode)
code_cache = get_code_cache(world, method_table, ipo_analysis_mode)
end
if method_table !== nothing
method_table = CachedMethodTable(OverlayMethodTable(world, method_table))
Expand Down Expand Up @@ -315,13 +315,10 @@ end
return Future{MethodCallResult}(mret, interp, sv) do ret, interp, sv
edge = ret.edge
if edge !== nothing
cache = CC.get(CC.code_cache(interp), edge, nothing)
if cache !== nothing
src = @atomic :monotonic cache.inferred
if isa(src, DAECache)
info = src.info
merge_daeinfo!(interp, sv.result, info)
end
src = @atomic :monotonic edge.inferred
if isa(src, DAECache)
info = src.info
merge_daeinfo!(interp, sv.result, info)
end
end
return ret
Expand All @@ -330,11 +327,11 @@ end

@override function CC.const_prop_call(interp::DAEInterpreter,
mi::MethodInstance, result::MethodCallResult, arginfo::ArgInfo,
sv::InferenceState, concrete_eval_result::Union{Nothing,ConstCallResults})
sv::InferenceState, concrete_eval_result::Union{Nothing,ConstCallResult})
ret = @invoke CC.const_prop_call(interp::AbstractInterpreter,
mi::MethodInstance, result::MethodCallResult, arginfo::ArgInfo,
sv::InferenceState, concrete_eval_result::Union{Nothing,ConstCallResults})
if isa(ret, ConstCallResults)
sv::InferenceState, concrete_eval_result::Union{Nothing,ConstCallResult})
if isa(ret, ConstCallResult)
const_result = ret.const_result::ConstPropResult
info = interp.dae_cache[const_result.result]
merge_daeinfo!(interp, sv.result, info)
Expand All @@ -353,26 +350,20 @@ struct DAECache
new(inferred, ir, info)
end

@override CC.transform_result_for_cache(interp::DAEInterpreter,
mi::MethodInstance, valid_worlds::WorldRange, result::InferenceResult, cond::Bool) =
_transform_result_for_cache(interp, mi, valid_worlds, result, cond)

function _transform_result_for_cache(interp::DAEInterpreter,
mi::MethodInstance, valid_worlds::WorldRange, result::InferenceResult, cond::Bool=false)
function CC.transform_result_for_cache(interp::DAEInterpreter, result::InferenceResult)
src = result.src
if isa(src, DAECache)
return src
end
inferred = @invoke CC.transform_result_for_cache(interp::AbstractInterpreter,
mi::MethodInstance, valid_worlds::WorldRange, result::InferenceResult, cond::Bool)
inferred = @invoke CC.transform_result_for_cache(interp::AbstractInterpreter, result)
return DAECache(inferred, nothing, interp.dae_cache[result])
end

# inlining
# --------

function dae_inlining_policy(@nospecialize(src), @nospecialize(info::CallInfo), raise::Bool=true)
if isa(info, Diffractor.FRuleCallInfo)
if isa(info, Diffractor.FRuleCallInfo) && info.frule_call.rt !== Const(nothing)
return nothing
end
osrc = src
Expand Down Expand Up @@ -502,13 +493,10 @@ end
result::MethodCallResult, si::StmtInfo, sv::InferenceState, force::Bool)
edge = result.edge
if edge !== nothing
cache = CC.get(CC.code_cache(interp), edge, nothing)
if cache !== nothing
src = @atomic :monotonic cache.inferred
if isa(src, DAECache)
src.info.has_dae_intrinsics && return true
src.info.has_scoperead && return true
end
src = @atomic :monotonic edge.inferred
if isa(src, DAECache)
src.info.has_dae_intrinsics && return true
src.info.has_scoperead && return true
end
end
return @invoke CC.const_prop_rettype_heuristic(interp::AbstractInterpreter,
Expand Down Expand Up @@ -1052,7 +1040,7 @@ end
using Cthulhu

function Cthulhu.get_optimized_codeinst(interp::DAEInterpreter, curs::Cthulhu.CthulhuCursor)
interp.code_cache.cache[curs.mi]
CC.getindex(CC.code_cache(interp), curs.mi)
end

function Cthulhu.lookup(interp::DAEInterpreter, curs::Cthulhu.CthulhuCursor, optimize::Bool)
Expand Down Expand Up @@ -1109,7 +1097,7 @@ function lookup_optimized(interp::DAEInterpreter, mi::MethodInstance, allow_no_s
end

Cthulhu.can_descend(interp::DAEInterpreter, @nospecialize(key), optimize::Bool) =
haskey(optimize ? interp.code_cache.cache : interp.unopt, key)
optimize ? CC.haskey(CC.code_cache(interp), key) : haskey(interp.unopt, key)

# TODO: Why does Cthulhu have this separately from the lookup logic, which already
# returns effects
Expand Down
10 changes: 6 additions & 4 deletions src/transform/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,16 @@ end
function remap_info(remap_ir!, info)
# TODO: This is pretty aweful, but it works for now.
# It'll go away when we switch to IPO.
if isa(info, Diffractor.FRuleCallInfo) && info.frule_call.rt === Const(nothing)
info = info.info
end
isa(info, CC.ConstCallInfo) || return info
results = map(info.results) do result
result === nothing && return result
if isa(result, CC.SemiConcreteResult)
let ir = copy(result.ir)
remap_ir!(ir)
CC.SemiConcreteResult(result.mi, ir, result.effects, result.spec_info)
CC.SemiConcreteResult(result.edge, ir, result.effects, result.spec_info)
end
elseif isa(result, CC.ConstPropResult)
if isa(result.result.src, DAECache)
Expand All @@ -76,10 +79,9 @@ function widen_extra_info!(ir)
for i = 1:length(ir.stmts)
info = ir.stmts[i][:info]
if isa(info, Diffractor.FRuleCallInfo)
ir.stmts[i][:info] = info.info
else
ir.stmts[i][:info] = remap_info(widen_extra_info!, info)
info = info.info
end
ir.stmts[i][:info] = remap_info(widen_extra_info!, info)
inst = ir.stmts[i][:inst]
if isa(inst, PiNode)
ir.stmts[i][:inst] = PiNode(inst.val, widenconst(inst.typ))
Expand Down
1 change: 1 addition & 0 deletions test/ipo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module ipo
using Test
using DAECompiler
using DAECompiler.Intrinsics
using DAECompiler.Intrinsics: state_ddt
using SciMLBase, OrdinaryDiffEq, Sundials

include(joinpath(Base.pkgdir(DAECompiler), "test", "testutils.jl"))
Expand Down
Loading