Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,13 @@ StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4"
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
Tracy = "e689c965-62c8-4b79-b2c5-8359227902fd"

[weakdeps]
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"

[extensions]
DAECompilerCthulhuExt = ["Compiler", "Cthulhu"]

[sources]
Compiler = {rev = "master", url = "https://github.com/JuliaLang/BaseCompiler.jl.git"}
Cthulhu = {rev = "master", url = "https://github.com/JuliaDebug/Cthulhu.jl.git"}
Expand Down
132 changes: 132 additions & 0 deletions ext/DAECompilerCthulhuExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
module DAECompilerCthulhuExt

using Core.IR
using DAECompiler: DAECompiler, DAEIPOResult, UncompilableIPOResult, Settings, ADAnalyzer, structural_analysis!, find_matching_ci, StructureCache, ir_to_src, get_method_instance, MappingInfo, AnalyzedSource
using Compiler: Compiler, InferenceResult, NativeInterpreter, SOURCE_MODE_GET_SOURCE, get_inference_world, typeinf_ext, Effects, get_ci_mi, NoCallInfo
using Accessors: setproperties
using Diffractor: FRuleCallInfo

import Cthulhu as _Cthulhu
const Cthulhu = Base.get_extension(_Cthulhu, :CthulhuCompilerExt)
using .Cthulhu: CthulhuState, AbstractProvider, Command, InferenceKey, InferenceDict, PC2Remarks, PC2CallMeta, PC2Effects, PC2Excts, LookupResult, generate_code_instance, value_for_default_command

mutable struct DAEProvider <: AbstractProvider
world::UInt
settings::Settings
remarks::InferenceDict{PC2Remarks}
calls::InferenceDict{PC2CallMeta}
effects::InferenceDict{PC2Effects}
exception_types::InferenceDict{PC2Excts}
end
DAEProvider(; world = Base.tls_world_age(), settings = Settings()) = DAEProvider(world, settings, InferenceDict{PC2Remarks}(), InferenceDict{PC2CallMeta}(), InferenceDict{PC2Effects}(), InferenceDict{PC2Excts}())

Cthulhu.get_inference_world(provider::DAEProvider) = provider.world

function Cthulhu.find_method_instance(provider::DAEProvider, @nospecialize(tt::Type{<:Tuple}), world::UInt)
return get_method_instance(tt, world)
end

function check_result(ci::CodeInstance)
isa(ci.inferred, UncompilableIPOResult) && throw(ci.inferred.error)
return true
end

function Cthulhu.generate_code_instance(provider::DAEProvider, mi::MethodInstance)
world = get_inference_world(provider)
ci = find_matching_ci(ci->ci.owner == StructureCache(), mi, world)
# XXX: We should not cache the CodeInstance this way, or at least invalidate in the provider in `toggle_setting!`.
if ci !== nothing
haskey(provider.remarks, ci) && return ci
else
provider.settings.force_inline_all && @warn "`force_inline_all=true` is not supported yet; this setting will be ignored"
analyzer = ADAnalyzer(; world)
ci_pre = typeinf_ext(analyzer, mi, SOURCE_MODE_GET_SOURCE)
result = structural_analysis!(ci_pre, world, provider.settings)
ci = find_matching_ci(ci->ci.owner == StructureCache(), mi, world)::CodeInstance
end

check_result(ci)
provider.remarks[ci] = PC2Remarks()
provider.calls[ci] = PC2CallMeta()
provider.effects[ci] = PC2Effects()
provider.exception_types[ci] = PC2Excts()

@eval Main global result = $(ci.inferred)

return ci
end

Cthulhu.get_override(provider::DAEProvider, @nospecialize(info)) = nothing

Cthulhu.get_pc_remarks(provider::DAEProvider, key::CodeInstance) = get(provider.remarks, key, nothing)
Cthulhu.get_pc_effects(provider::DAEProvider, key::CodeInstance) = get(provider.effects, key, nothing)
Cthulhu.get_pc_excts(provider::DAEProvider, key::CodeInstance) = get(provider.exception_types, key, nothing)

