Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
950 changes: 437 additions & 513 deletions Manifest.toml

Large diffs are not rendered by default.

12 changes: 6 additions & 6 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,37 +39,37 @@ StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4"
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
Tracy = "e689c965-62c8-4b79-b2c5-8359227902fd"

[weakdeps]
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"

[sources]
Compiler = {rev = "master", url = "https://github.com/JuliaLang/BaseCompiler.jl.git"}
Cthulhu = {rev = "master", url = "https://github.com/JuliaDebug/Cthulhu.jl.git"}
DifferentiationInterface = {rev = "main", subdir = "DifferentiationInterface", url = "https://github.com/Keno/DifferentiationInterface.jl"}
Diffractor = {rev = "main", url = "https://github.com/JuliaDiff/Diffractor.jl.git"}
Diffractor = {rev = "cthulhu", url = "https://github.com/JuliaDiff/Diffractor.jl.git"}
SimpleNonlinearSolve = {rev = "master", subdir = "lib/SimpleNonlinearSolve", url = "https://github.com/SciML/NonlinearSolve.jl.git"}
StateSelection = {rev = "main", url = "https://github.com/JuliaComputing/StateSelection.jl.git"}

[extensions]
DAECompilerCthulhuExt = ["Compiler", "Cthulhu"]

[compat]
Accessors = "0.1.36"
AutoHashEquals = "2.2.0"
CentralizedCaches = "1.1.0"
ChainRules = "1.50"
ChainRulesCore = "1.20"
Compiler = "0"
Cthulhu = "3.0.0"
DiffEqBase = "6.149.2"
DifferentiationInterface = "0.6.52"
Diffractor = "0.2.7"
DifferentiationInterface = "0.7.9"
ForwardDiff = "0.10.36"
InteractiveUtils = "1.11.0"
NonlinearSolve = "3.5.0, 4"
OrderedCollections = "1.6.3"
PrecompileTools = "1"
Preferences = "1.4"
SciMLBase = "2.86.2"
SimpleNonlinearSolve = "2.3.0"
StateSelection = "0.2.0"
StaticArraysCore = "1.4.2"
Sundials = "4.19"
Expand Down
145 changes: 145 additions & 0 deletions ext/DAECompilerCthulhuExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
module DAECompilerCthulhuExt

using Core.IR
using DAECompiler: DAECompiler, DAEIPOResult, UncompilableIPOResult, Settings, ADAnalyzer, structural_analysis!, find_matching_ci, matched_system_structure, StructureCache, ir_to_src, get_method_instance, MappingInfo, AnalyzedSource
using Compiler: Compiler, InferenceResult, NativeInterpreter, SOURCE_MODE_GET_SOURCE, typeinf_ext, Effects, get_ci_mi, NoCallInfo
using Accessors: setproperties
using Diffractor: FRuleCallInfo

using Cthulhu: Cthulhu, get_module_for_compiler_integration, CthulhuState, AbstractProvider, Command, generate_code_instance, lookup, value_for_default_command, perform_action, get_inference_world, cached_return_type, cached_exception_type
const CompilerIntegration = get_module_for_compiler_integration(; use_compiler_stdlib = true)
using .CompilerIntegration: LookupResult, get_effects, InferenceDict, PC2Remarks, PC2CallMeta, PC2Effects, PC2Excts, ConstPropCallInfo, SemiConcreteCallInfo, OCCallInfo

mutable struct DAEProvider <: AbstractProvider
world::UInt
settings::Settings
remarks::InferenceDict{PC2Remarks}
calls::InferenceDict{PC2CallMeta}
effects::InferenceDict{PC2Effects}
exception_types::InferenceDict{PC2Excts}
end
DAEProvider(; world = Base.tls_world_age(), settings = Settings()) = DAEProvider(world, settings, InferenceDict{PC2Remarks}(), InferenceDict{PC2CallMeta}(), InferenceDict{PC2Effects}(), InferenceDict{PC2Excts}())

Cthulhu.get_inference_world(provider::DAEProvider) = provider.world

