Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
[sources]
Compiler = {rev = "master", url = "https://github.com/JuliaLang/BaseCompiler.jl.git"}
Cthulhu = {rev = "master", url = "https://github.com/JuliaDebug/Cthulhu.jl.git"}
DifferentiationInterface = {rev = "main", subdir = "DifferentiationInterface", url = "https://github.com/Keno/DifferentiationInterface.jl#main"}
DifferentiationInterface = {rev = "main", subdir = "DifferentiationInterface", url = "https://github.com/Keno/DifferentiationInterface.jl"}
Diffractor = {rev = "main", url = "https://github.com/JuliaDiff/Diffractor.jl.git"}
SimpleNonlinearSolve = {rev = "master", subdir = "lib/SimpleNonlinearSolve", url = "https://github.com/SciML/NonlinearSolve.jl.git"}
StateSelection = {rev = "main", url = "https://github.com/JuliaComputing/StateSelection.jl.git"}
Expand Down
3 changes: 3 additions & 0 deletions benchmark/Project.toml
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
[deps]
Compiler = "807dbc54-b67e-4c79-8afb-eafe4df6f2e1"
DAECompiler = "32805668-c3d0-42c2-aafd-0d0a9857a104"
Polynomials = "f27b6e38-b328-58d1-80ce-0feddd5e7a45"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4"
XSteam = "95ff35a0-be81-11e9-2ca3-5b4e338e8476"

[sources]
Compiler = {rev = "master", url = "https://github.com/JuliaLang/BaseCompiler.jl.git"}
DAECompiler = {path = "."}
XSteam = {rev = "master", url = "https://github.com/hzgzh/XSteam.jl.git"}
110 changes: 78 additions & 32 deletions src/analysis/ipoincidence.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,49 +45,95 @@ end
apply_linear_incidence(𝕃, ret::Type, caller::CallerMappingState, mapping::CalleeMapping) = ret
apply_linear_incidence(𝕃, ret::Const, caller::CallerMappingState, mapping::CalleeMapping) = ret
function apply_linear_incidence(𝕃, ret::Incidence, caller::Union{CallerMappingState, Nothing}, mapping::CalleeMapping)
coeffs = mapping.var_coeffs

const_val = ret.typ
new_row = _zero_row()

for (v_offset, coeff) in zip(rowvals(ret.row), nonzeros(ret.row))
v = v_offset - 1
# Substitute variables returned by the callee with the incidence defined by the caller.
# The composition will be additive in the constant terms, and multiplicative for linear coefficients.
caller_variables = mapping.var_coeffs

typ = ret.typ
row = _zero_row()

used_caller_variables = Int[]
for i in rowvals(ret.row)
i == 1 && continue # skip time
v = i - 1
if !isassigned(caller_variables, v)
compute_missing_coeff!(caller_variables, caller::CallerMappingState, v)
end
substitution = caller_variables[i - 1]
isa(substitution, Incidence) || continue
for j in rowvals(substitution.row)
push!(used_caller_variables, j)
end
end
did_use_time = in(1, used_caller_variables)

for (i, coeff) in zip(rowvals(ret.row), nonzeros(ret.row))
# Time dependence persists as itself
if v == 0
new_row[v_offset] += coeff
if i == 1
row[i] = coeff
continue
end

if !isassigned(coeffs, v)
@assert caller !== nothing
compute_missing_coeff!(coeffs, caller, v)
end

