Skip to content

Commit 997b428

Browse files
authored
Merge pull request #2266 from SciML/myb/fix
Fix analysis point and handle jacobian and tgrad correctly when parameters are split
2 parents 1344c1d + 97e558f commit 997b428

File tree

2 files changed

+46
-13
lines changed

2 files changed

+46
-13
lines changed

src/systems/connectors.jl

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -299,14 +299,10 @@ function generate_connection_set!(connectionsets, domain_csets,
299299
else
300300
if lhs isa Number || lhs isa Symbolic
301301
push!(eqs, eq) # split connections and equations
302-
elseif lhs isa Connection
303-
if get_systems(lhs) === :domain
304-
connection2set!(domain_csets, namespace, get_systems(rhs), isouter)
305-
else
306-
push!(cts, get_systems(rhs))
307-
end
302+
elseif lhs isa Connection && get_systems(lhs) === :domain
303+
connection2set!(domain_csets, namespace, get_systems(rhs), isouter)
308304
else
309-
error("$eq is not a legal equation!")
305+
push!(cts, get_systems(rhs))
310306
end
311307
end
312308
end

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 43 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -84,14 +84,37 @@ function generate_tgrad(sys::AbstractODESystem, dvs = states(sys), ps = paramete
8484
simplify = false, kwargs...)
8585
tgrad = calculate_tgrad(sys, simplify = simplify)
8686
pre = get_preprocess_constants(tgrad)
87-
return build_function(tgrad, dvs, ps, get_iv(sys); postprocess_fbody = pre, kwargs...)
87+
if ps isa Tuple
88+
return build_function(tgrad,
89+
dvs,
90+
ps...,
91+
get_iv(sys);
92+
postprocess_fbody = pre,
93+
kwargs...)
94+
else
95+
return build_function(tgrad,
96+
dvs,
97+
ps,
98+
get_iv(sys);
99+
postprocess_fbody = pre,
100+
kwargs...)
101+
end
88102
end
89103

90104
function generate_jacobian(sys::AbstractODESystem, dvs = states(sys), ps = parameters(sys);
91105
simplify = false, sparse = false, kwargs...)
92106
jac = calculate_jacobian(sys; simplify = simplify, sparse = sparse)
93107
pre = get_preprocess_constants(jac)
94-
return build_function(jac, dvs, ps, get_iv(sys); postprocess_fbody = pre, kwargs...)
108+
if ps isa Tuple
109+
return build_function(jac,
110+
dvs,
111+
ps...,
112+
get_iv(sys);
113+
postprocess_fbody = pre,
114+
kwargs...)
115+
else
116+
return build_function(jac, dvs, ps, get_iv(sys); postprocess_fbody = pre, kwargs...)
117+
end
95118
end
96119

97120
function generate_control_jacobian(sys::AbstractODESystem, dvs = states(sys),
@@ -364,8 +387,15 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem, dvs = s
364387
tgrad_oop, tgrad_iip = eval_expression ?
365388
(drop_expr(@RuntimeGeneratedFunction(eval_module, ex)) for ex in tgrad_gen) :
366389
tgrad_gen
367-
_tgrad(u, p, t) = tgrad_oop(u, p, t)
368-
_tgrad(J, u, p, t) = tgrad_iip(J, u, p, t)
390+
if p isa Tuple
391+
__tgrad(u, p, t) = tgrad_oop(u, p..., t)
392+
__tgrad(J, u, p, t) = tgrad_iip(J, u, p..., t)
393+
_tgrad = __tgrad
394+
else
395+
___tgrad(u, p, t) = tgrad_oop(u, p, t)
396+
___tgrad(J, u, p, t) = tgrad_iip(J, u, p, t)
397+
_tgrad = ___tgrad
398+
end
369399
else
370400
_tgrad = nothing
371401
end
@@ -379,8 +409,15 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem, dvs = s
379409
jac_oop, jac_iip = eval_expression ?
380410
(drop_expr(@RuntimeGeneratedFunction(eval_module, ex)) for ex in jac_gen) :
381411
jac_gen
382-
_jac(u, p, t) = jac_oop(u, p, t)
383-
_jac(J, u, p, t) = jac_iip(J, u, p, t)
412+
if p isa Tuple
413+
__jac(u, p, t) = jac_oop(u, p..., t)
414+
__jac(J, u, p, t) = jac_iip(J, u, p..., t)
415+
_jac = __jac
416+
else
417+
___jac(u, p, t) = jac_oop(u, p, t)
418+
___jac(J, u, p, t) = jac_iip(J, u, p, t)
419+
_jac = ___jac
420+
end
384421
else
385422
_jac = nothing
386423
end

0 commit comments

Comments
 (0)