Skip to content

Commit c311ce1

Browse files
Simplify type handling in non-concrete varmap
It's hard to come up with a useful example, but: ```julia using OrdinaryDiffEq, Catalyst, DiffEqFlux, ModelingToolkit NN(S, I, R, p) = FastChain(FastDense(3,10,tanh),FastDense(10,1))([S, I, R], p)[1] @register NN(S, I, R, p) rn = @reaction_network begin β, S + I --> 2I γ, I --> R NN(S, I, R, p3n), I --> Q δ, Q --> R end β γ δ p3n _p3n = Float64.(initial_params(FastChain(FastDense(3,10,tanh),FastDense(10,1)))) @parameters β γ δ p3n p = [β=>1.0,γ=>1.0,δ=>1.0,p3n=>_p3n] # [α,β] tspan = (0.0,250.0) u0 = [999.0,0.0,1.0,0.0] # [S,I,R] at t=0 op = ODEProblem(rn, u0, tspan, p) sol = solve(op,Tsit5()) ``` is such a thing. We can try to handle this better in the future but at least for now it gives `Any` and keeps `Any`.
1 parent 7876118 commit c311ce1

File tree

3 files changed

+6
-14
lines changed

3 files changed

+6
-14
lines changed

src/systems/reaction/reactionsystem.jl

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -487,24 +487,18 @@ end
487487

488488
# ODEProblem from AbstractReactionNetwork
489489
function DiffEqBase.ODEProblem(rs::ReactionSystem, u0::Union{AbstractArray, Number}, tspan, p=DiffEqBase.NullParameters(), args...; kwargs...)
490-
u0 = typeof(u0) <: Array{<:Pair} ? u0 : Pair.(rs.states,u0)
491-
p = typeof(p) <: Union{Array{<:Pair},DiffEqBase.NullParameters} ? p : Pair.(rs.ps,p)
492490
return ODEProblem(convert(ODESystem,rs),u0,tspan,p, args...; kwargs...)
493491
end
494492

495493
# SDEProblem from AbstractReactionNetwork
496494
function DiffEqBase.SDEProblem(rs::ReactionSystem, u0::Union{AbstractArray, Number}, tspan, p=DiffEqBase.NullParameters(), args...; noise_scaling=nothing, kwargs...)
497495
sde_sys = convert(SDESystem,rs,noise_scaling=noise_scaling)
498-
u0 = typeof(u0) <: Array{<:Pair} ? u0 : Pair.(rs.states,u0)
499-
p = typeof(p) <: Union{Array{<:Pair},DiffEqBase.NullParameters} ? p : Pair.(sde_sys.ps,p)
500496
p_matrix = zeros(length(rs.states), length(rs.eqs))
501497
return SDEProblem(sde_sys,u0,tspan,p,args...; noise_rate_prototype=p_matrix,kwargs...)
502498
end
503499

504500
# DiscreteProblem from AbstractReactionNetwork
505501
function DiffEqBase.DiscreteProblem(rs::ReactionSystem, u0::Union{AbstractArray, Number}, tspan::Tuple, p=DiffEqBase.NullParameters(), args...; kwargs...)
506-
u0 = typeof(u0) <: Array{<:Pair} ? u0 : Pair.(rs.states,u0)
507-
p = typeof(p) <: Union{Array{<:Pair},DiffEqBase.NullParameters} ? p : Pair.(rs.ps,p)
508502
return DiscreteProblem(convert(JumpSystem,rs), u0,tspan,p, args...; kwargs...)
509503
end
510504

@@ -515,8 +509,6 @@ end
515509

516510
# SteadyStateProblem from AbstractReactionNetwork
517511
function DiffEqBase.SteadyStateProblem(rs::ReactionSystem, u0::Union{AbstractArray, Number}, p=DiffEqBase.NullParameters(), args...; kwargs...)
518-
#u0 = typeof(u0) <: Array{<:Pair} ? u0 : Pair.(rs.states,u0)
519-
#p = typeof(p) <: Union{Array{<:Pair},DiffEqBase.NullParameters} ? p : Pair.(rs.ps,p)
520512
return SteadyStateProblem(ODEFunction(convert(ODESystem,rs)),u0,p, args...; kwargs...)
521513
end
522514

src/utils.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -271,11 +271,11 @@ function lower_varname(t::Term, iv)
271271
end
272272
lower_varname(t::Sym, iv) = t
273273

274-
function lower_mapnames(umap::AbstractArray{<:Pair}) where T
275-
[value(k) => value(v) for (k, v) in umap]
274+
function lower_mapnames(umap::AbstractArray{T}) where {T<:Pair}
275+
T[value(k) => value(v) for (k, v) in umap]
276276
end
277-
function lower_mapnames(umap::AbstractArray{<:Pair},name) where T
278-
[lower_varname(value(k), name) => value(v) for (k, v) in umap]
277+
function lower_mapnames(umap::AbstractArray{T},name) where {T<:Pair}
278+
T[lower_varname(value(k), name) => value(v) for (k, v) in umap]
279279
end
280280
lower_mapnames(umap::AbstractArray{<:Number}) = umap # Ambiguity
281281
lower_mapnames(umap::AbstractArray{<:Number},name) = umap

src/variables.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -211,8 +211,8 @@ varmap_to_vars(varmap,varlist)
211211
Takes a list of pairs of variables=>values and an ordered list of variables and
212212
creates the array of values in the correct order
213213
"""
214-
function varmap_to_vars(varmap::AbstractArray{<:Pair},varlist)
215-
out = similar(varmap,typeof(last(first(varmap))))
214+
function varmap_to_vars(varmap::AbstractArray{Pair{T,S}},varlist) where {T,S}
215+
out = similar(varmap,S)
216216
for (ivar, ival) in varmap
217217
j = findfirst(isequal(ivar),varlist)
218218
if isnothing(j)

0 commit comments

Comments
 (0)