replacement = coeffs[v]
if isa(replacement, Incidence)
new_row .+= replacement.row .* coeff
else
if isa(replacement, Const)
if isa(const_val, Const)
new_const_val = const_val.val + replacement.val * coeff
if isa(new_const_val, Float64)
const_val = Const(new_const_val)
else
const_val = widenconst(const_val)
end
v = i - 1
substitution = caller_variables[v]
if isa(substitution, Incidence)
# Distribute the coefficient to all terms.
# Because the coefficient is expressed in the reference of the callee,
# state dependence must be processed carefully.
typ = compose_additive_term(typ, substitution.typ, coeff)
for (j, substitute) in zip(rowvals(substitution.row), nonzeros(substitution.row))
row[j] === nonlinear && continue # no more information to be gained
if substitute === nonlinear || coeff === nonlinear
row[j] = nonlinear
elseif isa(coeff, Float64)
row[j] += coeff * substitute
else
const_val = widenconst(const_val)
time_dependent = coeff.time_dependent
state_dependent = false
if isa(substitute, Linearity)
time_dependent |= substitute.time_dependent
state_dependent |= substitute.state_dependent
end
if coeff.state_dependent
if coeff.time_dependent && did_use_time
# The term is at least bilinear in another state, and this state
# from the callee may alias time from the caller, so we must mark
# time as nonlinear.
row[1] = nonlinear
end
if count(==(j), used_caller_variables) > 1
# The term is at least bilinear in another state, but we don't
# know which state, so we must fall back to nonlinear.
row[j] = nonlinear
continue
end
# We'll only be state-dependent if variables from the callee
# map to at least one other variable than `j`.
state_dependent |= length(used_caller_variables) - did_use_time > 1
# If another state may contain time, we may be time-dependent too.
time_dependent |= did_use_time
end
j == 1 && (time_dependent = false)
row[j] += Linearity(; nonlinear = false, state_dependent, time_dependent)
end
else
# The replacement has some unknown type - we need to widen
# all the way here.
return widenconst(const_val)
end
elseif isa(substitution, Const)
typ = compose_additive_term(typ, substitution, coeff)
else
return widenconst(typ) # unknown lattice element, we should widen
end
end

return Incidence(const_val, new_row)
return Incidence(typ, row)
end

function compose_additive_term(@nospecialize(a), @nospecialize(b), coeff)
isa(a, Const) || return widenconst(a)
isa(b, Const) || return widenconst(a)
isa(coeff, Linearity) && return widenconst(a)
val = a.val + b.val * coeff
isa(val, Float64) || return widenconst(a)
return Const(val)
end

function apply_linear_incidence(𝕃, ret::Eq, caller::CallerMappingState, mapping::CalleeMapping)
Expand Down
170 changes: 115 additions & 55 deletions src/analysis/lattice.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,56 +39,82 @@ Compiler.widenlattice(::EqStructureLattice) = Compiler.ConstsLattice()
Compiler.is_valid_lattice_norec(::EqStructureLattice, @nospecialize(v)) = isa(v, Incidence) || isa(v, Eq) || isa(v, PartialScope) || isa(v, PartialKeyValue)
Compiler.has_extended_unionsplit(::EqStructureLattice) = true

############################## NonLinear #######################################
############################## Linearity #######################################

# XXX: update docstring
"""
struct NonLinear
struct Linearity
nonlinear::Bool = true
end

This struct expresses linearity information: linear or nonlinear.

This singleton number is similar to `missing` in that arithmatic with it is
Nonlinearity is similar to `missing` in that arithmetic with it is
saturating. When used as a coefficient in the Incidence lattice, it indicates
that the corresponding variable does not have a (constant) linear coefficient.
This may either mean that the variable in question has a non-constant linear
coefficient or that the variable is used non-linearly. We do not currently
distinguish the two situations.
"""
struct NonLinear; end
Base.iszero(::NonLinear) = false
Base.zero(::Type{Union{NonLinear, Float64}}) = 0.
Base.@kwdef struct Linearity
time_dependent::Bool = true
state_dependent::Bool = true
nonlinear::Bool = true
function Linearity(time_dependent, state_dependent, nonlinear)
if nonlinear && (!time_dependent || !state_dependent)
throw(ArgumentError("Modeling of state or time independence is not supported for nonlinearities"))
end
new(time_dependent, state_dependent, nonlinear)
end
end

const linear = Linearity(time_dependent = false, state_dependent = false, nonlinear = false)
const linear_time_dependent = Linearity(state_dependent = false, nonlinear = false)
const linear_state_dependent = Linearity(time_dependent = false, nonlinear = false)
const linear_time_and_state_dependent = Linearity(nonlinear = false)
const nonlinear = Linearity()

