Skip to content

Commit 31365ca

Browse files
authored
Merge pull request #989 from SciML/s/array-fixes
Fixes for symbolic arrays change
2 parents ca65334 + fe58c0a commit 31365ca

18 files changed

+94
-77
lines changed

Project.toml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,8 @@ SciMLBase = "1.3"
6969
Setfield = "0.7"
7070
SpecialFunctions = "0.7, 0.8, 0.9, 0.10, 1.0"
7171
StaticArrays = "0.10, 0.11, 0.12, 1.0"
72-
SymbolicUtils = "0.11.0"
73-
Symbolics = "0.1.21"
72+
SymbolicUtils = "0.12"
73+
Symbolics = "1"
7474
UnPack = "0.1, 1.0"
7575
Unitful = "1.1"
7676
julia = "1.2"
@@ -83,10 +83,11 @@ NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
8383
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
8484
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
8585
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
86+
ReferenceTests = "324d217c-45ce-50fc-942e-d289b448e8cf"
8687
SteadyStateDiffEq = "9672c7b4-1e72-59bd-8a11-6ac3964bc41f"
8788
StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0"
8889
Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4"
8990
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
9091

9192
[targets]
92-
test = ["BenchmarkTools", "ForwardDiff", "GalacticOptim", "NonlinearSolve", "OrdinaryDiffEq", "Optim", "Random", "SteadyStateDiffEq", "Test", "StochasticDiffEq", "Sundials"]
93+
test = ["BenchmarkTools", "ForwardDiff", "GalacticOptim", "NonlinearSolve", "OrdinaryDiffEq", "Optim", "Random", "ReferenceTests", "SteadyStateDiffEq", "Test", "StochasticDiffEq", "Sundials"]

src/ModelingToolkit.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ import JuliaFormatter
4040
using Reexport
4141
@reexport using Symbolics
4242
export @derivatives
43-
using Symbolics: _parse_vars, value, makesym, @derivatives, get_variables,
43+
using Symbolics: _parse_vars, value, @derivatives, get_variables,
4444
exprs_occur_in, solve_for, build_expr
4545
import Symbolics: rename, get_variables!, _solve, hessian_sparsity,
4646
jacobian_sparsity, islinear, _iszero, _isone,
@@ -49,7 +49,7 @@ import Symbolics: rename, get_variables!, _solve, hessian_sparsity,
4949
ParallelForm, SerialForm, MultithreadedForm, build_function,
5050
unflatten_long_ops, rhss, lhss, prettify_expr, gradient,
5151
jacobian, hessian, derivative, sparsejacobian, sparsehessian,
52-
substituter
52+
substituter, scalarize
5353

5454
import DiffEqBase: @add_kwonly
5555