function Cthulhu.find_method_instance(provider::DAEProvider, @nospecialize(tt::Type{<:Tuple}), world::UInt)
return get_method_instance(tt, world)
end

function check_result(ci::CodeInstance)
isa(ci.inferred, UncompilableIPOResult) && throw(ci.inferred.error)
return true
end

function Cthulhu.generate_code_instance(provider::DAEProvider, mi::MethodInstance)
world = get_inference_world(provider)
ci = find_matching_ci(ci->ci.owner == StructureCache(), mi, world)
# XXX: We should not cache the CodeInstance this way, or at least invalidate in the provider in `toggle_setting!`.
if ci !== nothing
haskey(provider.remarks, ci) && return ci
else
provider.settings.force_inline_all && @warn "`force_inline_all=true` is not supported yet; this setting will be ignored"
analyzer = ADAnalyzer(; world)
ci_pre = typeinf_ext(analyzer, mi, SOURCE_MODE_GET_SOURCE)
result = structural_analysis!(ci_pre, world, provider.settings)
ci = find_matching_ci(ci->ci.owner == StructureCache(), mi, world)::CodeInstance
end

check_result(ci)
provider.remarks[ci] = PC2Remarks()
provider.calls[ci] = PC2CallMeta()
provider.effects[ci] = PC2Effects()
provider.exception_types[ci] = PC2Excts()

return ci
end

get_override(provider::DAEProvider, info::ConstPropCallInfo) = nothing
get_override(provider::DAEProvider, info::SemiConcreteCallInfo) = nothing
get_override(provider::DAEProvider, info::OCCallInfo) = nothing

Cthulhu.get_pc_remarks(provider::DAEProvider, key::CodeInstance) = get(provider.remarks, key, nothing)
Cthulhu.get_pc_effects(provider::DAEProvider, key::CodeInstance) = get(provider.effects, key, nothing)
Cthulhu.get_pc_excts(provider::DAEProvider, key::CodeInstance) = get(provider.exception_types, key, nothing)

Cthulhu.lookup(provider::DAEProvider, result::InferenceResult, optimize::Bool) = nothing
function Cthulhu.lookup(provider::DAEProvider, ci::CodeInstance, optimize::Bool)
if isa(ci.inferred, AnalyzedSource)
mi = get_ci_mi(ci)
new_ci = generate_code_instance(provider, mi)
check_result(new_ci)
@assert isa(new_ci.inferred, DAEIPOResult) "Inferred type of newly generated `CodeInstance` must be `DAEIPOResult`, got `$(typeof(new_ci.inferred))`"
return lookup(provider, new_ci, optimize)
end
result = ci.inferred::DAEIPOResult
ir = copy(result.ir)
pushfirst!(ir.argtypes, Tuple)
src = ir_to_src(ir, provider.settings; widen = false)
src.ssavaluetypes = copy(ir.stmts.type)
src.min_world = @atomic ci.min_world
src.max_world = @atomic ci.max_world
optimized = true
rt = cached_return_type(ci)
exct = cached_exception_type(ci)
infos = widen_call_infos(ir.stmts.info)
return LookupResult(ir, src, rt, exct, infos, src.slottypes, get_effects(ci), optimized)
end

function widen_call_infos(infos)
infos = copy(infos)
for (i, info) in enumerate(infos)
while true
isa(info, FRuleCallInfo) && (info = info.info; continue)
isa(info, MappingInfo) && (info = info.info; continue)
break
end
infos[i] = info
end
return infos
end

function toggle_setting(provider::DAEProvider, setting::Symbol, value)
return setproperties(provider.settings, NamedTuple((setting => value,)))
end

function Cthulhu.menu_commands(provider::DAEProvider)
commands = Cthulhu.default_menu_commands(provider)
filter!(x -> !in(x.name, (:optimize, :dump_params, :llvm, :native, :inlining_costs)), commands)
push!(commands, toggle_setting(provider, 'f', :force_inline_all, "force inline all"))
push!(commands, perform_action(show_mss, 'm', :show_mss, :actions, "Show system structure"))
return commands
end

