Skip to content
Draft
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,13 @@ StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4"
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
Tracy = "e689c965-62c8-4b79-b2c5-8359227902fd"

[weakdeps]
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"

[extensions]
DAECompilerCthulhuExt = ["Compiler", "Cthulhu"]

[sources]
Compiler = {rev = "master", url = "https://github.com/JuliaLang/BaseCompiler.jl.git"}
Cthulhu = {rev = "master", url = "https://github.com/JuliaDebug/Cthulhu.jl.git"}
Expand Down
132 changes: 132 additions & 0 deletions ext/DAECompilerCthulhuExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
module DAECompilerCthulhuExt

using Core.IR
using DAECompiler: DAECompiler, DAEIPOResult, UncompilableIPOResult, Settings, ADAnalyzer, structural_analysis!, find_matching_ci, StructureCache, ir_to_src, get_method_instance, MappingInfo, AnalyzedSource
using Compiler: Compiler, InferenceResult, NativeInterpreter, SOURCE_MODE_GET_SOURCE, get_inference_world, typeinf_ext, Effects, get_ci_mi, NoCallInfo
using Accessors: setproperties
using Diffractor: FRuleCallInfo

import Cthulhu as _Cthulhu
const Cthulhu = Base.get_extension(_Cthulhu, :CthulhuCompilerExt)
using .Cthulhu: CthulhuState, AbstractProvider, Command, InferenceKey, InferenceDict, PC2Remarks, PC2CallMeta, PC2Effects, PC2Excts, LookupResult, generate_code_instance, value_for_default_command

mutable struct DAEProvider <: AbstractProvider
world::UInt
settings::Settings
remarks::InferenceDict{PC2Remarks}
calls::InferenceDict{PC2CallMeta}
effects::InferenceDict{PC2Effects}
exception_types::InferenceDict{PC2Excts}
end
DAEProvider(; world = Base.tls_world_age(), settings = Settings()) = DAEProvider(world, settings, InferenceDict{PC2Remarks}(), InferenceDict{PC2CallMeta}(), InferenceDict{PC2Effects}(), InferenceDict{PC2Excts}())

Cthulhu.get_inference_world(provider::DAEProvider) = provider.world

function Cthulhu.find_method_instance(provider::DAEProvider, @nospecialize(tt::Type{<:Tuple}), world::UInt)
return get_method_instance(tt, world)
end

function check_result(ci::CodeInstance)
isa(ci.inferred, UncompilableIPOResult) && throw(ci.inferred.error)
return true
end

function Cthulhu.generate_code_instance(provider::DAEProvider, mi::MethodInstance)
world = get_inference_world(provider)
ci = find_matching_ci(ci->ci.owner == StructureCache(), mi, world)
# XXX: We should not cache the CodeInstance this way, or at least invalidate in the provider in `toggle_setting!`.
if ci !== nothing
haskey(provider.remarks, ci) && return ci
else
provider.settings.force_inline_all && @warn "`force_inline_all=true` is not supported yet; this setting will be ignored"
analyzer = ADAnalyzer(; world)
ci_pre = typeinf_ext(analyzer, mi, SOURCE_MODE_GET_SOURCE)
result = structural_analysis!(ci_pre, world, provider.settings)
ci = find_matching_ci(ci->ci.owner == StructureCache(), mi, world)::CodeInstance
end

check_result(ci)
provider.remarks[ci] = PC2Remarks()
provider.calls[ci] = PC2CallMeta()
provider.effects[ci] = PC2Effects()
provider.exception_types[ci] = PC2Excts()

@eval Main global result = $(ci.inferred)

return ci
end

Cthulhu.get_override(provider::DAEProvider, @nospecialize(info)) = nothing

Cthulhu.get_pc_remarks(provider::DAEProvider, key::CodeInstance) = get(provider.remarks, key, nothing)
Cthulhu.get_pc_effects(provider::DAEProvider, key::CodeInstance) = get(provider.effects, key, nothing)
Cthulhu.get_pc_excts(provider::DAEProvider, key::CodeInstance) = get(provider.exception_types, key, nothing)