src/parameters.jl

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,17 @@ isparameter(x) = false
1010
1111
Maps the variable to a paramter.
1212
"""
13-
toparam(s::Symbolic) = setmetadata(s, MTKParameterCtx, true)
13+
function toparam(s)
14+
if s isa Symbolics.Arr
15+
Symbolics.wrap(toparam(Symbolics.unwrap(s)))
16+
elseif s isa AbstractArray
17+
map(toparam, s)
18+
elseif symtype(s) <: AbstractArray
19+
Symbolics.recurse_and_apply(toparam, s)
20+
else
21+
setmetadata(s, MTKParameterCtx, true)
22+
end
23+
end
1424
toparam(s::Num) = Num(toparam(value(s)))
1525

1626
"""
@@ -30,6 +40,6 @@ macro parameters(xs...)
3040
Symbolics._parse_vars(:parameters,
3141
Real,
3242
xs,
33-
x -> x isa Array ? toparam.(x) : toparam(x)
43+
toparam,
3444
) |> esc
3545
end

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,6 @@ function generate_function(
7979
# substitute x(t) by just x
8080
rhss = implicit_dae ? [_iszero(eq.lhs) ? eq.rhs : eq.rhs - eq.lhs for eq in eqs] :
8181
[eq.rhs for eq in eqs]
82-
#obss = [makesym(value(eq.lhs)) ~ substitute(eq.rhs, sub) for eq ∈ observed(sys)]
8382
#rhss = Let(obss, rhss)
8483

8584
# TODO: add an optional check on the ordering of observed equations

src/systems/diffeqs/odesystem.jl

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -91,9 +91,9 @@ function ODESystem(
9191
defaults=_merge(Dict(default_u0), Dict(default_p)),
9292
connection_type=nothing,
9393
)
94-
iv′ = value(iv)
95-
dvs′ = value.(dvs)
96-
ps′ = value.(ps)
94+
iv′ = value(scalarize(iv))
95+
dvs′ = value.(scalarize(dvs))
96+
ps′ = value.(scalarize(ps))
9797

9898
if !(isempty(default_u0) && isempty(default_p))
9999
Base.depwarn("`default_u0` and `default_p` are deprecated. Use `defaults` instead.", :ODESystem, force=true)
@@ -117,11 +117,19 @@ vars(exprs::Symbolic) = vars([exprs])
117117
vars(exprs) = foldl(vars!, exprs; init = Set())
118118
vars!(vars, eq::Equation) = (vars!(vars, eq.lhs); vars!(vars, eq.rhs); vars)
119119
function vars!(vars, O)
120-
isa(O, Sym) && return push!(vars, O)
120+
if isa(O, Sym)
121+
return push!(vars, O)
122+
end
121123
!istree(O) && return vars
122124

123125
operation(O) isa Differential && return push!(vars, O)
124126

127+
if operation(O) === (getindex) &&
128+
first(arguments(O)) isa Symbolic
129+
130+
return push!(vars, O)
131+
end
132+
125133
operation(O) isa Sym && push!(vars, O)
126134
for arg in arguments(O)
127135
vars!(vars, arg)

src/systems/diffeqs/sdesystem.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,9 @@ function DiffEqBase.SDEFunction{iip}(sys::SDESystem, dvs = states(sys), ps = par
175175
u0 = nothing;
176176
version = nothing, tgrad=false, sparse = false,
177177
jac = false, Wfact = false, eval_expression = true, kwargs...) where {iip}
178+
dvs = scalarize.(dvs)
179+
ps = scalarize.(ps)
180+
178181
f_gen = generate_function(sys, dvs, ps; expression=Val{eval_expression}, kwargs...)
179182
f_oop,f_iip = eval_expression ? (@RuntimeGeneratedFunction(ex) for ex in f_gen) : f_gen
180183
g_gen = generate_diffusion_function(sys, dvs, ps; expression=Val{eval_expression}, kwargs...)

src/systems/nonlinear/nonlinearsystem.jl

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -103,18 +103,11 @@ end
103103
function generate_function(sys::NonlinearSystem, dvs = states(sys), ps = parameters(sys); kwargs...)
104104
#obsvars = map(eq->eq.lhs, observed(sys))
105105
#fulldvs = [dvs; obsvars]
106-
fulldvs = dvs
107-
fulldvs′ = makesym.(value.(fulldvs))
108106

109-
sub = Dict(fulldvs .=> fulldvs′)
110-
# substitute x(t) by just x
111-
rhss = [substitute(deq.rhs, sub) for deq equations(sys)]
112-
#obss = [makesym(value(eq.lhs)) ~ substitute(eq.rhs, sub) for eq ∈ observed(sys)]
107+
rhss = [deq.rhs for deq equations(sys)]
113108
#rhss = Let(obss, rhss)
114109

115-
dvs′ = fulldvs′[1:length(dvs)]
116-
ps′ = makesym.(value.(ps), states=())
117-
return build_function(rhss, dvs′, ps′;
110+
return build_function(rhss, value.(dvs), value.(ps);
118111
conv = AbstractSysToExpr(sys), kwargs...)
119112
end
120113

src/systems/reaction/reactionsystem.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -467,11 +467,13 @@ function Base.convert(::Type{<:SDESystem}, rs::ReactionSystem;
467467
noise_scaling=nothing, name=nameof(rs), combinatoric_ratelaws=true,
468468
include_zero_odes=true, kwargs...)
469469

470-
if noise_scaling isa Vector
470+
if noise_scaling isa AbstractArray
471471
(length(noise_scaling)!=length(equations(rs))) &&
472472
error("The number of elements in 'noise_scaling' must be equal " *
473473
"to the number of reactions in the reaction system.")
474-
noise_scaling = value.(noise_scaling)
474+
if !(noise_scaling isa Symbolics.Arr)
475+
noise_scaling = value.(noise_scaling)
476+
end
475477
elseif !isnothing(noise_scaling)
476478
noise_scaling = fill(value(noise_scaling),length(equations(rs)))
477479
end
@@ -482,7 +484,7 @@ function Base.convert(::Type{<:SDESystem}, rs::ReactionSystem;
482484
combinatoric_ratelaws=combinatoric_ratelaws)
483485
systems = convert.(SDESystem, get_systems(rs))
484486
SDESystem(eqs, noiseeqs, get_iv(rs), get_states(rs),
485-
(noise_scaling===nothing) ? get_ps(rs) : union(get_ps(rs), toparam.(noise_scaling));
487+
(noise_scaling===nothing) ? get_ps(rs) : union(get_ps(rs), toparam(noise_scaling));
486488
name=name,
487489
systems=systems,
488490
kwargs...)

src/systems/systemstructure.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ function initialize_system_structure(sys)
180180
for algvar in algvars
181181
# it could be that a variable appeared in the states, but never appeared
182182
# in the equations.
183-
algvaridx = get(var2idx, algvar, 0)
183+
algvaridx = var2idx[algvar]
184184
vartype[algvaridx] = ALGEBRAIC_VARIABLE
185185
end
186186

test/bigsystem.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
using ModelingToolkit, LinearAlgebra, SparseArrays
2+
using Symbolics
3+
using Symbolics: scalarize
24

35
# Define the constants for the PDE
46
const α₂ = 1.0
@@ -27,6 +29,9 @@ My[end,end-1] = 2.0
2729
# Define the initial condition as normal arrays
2830
@variables du[1:N,1:N,1:3] u[1:N,1:N,1:3] MyA[1:N,1:N] AMx[1:N,1:N] DA[1:N,1:N]
2931

32+
du,u,MyA,AMx,DA = scalarize.((du,u,MyA,AMx,DA))
33+
@show typeof.((du,u,MyA,AMx,DA))
34+
3035
# Define the discretized PDE as an ODE function
3136
function f(du,u,p,t)
3237
A = @view u[:,:,1]

0 commit comments

Comments
 (0)