function show_mss(state::CthulhuState)
result = state.ci.inferred::DAEIPOResult
terminal = state.terminal
io = terminal.out_stream::IO
mss = matched_system_structure(result, state.provider.settings.mode)
(_, width) = displaysize(terminal)
printstyled(io, '\n', '-'^((width - 26) ÷ 2), " Showing system structure ", '-'^((width - 26) ÷ 2), '\n'; color = :light_black)
show(io, MIME"text/plain"(), mss)
printstyled(io, '\n', '-'^width, "\n\n"; color = :light_black)
end

function toggle_setting(provider::DAEProvider, key::Char, name::Symbol, description::String = string(name))
callback = state -> toggle_setting!(state, name)
Command(callback, key, name, description, :toggles)
end

function Cthulhu.value_for_command(provider::DAEProvider, state::CthulhuState, command::Command)
hasproperty(provider.settings, command.name) &&
return getproperty(provider.settings, command.name)
return value_for_default_command(provider, state, command)
end

function toggle_setting!(state::CthulhuState, name::Symbol)
(; provider) = state
(; settings) = provider
value = !getproperty(settings, name)::Bool
provider.settings = setproperties(settings, NamedTuple((name => value,)))
state.display_code = true
end

DAECompiler.dae_provider(args...; kwargs...) = DAEProvider(args...; kwargs...)

end # module
2 changes: 2 additions & 0 deletions src/DAECompiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,6 @@ module DAECompiler
include("analysis/consistency.jl")
include("interface.jl")
include("problem_interface.jl")

export dae_provider # use with Cthulhu, `@descend provider=dae_provider() pingpong()`
end
2 changes: 1 addition & 1 deletion src/analysis/ADAnalyzer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ end
return AnalyzedSource(ir, slotnames, Compiler.compute_inlining_cost(interp, result), result.src.src.nargs, result.src.src.isva)
end

@override function Compiler.transform_result_for_local_cache(interp::ADAnalyzer, result::InferenceResult)
@override function Compiler.transform_result_for_local_cache(interp::ADAnalyzer, result::InferenceResult, edges::SimpleVector)
if Compiler.result_is_constabi(interp, result)
return nothing
end
Expand Down
17 changes: 17 additions & 0 deletions src/analysis/cache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,20 @@ function make_structure_from_ipo(ipo::DAEIPOResult)

structure = DAESystemStructure(StateSelection.complete(var_to_diff), StateSelection.complete(eq_to_diff), graph, solvable_graph)
end

function matched_system_structure(result::DAEIPOResult, mode)
structure = make_structure_from_ipo(result)

tstate = TransformationState(result, structure)
err = StateSelection.check_consistency(tstate, nothing)
err !== nothing && throw(err)

ret = top_level_state_selection!(tstate)
isa(ret, UncompilableIPOResult) && throw(ret.error)

(diff_key, init_key) = ret
key = in(mode, (DAE, DAENoInit, ODE, ODENoInit)) ? diff_key : init_key

var_eq_matching = matching_for_key(tstate, key)
return StateSelection.MatchedSystemStructure(result, structure, var_eq_matching)
end
3 changes: 3 additions & 0 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,6 @@ function refresh()
return nothing
end
refresh()

