diff --git a/Manifest.toml b/Manifest.toml index 3d16ae7..f066b1d 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -259,7 +259,7 @@ weakdeps = ["Dates", "LinearAlgebra"] [[deps.CompilerSupportLibraries_jll]] deps = ["Artifacts", "Libdl"] uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" -version = "1.1.1+0" +version = "1.2.0+0" [[deps.CompositeTypes]] git-tree-sha1 = "bce26c3dab336582805503bed209faab1c279768" @@ -317,9 +317,9 @@ version = "4.1.1" [[deps.Cthulhu]] deps = ["CodeTracking", "FoldingTrees", "InteractiveUtils", "JuliaSyntax", "PrecompileTools", "Preferences", "REPL", "TypedSyntax", "UUIDs", "Unicode", "WidthLimitedIO"] -git-tree-sha1 = "6dd420e944a3be328f91088d6a1af02576ccba4b" +git-tree-sha1 = "638d5b786059bade99139f6b5af59932e424099c" uuid = "f68482b8-f384-11e8-15f7-abe071a5a75f" -version = "2.15.0" +version = "2.15.3" [[deps.DAECompiler]] deps = ["Accessors", "CentralizedCaches", "ChainRules", "ChainRulesCore", "Cthulhu", "DiffEqBase", "DiffEqCallbacks", "Diffractor", "Distributions", "ExprTools", "ForwardDiff", "Graphs", "LinearAlgebra", "NonlinearSolve", "OrderedCollections", "OrdinaryDiffEq", "PrecompileTools", "Preferences", "REPL", "Random", "SciMLBase", "SciMLSensitivity", "SparseArrays", "StateSelection", "StaticArraysCore", "Sundials", "SymbolicIndexingInterface", "TimerOutputs", "Tracy"] @@ -444,11 +444,11 @@ version = "1.15.1" [[deps.Diffractor]] deps = ["AbstractDifferentiation", "ChainRules", "ChainRulesCore", "Combinatorics", "Cthulhu", "InteractiveUtils", "OffsetArrays", "PrecompileTools", "StaticArrays", "StructArrays"] -git-tree-sha1 = "2a9b827fce47e27ef32471df18a96dc4ff1123bd" -repo-rev = "kf/compileradjust" +git-tree-sha1 = "a9a4b544e40f91c770760b4624284188be1fe9bd" +repo-rev = "kf/edgesadjust" repo-url = "https://github.com/JuliaDiff/Diffractor.jl.git" uuid = "9f5e2b26-1114-432f-b630-d3fe2085c51c" -version = "0.2.8" +version = "0.2.10" [[deps.Distances]] deps = ["LinearAlgebra", "Statistics", "StatsAPI"] @@ -1094,7 +1094,7 @@ weakdeps = ["SparseArrays"] [[deps.MbedTLS_jll]] deps = ["Artifacts", "Libdl"] uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" -version = "2.28.6+0" +version = "2.28.6+1" [[deps.Missings]] deps = ["DataAPI"] @@ -1237,12 +1237,12 @@ weakdeps = ["Adapt"] [[deps.OpenBLAS_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"] uuid = "4536629a-c528-5b80-bd46-f80d51c5b363" -version = "0.3.28+2" +version = "0.3.28+3" [[deps.OpenLibm_jll]] deps = ["Artifacts", "Libdl"] uuid = "05823500-19ac-5b8b-9628-191a04bc5112" -version = "0.8.1+2" +version = "0.8.1+3" [[deps.OpenSpecFun_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"] @@ -1832,7 +1832,7 @@ uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" [[deps.SuiteSparse_jll]] deps = ["Artifacts", "Libdl", "libblastrampoline_jll"] uuid = "bea87d4a-7f5b-5778-9afe-8cc45184846c" -version = "7.8.0+0" +version = "7.8.0+1" [[deps.Sundials]] deps = ["CEnum", "DataStructures", "DiffEqBase", "Libdl", "LinearAlgebra", "Logging", "PrecompileTools", "Reexport", "SciMLBase", "SparseArrays", "Sundials_jll"] @@ -2046,17 +2046,17 @@ version = "1.0.1" [[deps.Zlib_jll]] deps = ["Libdl"] uuid = "83775a58-1f1d-513f-b197-d71354ab007a" -version = "1.3.1+0" +version = "1.3.1+1" [[deps.libblastrampoline_jll]] deps = ["Artifacts", "Libdl"] uuid = "8e850b90-86db-534c-a0d3-1478176c7d93" -version = "5.11.0+0" +version = "5.11.1+0" [[deps.nghttp2_jll]] deps = ["Artifacts", "Libdl"] uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d" -version = "1.60.0+0" +version = "1.63.0+1" [[deps.oneTBB_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] @@ -2067,4 +2067,4 @@ version = "2021.12.0+0" [[deps.p7zip_jll]] deps = ["Artifacts", "Libdl"] uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0" -version = "17.5.0+0" +version = "17.5.0+1" diff --git a/src/analysis/compiler.jl b/src/analysis/compiler.jl index 8b91323..b2d8884 100644 --- a/src/analysis/compiler.jl +++ b/src/analysis/compiler.jl @@ -442,6 +442,9 @@ function dae_result_for_inst(interp, inst::CC.Instruction) info = inst[:info] stmt = inst[:stmt] mi = stmt.args[1] + if isa(info, Diffractor.FRuleCallInfo) && info.frule_call.rt === Const(nothing) + info = info.info + end if isa(info, CC.ConstCallInfo) if length(info.results) != 1 # TODO: When does this happen? Union split? diff --git a/src/analysis/interpreter.jl b/src/analysis/interpreter.jl index 66f5dc1..cdbe75d 100644 --- a/src/analysis/interpreter.jl +++ b/src/analysis/interpreter.jl @@ -4,7 +4,7 @@ using Core: CodeInfo, MethodInstance, CodeInstance, SimpleVector, MethodMatch, MethodTable using .CC: AbstractInterpreter, NativeInterpreter, InferenceParams, OptimizationParams, InferenceResult, InferenceState, OptimizationState, WorldRange, WorldView, ArgInfo, - StmtInfo, MethodCallResult, ConstCallResults, ConstPropResult, MethodTableView, + StmtInfo, MethodCallResult, ConstCallResult, ConstPropResult, MethodTableView, CachedMethodTable, InternalMethodTable, OverlayMethodTable, CallMeta, CallInfo, IRCode, LazyDomtree, IRInterpretationState, set_inlineable!, block_for_inst, BitSetBoundedMinPrioritySet, AbsIntState, Future @@ -117,7 +117,7 @@ struct DAEInterpreter <: AbstractInterpreter ipo_analysis_mode::Bool = false, in_analysis::Bool = false) if code_cache === nothing - code_cache = get_code_cache(method_table, ipo_analysis_mode) + code_cache = get_code_cache(world, method_table, ipo_analysis_mode) end if method_table !== nothing method_table = CachedMethodTable(OverlayMethodTable(world, method_table)) @@ -315,13 +315,10 @@ end return Future{MethodCallResult}(mret, interp, sv) do ret, interp, sv edge = ret.edge if edge !== nothing - cache = CC.get(CC.code_cache(interp), edge, nothing) - if cache !== nothing - src = @atomic :monotonic cache.inferred - if isa(src, DAECache) - info = src.info - merge_daeinfo!(interp, sv.result, info) - end + src = @atomic :monotonic edge.inferred + if isa(src, DAECache) + info = src.info + merge_daeinfo!(interp, sv.result, info) end end return ret @@ -330,11 +327,11 @@ end @override function CC.const_prop_call(interp::DAEInterpreter, mi::MethodInstance, result::MethodCallResult, arginfo::ArgInfo, - sv::InferenceState, concrete_eval_result::Union{Nothing,ConstCallResults}) + sv::InferenceState, concrete_eval_result::Union{Nothing,ConstCallResult}) ret = @invoke CC.const_prop_call(interp::AbstractInterpreter, mi::MethodInstance, result::MethodCallResult, arginfo::ArgInfo, - sv::InferenceState, concrete_eval_result::Union{Nothing,ConstCallResults}) - if isa(ret, ConstCallResults) + sv::InferenceState, concrete_eval_result::Union{Nothing,ConstCallResult}) + if isa(ret, ConstCallResult) const_result = ret.const_result::ConstPropResult info = interp.dae_cache[const_result.result] merge_daeinfo!(interp, sv.result, info) @@ -353,18 +350,12 @@ struct DAECache new(inferred, ir, info) end -@override CC.transform_result_for_cache(interp::DAEInterpreter, - mi::MethodInstance, valid_worlds::WorldRange, result::InferenceResult, cond::Bool) = - _transform_result_for_cache(interp, mi, valid_worlds, result, cond) - -function _transform_result_for_cache(interp::DAEInterpreter, - mi::MethodInstance, valid_worlds::WorldRange, result::InferenceResult, cond::Bool=false) +function CC.transform_result_for_cache(interp::DAEInterpreter, result::InferenceResult) src = result.src if isa(src, DAECache) return src end - inferred = @invoke CC.transform_result_for_cache(interp::AbstractInterpreter, - mi::MethodInstance, valid_worlds::WorldRange, result::InferenceResult, cond::Bool) + inferred = @invoke CC.transform_result_for_cache(interp::AbstractInterpreter, result) return DAECache(inferred, nothing, interp.dae_cache[result]) end @@ -372,7 +363,7 @@ end # -------- function dae_inlining_policy(@nospecialize(src), @nospecialize(info::CallInfo), raise::Bool=true) - if isa(info, Diffractor.FRuleCallInfo) + if isa(info, Diffractor.FRuleCallInfo) && info.frule_call.rt !== Const(nothing) return nothing end osrc = src @@ -502,13 +493,10 @@ end result::MethodCallResult, si::StmtInfo, sv::InferenceState, force::Bool) edge = result.edge if edge !== nothing - cache = CC.get(CC.code_cache(interp), edge, nothing) - if cache !== nothing - src = @atomic :monotonic cache.inferred - if isa(src, DAECache) - src.info.has_dae_intrinsics && return true - src.info.has_scoperead && return true - end + src = @atomic :monotonic edge.inferred + if isa(src, DAECache) + src.info.has_dae_intrinsics && return true + src.info.has_scoperead && return true end end return @invoke CC.const_prop_rettype_heuristic(interp::AbstractInterpreter, @@ -1052,7 +1040,7 @@ end using Cthulhu function Cthulhu.get_optimized_codeinst(interp::DAEInterpreter, curs::Cthulhu.CthulhuCursor) - interp.code_cache.cache[curs.mi] + CC.getindex(CC.code_cache(interp), curs.mi) end function Cthulhu.lookup(interp::DAEInterpreter, curs::Cthulhu.CthulhuCursor, optimize::Bool) @@ -1109,7 +1097,7 @@ function lookup_optimized(interp::DAEInterpreter, mi::MethodInstance, allow_no_s end Cthulhu.can_descend(interp::DAEInterpreter, @nospecialize(key), optimize::Bool) = - haskey(optimize ? interp.code_cache.cache : interp.unopt, key) + optimize ? CC.haskey(CC.code_cache(interp), key) : haskey(interp.unopt, key) # TODO: Why does Cthulhu have this separately from the lookup logic, which already # returns effects diff --git a/src/transform/common.jl b/src/transform/common.jl index 97d9129..236bb2b 100644 --- a/src/transform/common.jl +++ b/src/transform/common.jl @@ -51,13 +51,16 @@ end function remap_info(remap_ir!, info) # TODO: This is pretty aweful, but it works for now. # It'll go away when we switch to IPO. + if isa(info, Diffractor.FRuleCallInfo) && info.frule_call.rt === Const(nothing) + info = info.info + end isa(info, CC.ConstCallInfo) || return info results = map(info.results) do result result === nothing && return result if isa(result, CC.SemiConcreteResult) let ir = copy(result.ir) remap_ir!(ir) - CC.SemiConcreteResult(result.mi, ir, result.effects, result.spec_info) + CC.SemiConcreteResult(result.edge, ir, result.effects, result.spec_info) end elseif isa(result, CC.ConstPropResult) if isa(result.result.src, DAECache) @@ -76,10 +79,9 @@ function widen_extra_info!(ir) for i = 1:length(ir.stmts) info = ir.stmts[i][:info] if isa(info, Diffractor.FRuleCallInfo) - ir.stmts[i][:info] = info.info - else - ir.stmts[i][:info] = remap_info(widen_extra_info!, info) + info = info.info end + ir.stmts[i][:info] = remap_info(widen_extra_info!, info) inst = ir.stmts[i][:inst] if isa(inst, PiNode) ir.stmts[i][:inst] = PiNode(inst.val, widenconst(inst.typ)) diff --git a/test/ipo.jl b/test/ipo.jl index a139c53..e1abe28 100644 --- a/test/ipo.jl +++ b/test/ipo.jl @@ -3,6 +3,7 @@ module ipo using Test using DAECompiler using DAECompiler.Intrinsics +using DAECompiler.Intrinsics: state_ddt using SciMLBase, OrdinaryDiffEq, Sundials include(joinpath(Base.pkgdir(DAECompiler), "test", "testutils.jl"))