function Cthulhu.LookupResult(provider::DAEProvider, ci::CodeInstance, optimize::Bool)
if isa(ci.inferred, AnalyzedSource)
mi = get_ci_mi(ci)
new_ci = generate_code_instance(provider, mi)
check_result(new_ci)
@assert isa(new_ci.inferred, DAEIPOResult) "Inferred type of newly generated `CodeInstance` must be `DAEIPOResult`, got `$(typeof(new_ci.inferred))`"
return LookupResult(provider, new_ci, optimize)
end
result = ci.inferred::DAEIPOResult
ir = copy(result.ir)
pushfirst!(ir.argtypes, Tuple)
src = ir_to_src(ir, provider.settings; widen = false)
src.ssavaluetypes = copy(ir.stmts.type)
src.min_world = @atomic ci.min_world
src.max_world = @atomic ci.max_world
optimized = true
rt = Cthulhu.cached_return_type(ci)
exct = Cthulhu.cached_exception_type(ci)
infos = widen_call_infos(ir.stmts.info)
return LookupResult(ir, src, rt, exct, infos, src.slottypes, Cthulhu.get_effects(ci), optimized)
end

function widen_call_infos(infos)
infos = copy(infos)
for (i, info) in enumerate(infos)
while true
isa(info, FRuleCallInfo) && (info = info.info; continue)
isa(info, MappingInfo) && (info = info.info; continue)
break
end
infos[i] = info
end
return infos
end

function toggle_setting(provider::DAEProvider, setting::Symbol, value)
return setproperties(provider.settings, NamedTuple((setting => value,)))
end

function Cthulhu.menu_commands(provider::DAEProvider)
commands = Cthulhu.default_menu_commands(provider)
filter!(x -> !in(x.name, (:optimize, :dump_params, :llvm, :native)), commands)
push!(commands, toggle_setting(provider, 'f', :force_inline_all, "force inline all"))
return commands
end

function toggle_setting(provider::DAEProvider, key::Char, name::Symbol, description::String = string(name))
callback = state -> toggle_setting!(state, name)
Command(callback, key, name, description, :toggles)
end

function Cthulhu.value_for_command(provider::DAEProvider, state::CthulhuState, command::Command)
hasproperty(provider.settings, command.name) &&
return getproperty(provider.settings, command.name)
return value_for_default_command(provider, state, command)
end

function toggle_setting!(state::CthulhuState, name::Symbol)
(; provider) = state
(; settings) = provider
value = !getproperty(settings, name)::Bool
provider.settings = setproperties(settings, NamedTuple((name => value,)))
state.display_code = true
end

DAECompiler.dae_provider(args...; kwargs...) = DAEProvider(args...; kwargs...)

end # module
2 changes: 2 additions & 0 deletions src/DAECompiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,6 @@ module DAECompiler
include("analysis/consistency.jl")
include("interface.jl")
include("problem_interface.jl")

export dae_provider # use with Cthulhu, `@descend provider=dae_provider() pingpong()`
end
3 changes: 3 additions & 0 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,6 @@ function refresh()
return nothing
end
refresh()

# methods are to be added via the Cthulhu extension
function dae_provider end
13 changes: 5 additions & 8 deletions src/transform/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,26 +40,23 @@ function widen_extra_info!(ir)
end
end

function ir_to_src(ir::IRCode, settings::Settings; slotnames = nothing)
isva = false
function ir_to_src(ir::IRCode, settings::Settings; slotnames = nothing, widen = true, isva = false)
ir.debuginfo.def === nothing && (ir.debuginfo.def = :var"generated IR for OpaqueClosure")
maybe_rewrite_debuginfo!(ir, settings)
nargtypes = length(ir.argtypes)
nargs = nargtypes-1
sig = Compiler.compute_oc_signature(ir, nargs, isva)
rt = Compiler.compute_ir_rettype(ir)
src = ccall(:jl_new_code_info_uninit, Ref{CodeInfo}, ())
if slotnames === nothing
src.slotnames = Symbol[Symbol("arg$i") for i = 1:nargtypes]
else
length(slotnames) == nargtypes || error("mismatched `argtypes` and `slotnames`")
src.slotnames = slotnames
end
src.nargs = length(ir.argtypes)
src.isva = false
src.nargs = nargtypes
src.isva = isva
src.slotflags = fill(zero(UInt8), nargtypes)
src.slottypes = copy(ir.argtypes)
src = Compiler.ir_to_codeinf!(src, ir)
Compiler.replace_code_newstyle!(src, ir)
widen && Compiler.widen_all_consts!(src)
return src
end

Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ include("regression.jl")
include("errors.jl")
include("invalidation.jl")
include("validation.jl")
include("cthulhu.jl")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing file?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's actually the include that I forgot to remove. I planned to have tests (and still do), but until now I've just tested manually.


using Pkg
Pkg.activate(joinpath(dirname(@__DIR__), "benchmark")) do
Expand Down
Loading