Skip to content

Commit 0673148

Browse files
Merge pull request #1784 from AayushSabharwal/paramsyms
Use `paramsyms` from new `DiffEqFunction`s
2 parents aac6d80 + 2e1422b commit 0673148

File tree

7 files changed

+37
-12
lines changed

7 files changed

+37
-12
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ NonlinearSolve = "0.3.8"
7070
RecursiveArrayTools = "2.3"
7171
Reexport = "0.2, 1"
7272
RuntimeGeneratedFunctions = "0.4.3, 0.5"
73-
SciMLBase = "1.56.1"
73+
SciMLBase = "1.58.0"
7474
Setfield = "0.7, 0.8, 1"
7575
SpecialFunctions = "0.7, 0.8, 0.9, 0.10, 1.0, 2"
7676
StaticArrays = "0.10, 0.11, 0.12, 1.0"

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,7 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem, dvs = s
364364
jac_prototype = jac_prototype,
365365
syms = Symbol.(states(sys)),
366366
indepsym = Symbol(get_iv(sys)),
367+
paramsyms = Symbol.(ps),
367368
observed = observedfun,
368369
sparsity = sparsity ? jacobian_sparsity(sys) : nothing)
369370
end
@@ -449,9 +450,9 @@ function DiffEqBase.DAEFunction{iip}(sys::AbstractODESystem, dvs = states(sys),
449450
sys = sys,
450451
jac = _jac === nothing ? nothing : _jac,
451452
syms = Symbol.(dvs),
453+
indepsym = Symbol(get_iv(sys)),
454+
paramsyms = Symbol.(ps),
452455
jac_prototype = jac_prototype,
453-
# missing fields in `DAEFunction`
454-
#indepsym = Symbol(get_iv(sys)),
455456
observed = observedfun)
456457
end
457458

@@ -534,7 +535,8 @@ function ODEFunctionExpr{iip}(sys::AbstractODESystem, dvs = states(sys),
534535
mass_matrix = M,
535536
jac_prototype = $jp_expr,
536537
syms = $(Symbol.(states(sys))),
537-
indepsym = $(QuoteNode(Symbol(get_iv(sys)))))
538+
indepsym = $(QuoteNode(Symbol(get_iv(sys)))),
539+
paramsyms = $(QuoteNode(Symbol.(parameters(sys)))))
538540
end
539541
!linenumbers ? striplines(ex) : ex
540542
end

src/systems/diffeqs/sdesystem.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -421,6 +421,8 @@ function DiffEqBase.SDEFunction{iip}(sys::SDESystem, dvs = states(sys),
421421
Wfact_t = _Wfact_t === nothing ? nothing : _Wfact_t,
422422
mass_matrix = _M,
423423
syms = Symbol.(states(sys)),
424+
indepsym = Symbol(get_iv(sys)),
425+
paramsyms = Symbol.(ps),
424426
observed = observedfun)
425427
end
426428

@@ -505,7 +507,9 @@ function SDEFunctionExpr{iip}(sys::SDESystem, dvs = states(sys),
505507
Wfact = Wfact,
506508
Wfact_t = Wfact_t,
507509
mass_matrix = M,
508-
syms = $(Symbol.(states(sys))))
510+
syms = $(Symbol.(states(sys))),
511+
indepsym = $(Symbol(get_iv(sys))),
512+
paramsyms = $(Symbol.(parameters(sys))))
509513
end
510514
!linenumbers ? striplines(ex) : ex
511515
end

src/systems/discrete_system/discrete_system.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,8 @@ function SciMLBase.DiscreteProblem(sys::DiscreteSystem, u0map, tspan,
206206
expression_module = eval_module)
207207
f_oop, _ = (@RuntimeGeneratedFunction(eval_module, ex) for ex in f_gen)
208208
f(u, p, iv) = f_oop(u, p, iv)
209-
fd = DiscreteFunction(f; syms = Symbol.(dvs), sys = sys)
209+
fd = DiscreteFunction(f; syms = Symbol.(dvs), indepsym = Symbol(iv),
210+
paramsyms = Symbol.(ps), sys = sys)
210211
DiscreteProblem(fd, u0, tspan, p; kwargs...)
211212
end
212213