function Cthulhu.LookupResult(provider::DAEProvider, ci::CodeInstance, optimize::Bool)
if isa(ci.inferred, AnalyzedSource)
mi = get_ci_mi(ci)
new_ci = generate_code_instance(provider, mi)
check_result(new_ci)
@assert isa(new_ci.inferred, DAEIPOResult) "Inferred type of newly generated `CodeInstance` must be `DAEIPOResult`, got `$(typeof(new_ci.inferred))`"
return LookupResult(provider, new_ci, optimize)
end
result = ci.inferred::DAEIPOResult
ir = copy(result.ir)
pushfirst!(ir.argtypes, Tuple)
src = ir_to_src(ir, provider.settings; widen = false)
src.ssavaluetypes = copy(ir.stmts.type)
src.min_world = @atomic ci.min_world
src.max_world = @atomic ci.max_world
optimized = true
rt = Cthulhu.cached_return_type(ci)
exct = Cthulhu.cached_exception_type(ci)
infos = widen_call_infos(ir.stmts.info)
return LookupResult(ir, src, rt, exct, infos, src.slottypes, Cthulhu.get_effects(ci), optimized)
end

function widen_call_infos(infos)
infos = copy(infos)
for (i, info) in enumerate(infos)
while true
isa(info, FRuleCallInfo) && (info = info.info; continue)
isa(info, MappingInfo) && (info = info.info; continue)
break
end
infos[i] = info
end
return infos
end

function toggle_setting(provider::DAEProvider, setting::Symbol, value)
return setproperties(provider.settings, NamedTuple((setting => value,)))
end

function Cthulhu.menu_commands(provider::DAEProvider)
commands = Cthulhu.default_menu_commands(provider)
filter!(x -> !in(x.name, (:optimize, :dump_params, :llvm, :native)), commands)
push!(commands, toggle_setting(provider, 'f', :force_inline_all, "force inline all"))
return commands
end

function toggle_setting(provider::DAEProvider, key::Char, name::Symbol, description::String = string(name))
callback = state -> toggle_setting!(state, name)
Command(callback, key, name, description, :toggles)
end

function Cthulhu.value_for_command(provider::DAEProvider, state::CthulhuState, command::Command)
hasproperty(provider.settings, command.name) &&
return getproperty(provider.settings, command.name)
return value_for_default_command(provider, state, command)
end

function toggle_setting!(state::CthulhuState, name::Symbol)
(; provider) = state
(; settings) = provider
value = !getproperty(settings, name)::Bool
provider.settings = setproperties(settings, NamedTuple((name => value,)))
state.display_code = true
end

DAECompiler.dae_provider(args...; kwargs...) = DAEProvider(args...; kwargs...)

end # module
2 changes: 2 additions & 0 deletions src/DAECompiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,6 @@ module DAECompiler
include("analysis/consistency.jl")
include("interface.jl")
include("problem_interface.jl")

export dae_provider # use with Cthulhu, `@descend provider=dae_provider() pingpong()`
end
3 changes: 3 additions & 0 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,6 @@ function refresh()
return nothing
end
refresh()

# methods are to be added via the Cthulhu extension
function dae_provider end
13 changes: 5 additions & 8 deletions src/transform/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,26 +40,23 @@ function widen_extra_info!(ir)
end
end

function ir_to_src(ir::IRCode, settings::Settings; slotnames = nothing)
isva = false
function ir_to_src(ir::IRCode, settings::Settings; slotnames = nothing, widen = true, isva = false)
ir.debuginfo.def === nothing && (ir.debuginfo.def = :var"generated IR for OpaqueClosure")
maybe_rewrite_debuginfo!(ir, settings)
nargtypes = length(ir.argtypes)
nargs = nargtypes-1
sig = Compiler.compute_oc_signature(ir, nargs, isva)
rt = Compiler.compute_ir_rettype(ir)
src = ccall(:jl_new_code_info_uninit, Ref{CodeInfo}, ())
if slotnames === nothing
src.slotnames = Symbol[Symbol("arg$i") for i = 1:nargtypes]
else
length(slotnames) == nargtypes || error("mismatched `argtypes` and `slotnames`")
src.slotnames = slotnames
end
src.nargs = length(ir.argtypes)
src.isva = false
src.nargs = nargtypes
src.isva = isva
src.slotflags = fill(zero(UInt8), nargtypes)
src.slottypes = copy(ir.argtypes)
src = Compiler.ir_to_codeinf!(src, ir)
Compiler.replace_code_newstyle!(src, ir)
widen && Compiler.widen_all_consts!(src)
return src
end

Expand Down
Loading