Skip to content

Commit 348b7cf

Browse files
authored
Merge pull request #1341 from SciML/myb/scalarize
collect -> scalarize
2 parents 5446af9 + d6fdb73 commit 348b7cf

File tree

7 files changed

+18
-15
lines changed

7 files changed

+18
-15
lines changed

.github/workflows/Downstream.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,13 @@ jobs:
1717
julia-version: [1]
1818
os: [ubuntu-latest]
1919
package:
20+
- {user: SciML, repo: SciMLBase.jl, group: Downstream}
2021
- {user: SciML, repo: Catalyst.jl, group: All}
2122
- {user: SciML, repo: CellMLToolkit.jl, group: All}
2223
- {user: SciML, repo: SBMLToolkit.jl, group: All}
2324
- {user: SciML, repo: NeuralPDE.jl, group: NNPDE}
2425
- {user: SciML, repo: DataDrivenDiffEq.jl, group: Standard}
2526
- {user: SciML, repo: StructuralIdentifiability.jl, group: All}
26-
2727
steps:
2828
- uses: actions/checkout@v2
2929
- uses: julia-actions/setup-julia@v1

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -173,12 +173,12 @@ function generate_rootfinding_callback(cbs, sys::ODESystem, dvs = states(sys), p
173173

174174
rhss = map(x->x.rhs, eqs)
175175
root_eq_vars = unique(collect(Iterators.flatten(map(ModelingToolkit.vars, rhss))))
176-
176+
177177
u = map(x->time_varying_as_func(value(x), sys), dvs)
178178
p = map(x->time_varying_as_func(value(x), sys), ps)
179179
t = get_iv(sys)
180180
rf_oop, rf_ip = build_function(rhss, u, p, t; expression=Val{false}, kwargs...)
181-
181+
182182
affect_functions = map(cbs) do cb # Keep affect function separate
183183
eq_aff = affect_equations(cb)
184184
affect = compile_affect(eq_aff, sys, dvs, ps; kwargs...)
@@ -233,7 +233,7 @@ function compile_affect(eqs::Vector{Equation}, sys, dvs, ps; kwargs...)
233233
length(update_vars) == length(unique(update_vars)) == length(eqs) ||
234234
error("affected variables not unique, each state can only be affected by one equation for a single `root_eqs => affects` pair.")
235235
vars = states(sys)
236-
236+
237237
u = map(x->time_varying_as_func(value(x), sys), vars)
238238
p = map(x->time_varying_as_func(value(x), sys), ps)
239239
t = get_iv(sys)

src/systems/diffeqs/odesystem.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ function ODESystem(
120120
checks = true,
121121
)
122122
name === nothing && throw(ArgumentError("The `name` keyword must be provided. Please consider using the `@named` macro"))
123-
deqs = collect(deqs)
123+
deqs = scalarize(deqs)
124124
@assert all(control -> any(isequal.(control, ps)), controls) "All controls must also be parameters."
125125

126126
iv′ = value(scalarize(iv))
@@ -152,7 +152,7 @@ function ODESystem(
152152
end
153153

154154
function ODESystem(eqs, iv=nothing; kwargs...)
155-
eqs = collect(eqs)
155+
eqs = scalarize(eqs)
156156
# NOTE: this assumes that the order of algebric equations doesn't matter
157157
diffvars = OrderedSet()
158158
allstates = OrderedSet()
@@ -186,7 +186,7 @@ function ODESystem(eqs, iv=nothing; kwargs...)
186186
end
187187
algevars = setdiff(allstates, diffvars)
188188
# the orders here are very important!
189-
return ODESystem(append!(diffeq, algeeq), iv, vcat(collect(diffvars), collect(algevars)), ps; kwargs...)
189+
return ODESystem(append!(diffeq, algeeq), iv, collect(Iterators.flatten((diffvars, algevars))), ps; kwargs...)
190190
end
191191

192192
# NOTE: equality does not check cached Jacobian
@@ -248,7 +248,7 @@ function build_explicit_observed_function(
248248
vars = Set()
249249
foreach(Base.Fix1(vars!, vars), ts)
250250
ivs = independent_variables(sys)
251-
dep_vars = collect(setdiff(vars, ivs))
251+
dep_vars = scalarize(setdiff(vars, ivs))
252252

253253
obs = observed(sys)
254254
sts = Set(states(sys))

src/systems/diffeqs/sdesystem.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ function SDESystem(deqs::AbstractVector{<:Equation}, neqs, iv, dvs, ps;
109109
checks = true,
110110
)
111111
name === nothing && throw(ArgumentError("The `name` keyword must be provided. Please consider using the `@named` macro"))
112-
deqs = collect(deqs)
112+
deqs = scalarize(deqs)
113113
iv′ = value(iv)
114114
dvs′ = value.(dvs)
115115
ps′ = value.(ps)

src/systems/discrete_system/discrete_system.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ function DiscreteSystem(
8383
kwargs...,
8484
)
8585
name === nothing && throw(ArgumentError("The `name` keyword must be provided. Please consider using the `@named` macro"))
86-
eqs = collect(eqs)
86+
eqs = scalarize(eqs)
8787
iv′ = value(iv)
8888
dvs′ = value.(dvs)
8989
ps′ = value.(ps)
@@ -108,7 +108,7 @@ end
108108

109109

110110
function DiscreteSystem(eqs, iv=nothing; kwargs...)
111-
eqs = collect(eqs)
111+
eqs = scalarize(eqs)
112112
# NOTE: this assumes that the order of algebric equations doesn't matter
113113
diffvars = OrderedSet()
114114
allstates = OrderedSet()
@@ -142,7 +142,7 @@ function DiscreteSystem(eqs, iv=nothing; kwargs...)
142142
end
143143
algevars = setdiff(allstates, diffvars)
144144
# the orders here are very important!
145-
return DiscreteSystem(append!(diffeq, algeeq), iv, vcat(collect(diffvars), collect(algevars)), ps; kwargs...)
145+
return DiscreteSystem(append!(diffeq, algeeq), iv, collect(Iterators.flatten((diffvars, algevars))), ps; kwargs...)
146146
end
147147

148148
"""

src/systems/jumps/jumpsystem.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ function JumpSystem(eqs, iv, states, ps;
7474
checks = true,
7575
kwargs...)
7676
name === nothing && throw(ArgumentError("The `name` keyword must be provided. Please consider using the `@named` macro"))
77-
eqs = collect(eqs)
77+
eqs = scalarize(eqs)
7878
sysnames = nameof.(systems)
7979
if length(unique(sysnames)) != length(sysnames)
8080
throw(ArgumentError("System names must be unique."))

src/systems/nonlinear/nonlinearsystem.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,10 @@ function NonlinearSystem(eqs, states, ps;
7777
throw(ArgumentError("NonlinearSystem does not accept `continuous_events`, you provided $continuous_events"))
7878
name === nothing && throw(ArgumentError("The `name` keyword must be provided. Please consider using the `@named` macro"))
7979
# Move things over, but do not touch array expressions
80-
eqs = [0 ~ x.rhs - x.lhs for x in collect(eqs)]
80+
#
81+
# # we cannot scalarize in the loop because `eqs` itself might require
82+
# scalarization
83+
eqs = [0 ~ x.rhs - x.lhs for x in scalarize(eqs)]
8184

8285
if !(isempty(default_u0) && isempty(default_p))
8386
Base.depwarn("`default_u0` and `default_p` are deprecated. Use `defaults` instead.", :NonlinearSystem, force=true)
@@ -90,7 +93,7 @@ function NonlinearSystem(eqs, states, ps;
9093
defaults = todict(defaults)
9194
defaults = Dict{Any,Any}(value(k) => value(v) for (k, v) in pairs(defaults))
9295

93-
states = collect(states)
96+
states = scalarize(states)
9497
states, ps = value.(states), value.(ps)
9598
var_to_name = Dict()
9699
process_variables!(var_to_name, defaults, states)

0 commit comments

Comments
 (0)