src/systems/jumps/jumpsystem.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,9 @@ function DiffEqBase.DiscreteProblem(sys::JumpSystem, u0map, tspan::Union{Tuple,
291291
end
292292
end
293293

294-
df = DiscreteFunction{true, true}(f; syms = Symbol.(states(sys)), sys = sys,
294+
df = DiscreteFunction{true, true}(f; syms = Symbol.(states(sys)),
295+
indepsym = Symbol(get_iv(sys)),
296+
paramsyms = Symbol.(ps), sys = sys,
295297
observed = observedfun)
296298
DiscreteProblem(df, u0, tspan, p; kwargs...)
297299
end
@@ -331,7 +333,9 @@ function DiscreteProblemExpr(sys::JumpSystem, u0map, tspan::Union{Tuple, Nothing
331333
u0 = $u0
332334
p = $p
333335
tspan = $tspan
334-
df = DiscreteFunction{true, true}(f, syms = $(Symbol.(states(sys))))
336+
df = DiscreteFunction{true, true}(f, syms = $(Symbol.(states(sys))),
337+
indepsym = $(Symbol(get_iv(sys))),
338+
paramsyms = $(Symbol.(parameters(sys))))
335339
DiscreteProblem(df, u0, tspan, p)
336340
end
337341
end

src/systems/nonlinear/nonlinearsystem.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,9 @@ function SciMLBase.NonlinearFunction{iip}(sys::NonlinearSystem, dvs = states(sys
240240
jac_prototype = sparse ?
241241
similar(calculate_jacobian(sys, sparse = sparse),
242242
Float64) : nothing,
243-
syms = Symbol.(states(sys)), observed = observedfun)
243+
syms = Symbol.(states(sys)),
244+
paramsyms = Symbol.(parameters(sys)),
245+
observed = observedfun)
244246
end
245247

246248
"""
@@ -285,7 +287,8 @@ function NonlinearFunctionExpr{iip}(sys::NonlinearSystem, dvs = states(sys),
285287
NonlinearFunction{$iip}(f,
286288
jac = jac,
287289
jac_prototype = $jp_expr,
288-
syms = $(Symbol.(states(sys))))
290+
syms = $(Symbol.(states(sys))),
291+
paramsyms = $(Symbol.(parameters(sys))))
289292
end
290293
!linenumbers ? striplines(ex) : ex
291294
end
@@ -314,6 +317,7 @@ function process_NonlinearProblem(constructor, sys::NonlinearSystem, u0map, para
314317

315318
f = constructor(sys, dvs, ps, u0; jac = jac, checkbounds = checkbounds,
316319
linenumbers = linenumbers, parallel = parallel, simplify = simplify,
320+
syms = Symbol.(dvs), paramsyms = Symbol.(ps),
317321
sparse = sparse, eval_expression = eval_expression, kwargs...)
318322
return f, u0, p
319323
end

src/systems/optimization/optimizationsystem.jl

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -268,11 +268,12 @@ function DiffEqBase.OptimizationProblem{iip}(sys::OptimizationSystem, u0map,
268268

269269
_f = DiffEqBase.OptimizationFunction{iip}(f,
270270
sys = sys,
271-
syms = Symbol.(states(sys)),
272271
SciMLBase.NoAD();
273272
grad = _grad,
274273
hess = _hess,
275274
hess_prototype = hess_prototype,
275+
syms = Symbol.(states(sys)),
276+
paramsyms = Symbol.(parameters(sys)),
276277
cons = cons,
277278
cons_j = cons_j,
278279
cons_h = cons_h,
@@ -283,10 +284,11 @@ function DiffEqBase.OptimizationProblem{iip}(sys::OptimizationSystem, u0map,
283284
else
284285
_f = DiffEqBase.OptimizationFunction{iip}(f,
285286
sys = sys,
286-
syms = Symbol.(states(sys)),
287287
SciMLBase.NoAD();
288288
grad = _grad,
289289
hess = _hess,
290+
syms = Symbol.(states(sys)),
291+
paramsyms = Symbol.(parameters(sys)),
290292
hess_prototype = hess_prototype,
291293
expr = obj_expr)
292294
end
@@ -399,9 +401,13 @@ function OptimizationProblemExpr{iip}(sys::OptimizationSystem, u0,
399401
cons = $cons
400402
cons_j = $cons_j
401403
cons_h = $cons_h
404+
syms = $(Symbol.(states(sys)))
405+
paramsyms = $(Symbol.(parameters(sys)))
402406
_f = OptimizationFunction{iip}(f, SciMLBase.NoAD();
403407
grad = grad,
404408
hess = hess,
409+
syms = syms,
410+
paramsyms = paramsyms,
405411
hess_prototype = hess_prototype,
406412
cons = cons,
407413
cons_j = cons_j,
@@ -421,9 +427,13 @@ function OptimizationProblemExpr{iip}(sys::OptimizationSystem, u0,
421427
hess = $_hess
422428
lb = $lb
423429
ub = $ub
430+
syms = $(Symbol.(states(sys)))
431+
paramsyms = $(Symbol.(parameters(sys)))
424432
_f = OptimizationFunction{iip}(f, SciMLBase.NoAD();
425433
grad = grad,
426434
hess = hess,
435+
syms = syms,
436+
paramsyms = paramsyms,
427437
hess_prototype = hess_prototype,
428438
expr = obj_expr)
429439
OptimizationProblem{$iip}(_f, u0, p; lb = lb, ub = ub, kwargs...)

0 commit comments

Comments
 (0)