Skip to content

Commit 04c7737

Browse files
committed
Refactored collect_constants.
1 parent 2e61eb5 commit 04c7737

File tree

1 file changed

+28
-38
lines changed

1 file changed

+28
-38
lines changed

src/utils.jl

Lines changed: 28 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -504,56 +504,46 @@ function collect_var!(states, parameters, var, iv)
504504
return nothing
505505
end
506506

507-
function collect_constants(eqs::Vector{Equation}) #For get_substitutions_and_solved_states
508-
constants = []
509-
for eq in eqs
510-
collect_constants!(constants, eq.lhs)
511-
collect_constants!(constants, eq.rhs)
512-
end
513-
return constants
514-
end
515-
516-
function collect_constants(eqs::AbstractArray{T}) where {T <: Union{Num, Symbolic}} # For generate_tgrad / generate_jacobian / generate_difference_cb
517-
constants = T[]
518-
for eq in eqs
519-
collect_constants!(constants, unwrap(eq))
520-
end
507+
function collect_constants(x)
508+
constants = Symbolics.Sym[]
509+
collect_constants!(constants, x)
521510
return constants
522511
end
523512

524-
function collect_constants(eqs::Vector{Matrix{Num}}) # For nonlinear hessian
525-
constants = Num[]
526-
for m in eqs
527-
for n in m
528-
collect_constants!(constants, unwrap(n))
529-
end
513+
function collect_constants!(constants, arr::AbstractArray{T}) where {T}
514+
for el in arr
515+
collect_constants!(constants, el)
530516
end
531-
return constants
532517
end
533518

534-
collect_constants(x::Num) = collect_constants(unwrap(x))
535-
function collect_constants(expr::Symbolic{T}) where {T} # For jump system affect / rate
536-
constants = Symbolic[]
537-
collect_constants!(constants, expr)
538-
return constants
519+
function collect_constants!(constants, eq::Equation)
520+
collect_constants!(constants, eq.lhs)
521+
collect_constants!(constants, eq.rhs)
539522
end
540523

541-
function collect_constant!(constants, var)
542-
if isconstant(var)
543-
push!(constants, var)
544-
end
545-
return nothing
546-
end
524+
collect_constants!(constants, x::Num) = collect_constants!(constants, unwrap(x))
525+
collect_constants!(constants, x::Real) = nothing
526+
collect_constants(n::Nothing) = Symbolics.Sym[]
547527

548-
function collect_constants!(constants, expr)
549-
if expr isa Sym
550-
collect_constant!(constants, expr)
528+
function collect_constants!(constants, expr::Symbolics.Symbolic{T}) where {T}
529+
if expr isa Sym && isconstant(expr)
530+
push!(constants, expr)
551531
else
552-
for var in vars(expr)
553-
collect_constant!(constants, var)
532+
evars = vars(expr)
533+
if length(evars) == 1 && isequal(only(evars), expr)
534+
return nothing #avoid infinite recursion for vars(x(t)) == [x(t)]
535+
else
536+
for var in evars
537+
collect_constants!(constants, var)
538+
end
554539
end
555540
end
556-
return nothing
541+
end
542+
543+
""" Replace symbolic constants with their literal values """
544+
function eliminate_constants(eqs::AbstractArray{<:Union{Equation, Symbolic}}, cs::Vector{Sym})
545+
cmap = Dict(x => getdefault(x) for x in cs)
546+
return substitute(eqs, cmap)
557547
end
558548

559549
function get_preprocess_constants(eqs)

0 commit comments

Comments
 (0)