join_linearity(a::Linearity, b::Real) = a
join_linearity(a::Real, b::Linearity) = b
join_linearity(a::Real, b::Real) = a == b ? promote(a, b) : Linearity(time_dependent = false, state_dependent = false, nonlinear = false)
function join_linearity(a::Linearity, b::Linearity)
(a.nonlinear | b.nonlinear) && return nonlinear
return Linearity(; time_dependent = a.time_dependent | b.time_dependent, state_dependent = a.state_dependent | b.state_dependent, nonlinear = false)
end

Base.iszero(::Linearity) = false
Base.zero(::Type{Union{Linearity, Float64}}) = 0.
for f in (:+, :-)
@eval begin
Base.$f(a::Real, b::NonLinear) = b
Base.$f(a::NonLinear, b::Real) = a
Base.$f(a::NonLinear, b::NonLinear) = nonlinear
Base.$f(::NonLinear) = nonlinear
Base.$f(a::Real, b::Linearity) = b
Base.$f(a::Linearity, b::Real) = a
Base.$f(a::Linearity, b::Linearity) = join_linearity(a, b)
Base.$f(a::Linearity) = a
end
end

Base.:(*)(a::Real, b::NonLinear) = iszero(a) ? a : b
Base.:(*)(a::NonLinear, b::Real) = iszero(b) ? b : a
Base.:(*)(a::NonLinear, b::NonLinear) = nonlinear
Base.div(a::Real, b::NonLinear) = iszero(a) ? a : b
Base.div(a::NonLinear, b::Real) = a
Base.div(a::NonLinear, b::NonLinear) = a
Base.:(/)(a::Real, b::NonLinear) = iszero(a) ? a : b
Base.:(/)(a::NonLinear, b::Real) = a
Base.:(/)(a::NonLinear, b::NonLinear) = a
Base.rem(a::Real, b::NonLinear) = b
Base.rem(a::NonLinear, b::Real) = a
Base.rem(a::NonLinear, b::NonLinear) = a
Base.abs(a::NonLinear) = nonlinear
Base.isone(a::NonLinear) = false

const nonlinear = NonLinear.instance
Base.Broadcast.broadcastable(::NonLinear) = Ref(nonlinear)
Base.:(*)(a::Real, b::Linearity) = iszero(a) ? a : b
Base.:(*)(a::Linearity, b::Real) = iszero(b) ? b : a
Base.div(a::Real, b::Linearity) = iszero(a) ? a : nonlinear
Base.div(a::Linearity, b::Real) = a
Base.div(a::Linearity, b::Linearity) = nonlinear
Base.:(/)(a::Real, b::Linearity) = iszero(a) ? a : nonlinear
Base.:(/)(a::Linearity, b::Real) = a
Base.:(/)(a::Linearity, b::Linearity) = nonlinear
Base.rem(a::Real, b::Linearity) = b
Base.rem(a::Linearity, b::Real) = a
Base.rem(a::Linearity, b::Linearity) = a
Base.abs(a::Linearity) = nonlinear
Base.isone(a::Linearity) = false

Base.Broadcast.broadcastable(x::Linearity) = Ref(x)

############################## Incidence #######################################
# TODO: Just use Infinities.jl here?
const MAX_EQS = div(typemax(Int), 2)

# For now, we only track exact, integer linearities, because that's what
# MTK can handle, so `nonlinear` includes linear operations with floating
# point values.
const IncidenceVector = SparseVector{Union{Float64, NonLinear}, Int}
const IncidenceValue = Union{Float64, Linearity}
const IncidenceVector = SparseVector{IncidenceValue, Int}

is_non_incidence_type(@nospecialize(type)) = type === Union{} || Base.issingletontype(type)

