Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions Manifest.toml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 6 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
CentralizedCaches = "d1073d05-2d26-4019-b855-dfa0385fef5e"
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Compiler = "807dbc54-b67e-4c79-8afb-eafe4df6f2e1"
Cthulhu = "f68482b8-f384-11e8-15f7-abe071a5a75f"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def"
Expand All @@ -33,17 +34,17 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
StateSelection = "64909d44-ed92-46a8-bbd9-f047dfbdc84b"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4"
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
Tracy = "e689c965-62c8-4b79-b2c5-8359227902fd"
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"

[weakdeps]
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"

[sources]
SciMLBase = {rev = "os/dae-get-du2", url = "https://github.com/CedarEDA/SciMLBase.jl"}
SciMLSensitivity = {rev = "kf/mindep2", url = "https://github.com/CedarEDA/SciMLSensitivity.jl"}

[weakdeps]
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"

[extensions]
DAECompilerModelingToolkitExt = "ModelingToolkit"

Expand All @@ -52,6 +53,7 @@ Accessors = "0.1.36"
CentralizedCaches = "1.1.0"
ChainRules = "1.50"
ChainRulesCore = "1.20"
Compiler = "0.0.1"
Cthulhu = "2.10.1"
DiffEqBase = "6.149.2"
Diffractor = "0.2.7"
Expand Down
3 changes: 2 additions & 1 deletion src/DAECompiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ function reconstruct_sensitivities(args...)
error("This method requires SciMLSensitivity")
end

const CC = Core.Compiler
import Compiler
const CC = Compiler
import .CC: get_inference_world
using Base: get_world_counter

Expand Down
36 changes: 18 additions & 18 deletions src/analysis/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ using ForwardDiff
using Base.Meta
using Graphs
using Core.IR
using Core.Compiler: InferenceState, bbidxiter, dominates, tmerge, typeinf_lattice
using .CC: InferenceState, bbidxiter, dominates, tmerge, typeinf_lattice