# methods are to be added via the Cthulhu extension
function dae_provider end
17 changes: 1 addition & 16 deletions src/reflection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,22 +32,7 @@ function code_structure_by_type(@nospecialize(tt::Type); world::UInt = Base.tls_
_result = structural_analysis!(ci, world, settings)
isa(_result, UncompilableIPOResult) && throw(_result.error)
!matched && return result ? _result : _result.ir
result = _result

structure = make_structure_from_ipo(result)

tstate = TransformationState(result, structure)
err = StateSelection.check_consistency(tstate, nothing)
err !== nothing && throw(err)

ret = top_level_state_selection!(tstate)
isa(ret, UncompilableIPOResult) && throw(ret.error)

(diff_key, init_key) = ret
key = in(mode, (DAE, DAENoInit, ODE, ODENoInit)) ? diff_key : init_key

var_eq_matching = matching_for_key(tstate, key)
return StateSelection.MatchedSystemStructure(result, structure, var_eq_matching)
return matched_system_structure(_result, mode)
end

"""
Expand Down
13 changes: 5 additions & 8 deletions src/transform/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,26 +40,23 @@ function widen_extra_info!(ir)
end
end

function ir_to_src(ir::IRCode, settings::Settings; slotnames = nothing)
isva = false
function ir_to_src(ir::IRCode, settings::Settings; slotnames = nothing, widen = true, isva = false)
ir.debuginfo.def === nothing && (ir.debuginfo.def = :var"generated IR for OpaqueClosure")
maybe_rewrite_debuginfo!(ir, settings)
nargtypes = length(ir.argtypes)
nargs = nargtypes-1
sig = Compiler.compute_oc_signature(ir, nargs, isva)
rt = Compiler.compute_ir_rettype(ir)
src = ccall(:jl_new_code_info_uninit, Ref{CodeInfo}, ())
if slotnames === nothing
src.slotnames = Symbol[Symbol("arg$i") for i = 1:nargtypes]
else
length(slotnames) == nargtypes || error("mismatched `argtypes` and `slotnames`")
src.slotnames = slotnames
end
src.nargs = length(ir.argtypes)
src.isva = false
src.nargs = nargtypes
src.isva = isva
src.slotflags = fill(zero(UInt8), nargtypes)
src.slottypes = copy(ir.argtypes)
src = Compiler.ir_to_codeinf!(src, ir)
Compiler.replace_code_newstyle!(src, ir)
widen && Compiler.widen_all_consts!(src)
return src
end

Expand Down
4 changes: 0 additions & 4 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,10 @@ REPL = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
StateSelection = "64909d44-ed92-46a8-bbd9-f047dfbdc84b"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4"
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
XSteam = "95ff35a0-be81-11e9-2ca3-5b4e338e8476"

[sources]
SciMLSensitivity = {rev = "kf/mindep4", url = "https://github.com/CedarEDA/SciMLSensitivity.jl"}
57 changes: 57 additions & 0 deletions test/cthulhu.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
module CthulhuExt

using Test
using DAECompiler
using DAECompiler.Intrinsics
using Cthulhu
using Cthulhu.Testing
using Cthulhu.Testing: wait_for

function test_descend_for_provider(provider, args...; io = nothing, callsite = 1)
terminal = VirtualTerminal()
harness = @run terminal descend(args...; terminal, provider)
write(terminal, 'A')
write(terminal, 'T')
write(terminal, 'L') # should be a no-op because we disable the LLVM view (which otherwise segfaults)
write(terminal, 'd') # debuginfo: :source
for _ in 2:callsite write(terminal, :down) end
write(terminal, :enter)
write(terminal, 'i') # inlining costs: on
write(terminal, 'S')
write(terminal, :up)
write(terminal, :enter)
write(terminal, 'q')
if io !== nothing
wait_for(harness.task)
displayed = String(readavailable(harness.io))
println(io, displayed)
end
@test end_terminal_session(harness)
end

@noinline function ping(a, b, c, d)
always!(b - sin(a))
always!(d - sin(c))
end

@noinline function pong(a, b, c, d)
always!(b - asin(a))
always!(ddt(d) - asin(c))
end

function pingpong()
a = continuous()
b = continuous()
c = continuous()
d = continuous()
ping(a, b, c, d)
pong(b, c, d, a)
end

io = IOBuffer()
test_descend_for_provider(dae_provider(), pingpong; io, callsite = 5)
text = String(take!(io))
@test contains(text, "Incidence(u₄)")
@test contains(text, "Eq(1)")

end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ include("regression.jl")
include("errors.jl")
include("invalidation.jl")
include("validation.jl")
include("cthulhu.jl")

using Pkg
Pkg.activate(joinpath(dirname(@__DIR__), "benchmark")) do
Expand Down
Loading