Expand All @@ -113,6 +139,24 @@ struct Incidence
if is_non_incidence_type(type)
throw(DomainError(type, "Invalid type for Incidence"))
end
row = convert(IncidenceVector, row)
time = row[1]
if in(time, (linear_time_dependent, linear_time_and_state_dependent))
throw(ArgumentError("Time incidence cannot be both linear and time-dependent, otherwise it would be nonlinear"))
end
for (i, coeff) in zip(rowvals(row), nonzeros(row))
isa(coeff, Linearity) || continue
coeff.nonlinear && continue
if coeff.time_dependent && !in(1, rowvals(row))
throw(ArgumentError("Time-dependent incidence annotation for $(subscript_state(i)) is inconsistent with an absence of time incidence"))
end
if coeff.state_dependent && !any(x -> x != 1, rowvals(row))
throw(ArgumentError("State-dependent incidence annotation for $(subscript_state(i)) is inconsistent with an absence of state incidence"))
end
if i > 1 && coeff.time_dependent && (isa(time, Float64) || !time.state_dependent)
throw(ArgumentError("Time-dependent state incidence for $(subscript_state(i)) is inconsistent with an absence of state dependence for time"))
end
end
return new(type, row)
end
end
Expand Down Expand Up @@ -148,33 +192,52 @@ function Base.show(io::IO, inc::Incidence)
else
print(io, minus ? " - " : " + ")
end
time = inc.row[1]
time_linear = time !== nonlinear
is_grouped(v, i) = isa(v, Linearity) && (v.state_dependent || (v.time_dependent || i == 1) && in(time, (linear_state_dependent, nonlinear)))
for (i, v) in zip(rowvals(inc.row), nonzeros(inc.row))
v !== nonlinear || continue
print_plusminus(io, v < 0)
if abs(v) != 1 || i == 1
print(io, abs(v))
is_grouped(v, i) && continue
if isa(v, Float64)
print_plusminus(io, v < 0)
if abs(v) != 1
print(io, abs(v))
end
else
!first && print(io, " + ")
first = false
if !is_grouped(inc.row[1], 1)
ᵢ = i > 1 ? subscript(i - 1) : 'ₜ'
if v.time_dependent
print(io, time_linear ? "∝t" : "f$ᵢ(t)", " * ")
else # unknown constant coefficient
print(io, "c$ᵢ * ")
end
end
end
print(io, subscript_state(i))
end
first_nonlinear = true
for (i, v) in zip(rowvals(inc.row), nonzeros(inc.row))
v === nonlinear || continue
if first_nonlinear
print_plusminus(io)
first_nonlinear = false
print(io, "f(")
else
print(io, ", ")
first_grouped = true
if any(i -> is_grouped(inc.row[i], i), rowvals(inc.row))
for (i, v) in zip(rowvals(inc.row), nonzeros(inc.row))
!is_grouped(v, i) && continue
if first_grouped
print_plusminus(io)
first_grouped = false
print(io, "f(")
else
print(io, ", ")
end
!v.nonlinear && print(io, '∝')
print(io, subscript_state(i))
end
if !first_grouped
print(io, ")")
end
print(io, i == 1 ? "t" : subscript_state(i))
end
if !first_nonlinear
print(io, ")")
end
print(io, ")")
end

_zero_row() = IncidenceVector(MAX_EQS, Int[], Union{Float64, NonLinear}[])
_zero_row() = IncidenceVector(MAX_EQS, Int[], IncidenceValue[])
const _ZERO_ROW = _zero_row()
const _ZERO_CONST = Const(0.0)
Base.zero(::Type{Incidence}) = Incidence(_ZERO_CONST, _zero_row())
Expand Down Expand Up @@ -590,17 +653,14 @@ function _aggressive_incidence_join(@nospecialize(rt), argtypes::Vector{Any})
for (i, v) in zip(rowvals(a.row), nonzeros(a.row))
# as long as they are equal then it is correct for both so nothing to do
if inci.row[i] != v
# Otherwise it can't be either but must allow both. We would ideally represent this as
# `LinearUnion{rr[i], v}` or `Linear`, but we don't have lattice elements like that
# `NonLinear` is our more general representation
inci.row[i] = nonlinear
# Otherwise it can't be either but must allow both.
inci.row[i] = join_linearity(inci.row[i], v)
end
end
# and the the other way: catch places that `rr` is nonzero but `aa` is zero
for (i, v) in zip(rowvals(inci.row), nonzeros(inci.row))
if a.row[i] != v
# mix of nonlinear and linear, or again: a mix of two different linear coefficients
inci.row[i] = nonlinear
inci.row[i] = join_linearity(a.row[i], v)
end
end
end
Expand Down
Loading
Loading