diff --git a/Project.toml b/Project.toml index c1d8752..962c0dc 100644 --- a/Project.toml +++ b/Project.toml @@ -39,11 +39,13 @@ StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4" SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5" TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" -Tracy = "e689c965-62c8-4b79-b2c5-8359227902fd" [weakdeps] ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78" +[extensions] +DAECompilerCthulhuExt = ["Compiler", "Cthulhu"] + [sources] Compiler = {rev = "master", url = "https://github.com/JuliaLang/BaseCompiler.jl.git"} Cthulhu = {rev = "master", url = "https://github.com/JuliaDebug/Cthulhu.jl.git"} diff --git a/ext/DAECompilerCthulhuExt.jl b/ext/DAECompilerCthulhuExt.jl new file mode 100644 index 0000000..159f440 --- /dev/null +++ b/ext/DAECompilerCthulhuExt.jl @@ -0,0 +1,132 @@ +module DAECompilerCthulhuExt + +using Core.IR +using DAECompiler: DAECompiler, DAEIPOResult, UncompilableIPOResult, Settings, ADAnalyzer, structural_analysis!, find_matching_ci, StructureCache, ir_to_src, get_method_instance, MappingInfo, AnalyzedSource +using Compiler: Compiler, InferenceResult, NativeInterpreter, SOURCE_MODE_GET_SOURCE, get_inference_world, typeinf_ext, Effects, get_ci_mi, NoCallInfo +using Accessors: setproperties +using Diffractor: FRuleCallInfo + +import Cthulhu as _Cthulhu +const Cthulhu = Base.get_extension(_Cthulhu, :CthulhuCompilerExt) +using .Cthulhu: CthulhuState, AbstractProvider, Command, InferenceKey, InferenceDict, PC2Remarks, PC2CallMeta, PC2Effects, PC2Excts, LookupResult, generate_code_instance, value_for_default_command + +mutable struct DAEProvider <: AbstractProvider + world::UInt + settings::Settings + remarks::InferenceDict{PC2Remarks} + calls::InferenceDict{PC2CallMeta} + effects::InferenceDict{PC2Effects} + exception_types::InferenceDict{PC2Excts} +end +DAEProvider(; world = Base.tls_world_age(), settings = Settings()) = DAEProvider(world, settings, InferenceDict{PC2Remarks}(), InferenceDict{PC2CallMeta}(), InferenceDict{PC2Effects}(), InferenceDict{PC2Excts}()) + +Cthulhu.get_inference_world(provider::DAEProvider) = provider.world + +function Cthulhu.find_method_instance(provider::DAEProvider, @nospecialize(tt::Type{<:Tuple}), world::UInt) + return get_method_instance(tt, world) +end + +function check_result(ci::CodeInstance) + isa(ci.inferred, UncompilableIPOResult) && throw(ci.inferred.error) + return true +end + +function Cthulhu.generate_code_instance(provider::DAEProvider, mi::MethodInstance) + world = get_inference_world(provider) + ci = find_matching_ci(ci->ci.owner == StructureCache(), mi, world) + # XXX: We should not cache the CodeInstance this way, or at least invalidate in the provider in `toggle_setting!`. + if ci !== nothing + haskey(provider.remarks, ci) && return ci + else + provider.settings.force_inline_all && @warn "`force_inline_all=true` is not supported yet; this setting will be ignored" + analyzer = ADAnalyzer(; world) + ci_pre = typeinf_ext(analyzer, mi, SOURCE_MODE_GET_SOURCE) + result = structural_analysis!(ci_pre, world, provider.settings) + ci = find_matching_ci(ci->ci.owner == StructureCache(), mi, world)::CodeInstance + end + + check_result(ci) + provider.remarks[ci] = PC2Remarks() + provider.calls[ci] = PC2CallMeta() + provider.effects[ci] = PC2Effects() + provider.exception_types[ci] = PC2Excts() + + @eval Main global result = $(ci.inferred) + + return ci +end + +Cthulhu.get_override(provider::DAEProvider, @nospecialize(info)) = nothing + +Cthulhu.get_pc_remarks(provider::DAEProvider, key::CodeInstance) = get(provider.remarks, key, nothing) +Cthulhu.get_pc_effects(provider::DAEProvider, key::CodeInstance) = get(provider.effects, key, nothing) +Cthulhu.get_pc_excts(provider::DAEProvider, key::CodeInstance) = get(provider.exception_types, key, nothing) + +function Cthulhu.LookupResult(provider::DAEProvider, ci::CodeInstance, optimize::Bool) + if isa(ci.inferred, AnalyzedSource) + mi = get_ci_mi(ci) + new_ci = generate_code_instance(provider, mi) + check_result(new_ci) + @assert isa(new_ci.inferred, DAEIPOResult) "Inferred type of newly generated `CodeInstance` must be `DAEIPOResult`, got `$(typeof(new_ci.inferred))`" + return LookupResult(provider, new_ci, optimize) + end + result = ci.inferred::DAEIPOResult + ir = copy(result.ir) + pushfirst!(ir.argtypes, Tuple) + src = ir_to_src(ir, provider.settings; widen = false) + src.ssavaluetypes = copy(ir.stmts.type) + src.min_world = @atomic ci.min_world + src.max_world = @atomic ci.max_world + optimized = true + rt = Cthulhu.cached_return_type(ci) + exct = Cthulhu.cached_exception_type(ci) + infos = widen_call_infos(ir.stmts.info) + return LookupResult(ir, src, rt, exct, infos, src.slottypes, Cthulhu.get_effects(ci), optimized) +end + +function widen_call_infos(infos) + infos = copy(infos) + for (i, info) in enumerate(infos) + while true + isa(info, FRuleCallInfo) && (info = info.info; continue) + isa(info, MappingInfo) && (info = info.info; continue) + break + end + infos[i] = info + end + return infos +end + +function toggle_setting(provider::DAEProvider, setting::Symbol, value) + return setproperties(provider.settings, NamedTuple((setting => value,))) +end + +function Cthulhu.menu_commands(provider::DAEProvider) + commands = Cthulhu.default_menu_commands(provider) + filter!(x -> !in(x.name, (:optimize, :dump_params, :llvm, :native)), commands) + push!(commands, toggle_setting(provider, 'f', :force_inline_all, "force inline all")) + return commands +end + +function toggle_setting(provider::DAEProvider, key::Char, name::Symbol, description::String = string(name)) + callback = state -> toggle_setting!(state, name) + Command(callback, key, name, description, :toggles) +end + +function Cthulhu.value_for_command(provider::DAEProvider, state::CthulhuState, command::Command) + hasproperty(provider.settings, command.name) && + return getproperty(provider.settings, command.name) + return value_for_default_command(provider, state, command) +end + +function toggle_setting!(state::CthulhuState, name::Symbol) + (; provider) = state + (; settings) = provider + value = !getproperty(settings, name)::Bool + provider.settings = setproperties(settings, NamedTuple((name => value,))) + state.display_code = true +end + +DAECompiler.dae_provider(args...; kwargs...) = DAEProvider(args...; kwargs...) + +end # module diff --git a/src/DAECompiler.jl b/src/DAECompiler.jl index a1d7d7f..321e757 100644 --- a/src/DAECompiler.jl +++ b/src/DAECompiler.jl @@ -43,4 +43,6 @@ module DAECompiler include("analysis/consistency.jl") include("interface.jl") include("problem_interface.jl") + + export dae_provider # use with Cthulhu, `@descend provider=dae_provider() pingpong()` end diff --git a/src/interface.jl b/src/interface.jl index 767cf9a..67924d0 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -113,3 +113,6 @@ function refresh() return nothing end refresh() + +# methods are to be added via the Cthulhu extension +function dae_provider end diff --git a/src/transform/common.jl b/src/transform/common.jl index 042a29e..a31e142 100644 --- a/src/transform/common.jl +++ b/src/transform/common.jl @@ -40,14 +40,10 @@ function widen_extra_info!(ir) end end -function ir_to_src(ir::IRCode, settings::Settings; slotnames = nothing) - isva = false +function ir_to_src(ir::IRCode, settings::Settings; slotnames = nothing, widen = true, isva = false) ir.debuginfo.def === nothing && (ir.debuginfo.def = :var"generated IR for OpaqueClosure") maybe_rewrite_debuginfo!(ir, settings) nargtypes = length(ir.argtypes) - nargs = nargtypes-1 - sig = Compiler.compute_oc_signature(ir, nargs, isva) - rt = Compiler.compute_ir_rettype(ir) src = ccall(:jl_new_code_info_uninit, Ref{CodeInfo}, ()) if slotnames === nothing src.slotnames = Symbol[Symbol("arg$i") for i = 1:nargtypes] @@ -55,11 +51,12 @@ function ir_to_src(ir::IRCode, settings::Settings; slotnames = nothing) length(slotnames) == nargtypes || error("mismatched `argtypes` and `slotnames`") src.slotnames = slotnames end - src.nargs = length(ir.argtypes) - src.isva = false + src.nargs = nargtypes + src.isva = isva src.slotflags = fill(zero(UInt8), nargtypes) src.slottypes = copy(ir.argtypes) - src = Compiler.ir_to_codeinf!(src, ir) + Compiler.replace_code_newstyle!(src, ir) + widen && Compiler.widen_all_consts!(src) return src end