@breadcrumb "ir_levels" function run_dae_passes(
interp::DAEInterpreter, ir::IRCode, debug_config::DebugConfig = DebugConfig())
Expand Down Expand Up @@ -317,7 +317,7 @@ has_any_genscope(sc::Scope) = isdefined(sc, :parent) && has_any_genscope(sc.pare
has_any_genscope(sc::PartialScope) = false
has_any_genscope(sc::PartialStruct) = false # TODO

function _make_argument_lattice_elem(which::Argument, @nospecialize(argt), add_variable!, add_equation!, add_scope!)
function _make_argument_lattice_elem(𝕃, which::Argument, @nospecialize(argt), add_variable!, add_equation!, add_scope!)
if isa(argt, Const)
#@assert !isa(argt.val, Scope) # Shouldn't have been forwarded
return argt
Expand All @@ -331,7 +331,7 @@ function _make_argument_lattice_elem(which::Argument, @nospecialize(argt), add_v
inc = Incidence(add_variable!(which))
return argt === Float64 ? inc : Incidence(argt, inc.row, inc.eps)
elseif isa(argt, PartialStruct)
return PartialStruct(argt.typ, Any[make_argument_lattice_elem(which, f, add_variable!, add_equation!, add_scope!) for f in argt.fields])
return PartialStruct(𝕃, argt.typ, Any[make_argument_lattice_elem(𝕃, which, f, add_variable!, add_equation!, add_scope!) for f in argt.fields])
elseif isabstracttype(argt) || ismutabletype(argt) || !isa(argt, DataType)
return nothing
else
Expand All @@ -344,20 +344,20 @@ function _make_argument_lattice_elem(which::Argument, @nospecialize(argt), add_v
for i = 1:length(fieldtypes(argt))
# TODO: Can we make this lazy?
ft = fieldtype(argt, i)
mft = _make_argument_lattice_elem(which, ft, add_variable!, add_equation!, add_scope!)
mft = _make_argument_lattice_elem(𝕃, which, ft, add_variable!, add_equation!, add_scope!)
if mft === nothing
push!(fields, Incidence(ft))
else
any = true
push!(fields, mft)
end
end
return any ? PartialStruct(argt, fields) : nothing
return any ? PartialStruct(𝕃, argt, fields) : nothing
end
end

function make_argument_lattice_elem(which::Argument, @nospecialize(argt), add_variable!, add_equation!, add_scope!)
mft = _make_argument_lattice_elem(which, argt, add_variable!, add_equation!, add_scope!)
function make_argument_lattice_elem(𝕃, which::Argument, @nospecialize(argt), add_variable!, add_equation!, add_scope!)
mft = _make_argument_lattice_elem(𝕃, which, argt, add_variable!, add_equation!, add_scope!)
mft === nothing ? Incidence(argt) : mft
end

Expand Down Expand Up @@ -532,7 +532,7 @@ end
nexternalvars = 0 # number of variables that we expect to come in
nexternaleqs = 0 # number of equation references that we expect to come in
if caller !== nothing
argtypes = Any[make_argument_lattice_elem(Argument(i), argt, add_variable!, add_equation!, add_scope!) for (i, argt) in enumerate(ir.argtypes)]
argtypes = Any[make_argument_lattice_elem(CC.typeinf_lattice(interp), Argument(i), argt, add_variable!, add_equation!, add_scope!) for (i, argt) in enumerate(ir.argtypes)]
nexternalvars = length(var_to_diff)
nexternaleqs = length(eqssas)
else
Expand Down Expand Up @@ -571,7 +571,7 @@ end
end
end

cur_scope_lattice = PartialStruct(Base.ScopedValues.Scope,
cur_scope_lattice = PartialStruct(CC.typeinf_lattice(interp), Base.ScopedValues.Scope,
Any[PartialKeyValue(Incidence(Base.PersistentDict{Base.ScopedValues.ScopedValue, Any}))])

# Scan the IR, computing equations, variables, diffgraph, etc.
Expand Down Expand Up @@ -1017,7 +1017,7 @@ end
for eq = 1:length(result.eq_kind)
mapped_eq = mapping.eqs[eq]
mapped_eq == 0 && continue
mapped_inc = apply_linear_incidence(result.total_incidence[eq], result, var_to_diff, var_kind, eq_kind, mapping)
mapped_inc = apply_linear_incidence(CC.typeinf_lattice(interp), result.total_incidence[eq], result, var_to_diff, var_kind, eq_kind, mapping)
if isassigned(total_incidence, mapped_eq)
total_incidence[mapped_eq] = tfunc(Val(Core.Intrinsics.add_float),
total_incidence[mapped_eq],
Expand All @@ -1033,7 +1033,7 @@ end

for (ieq, inc) in enumerate(result.total_incidence[(result.nexternaleqs+1):end])
mapping.eqs[ieq] == 0 || continue
push!(total_incidence, apply_linear_incidence(inc, result, var_to_diff, var_kind, eq_kind, mapping))
push!(total_incidence, apply_linear_incidence(CC.typeinf_lattice(interp), inc, result, var_to_diff, var_kind, eq_kind, mapping))
push!(eq_callee_mapping, [SSAValue(i)=>ieq])
push!(eq_kind, CalleeInternal)
mapping.eqs[ieq] = length(total_incidence)
Expand Down Expand Up @@ -1115,7 +1115,7 @@ end

nimplicitoutpairs = 0
if caller !== nothing
ultimate_rt, nimplicitoutpairs = process_ipo_return!(ultimate_rt, eq_kind, var_kind,
ultimate_rt, nimplicitoutpairs = process_ipo_return!(CC.typeinf_lattice(interp), ultimate_rt, eq_kind, var_kind,
var_to_diff, total_incidence, eq_callee_mapping)
end

Expand All @@ -1135,7 +1135,7 @@ end
Dict{TornCacheKey, CodeInstance}())
end

function process_ipo_return!(ultimate_rt::Incidence, eq_kind, var_kind, var_to_diff, total_incidence, eq_callee_mapping)
function process_ipo_return!(𝕃, ultimate_rt::Incidence, eq_kind, var_kind, var_to_diff, total_incidence, eq_callee_mapping)
nonlinrepl = nothing
nimplicitoutpairs = 0
function get_nonlinrepl()
Expand Down Expand Up @@ -1179,20 +1179,20 @@ function process_ipo_return!(ultimate_rt::Incidence, eq_kind, var_kind, var_to_d
return ultimate_rt, nimplicitoutpairs
end

function process_ipo_return!(ultimate_rt::Eq, eq_kind, args...)
function process_ipo_return!(𝕃, ultimate_rt::Eq, eq_kind, args...)
eq_kind[ultimate_rt.id] = External
return ultimate_rt, 0
end
process_ipo_return!(ultimate_rt::Union{Type, PartialScope, PartialKeyValue, Const}, args...) = ultimate_rt, 0
function process_ipo_return!(ultimate_rt::PartialStruct, args...)
process_ipo_return!(𝕃, ultimate_rt::Union{Type, PartialScope, PartialKeyValue, Const}, args...) = ultimate_rt, 0
function process_ipo_return!(𝕃, ultimate_rt::PartialStruct, args...)
nimplicitoutpairs = 0
fields = Any[]
for f in ultimate_rt.fields
(rt, n) = process_ipo_return!(f, args...)
(rt, n) = process_ipo_return!(𝕃, f, args...)
nimplicitoutpairs += n
push!(fields, rt)
end
return PartialStruct(ultimate_rt.typ, fields), nimplicitoutpairs
return PartialStruct(𝕃, ultimate_rt.typ, fields), nimplicitoutpairs
end

function get_variable_name(names::OrderedDict, var_to_diff, var_idx)
Expand Down
23 changes: 1 addition & 22 deletions src/analysis/compiler_reexports.jl
Original file line number Diff line number Diff line change
@@ -1,27 +1,9 @@
using Core.IR
using Core.Compiler: IRCode, Instruction, InstructionStream, IncrementalCompact,
using .CC: IRCode, Instruction, InstructionStream, IncrementalCompact,
NewInstruction, DomTree, BBIdxIter, AnySSAValue, UseRef, UseRefIterator,
block_for_inst, cfg_simplify!, is_known_call, argextype, getfield_tfunc, finish,
singleton_type, widenconst, dominates_ssa, ⊑, userefs

# TODO: This really needs to go into a uniform compiler stdlib.
Base.iterate(compact::IncrementalCompact, state) = Core.Compiler.iterate(compact, state)
Base.iterate(compact::IncrementalCompact) = Core.Compiler.iterate(compact)
Base.iterate(abu::CC.AbsIntStackUnwind, state...) = CC.iterate(abu, state...)

Base.setindex!(compact::IncrementalCompact, @nospecialize(v), idx::SSAValue) = Core.Compiler.setindex!(compact,v,idx)
Base.setindex!(ir::IRCode, @nospecialize(v), idx::SSAValue) = Core.Compiler.setindex!(ir,v,idx)
Base.setindex!(inst::Instruction, @nospecialize(v), sym::Symbol) = Core.Compiler.setindex!(inst,v,sym)
Base.getindex(compact::IncrementalCompact, idx::AnySSAValue) = Core.Compiler.getindex(compact,idx)

Base.setindex!(urs::InstructionStream, @nospecialize args...) = Core.Compiler.setindex!(urs, args...)
Base.setindex!(ir::IRCode, @nospecialize args...) = Core.Compiler.setindex!(ir, args...)
Base.getindex(ir::IRCode, @nospecialize args...) = Core.Compiler.getindex(ir, args...)

Base.IteratorSize(::Type{CC.AbsIntStackUnwind}) = Base.SizeUnknown()

# TODO: Move this to Core.Compiler
CC.block_for_inst(ir::IRCode, s::SSAValue) = block_for_inst(ir, s.id)
function CC.dominates_ssa(ir::IRCode, domtree::DomTree, x::SSAValue, y::SSAValue; dominates_after=false)
xb = block_for_inst(ir, x)
yb = block_for_inst(ir, y)
Expand Down Expand Up @@ -82,6 +64,3 @@ function replace_argument!(compact::IncrementalCompact, idx::Int, argn::Argument
compact[ssa] = urs[]
end

Base.copy(phi::PhiNode) = Core.PhiNode(copy(phi.edges), copy(phi.values))
Base.push!(bs::CC.BitSet, i::Int) = CC.push!(bs, i)
Base.push!(bs::CC.BitSetBoundedMinPrioritySet, i::Int) = CC.push!(bs, i)
Loading
Loading