Skip to content

Commit a556e80

Browse files
Merge pull request #2983 from AayushSabharwal/as/use-union-false
refactor: default to `use_union = false` for `ODEProblem`s
2 parents dd7d557 + 8a0415b commit a556e80

File tree

4 files changed

+9
-10
lines changed

4 files changed

+9
-10
lines changed

ext/MTKChainRulesCoreExt.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@ import ChainRulesCore: NoTangent
66

77
function ChainRulesCore.rrule(::Type{MTK.MTKParameters}, tunables, args...)
88
function mtp_pullback(dt)
9-
(NoTangent(), dt.tunable[1:length(tunables)], ntuple(_ -> NoTangent(), length(args))...)
9+
(NoTangent(), dt.tunable[1:length(tunables)],
10+
ntuple(_ -> NoTangent(), length(args))...)
1011
end
1112
MTK.MTKParameters(tunables, args...), mtp_pullback
1213
end

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -772,7 +772,7 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
772772
linenumbers = true, parallel = SerialForm(),
773773
eval_expression = false,
774774
eval_module = @__MODULE__,
775-
use_union = true,
775+
use_union = false,
776776
tofloat = true,
777777
symbolic_u0 = false,
778778
u0_constructor = identity,

src/systems/jumps/jumpsystem.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -446,7 +446,7 @@ oprob = ODEProblem(complete(js), u₀map, tspan, parammap)
446446
"""
447447
function DiffEqBase.ODEProblem(sys::JumpSystem, u0map, tspan::Union{Tuple, Nothing},
448448
parammap = DiffEqBase.NullParameters();
449-
use_union = true,
449+
use_union = false,
450450
eval_expression = false,
451451
eval_module = @__MODULE__,
452452
kwargs...)

test/extensions/ad.jl

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,13 @@ using SciMLSensitivity
88

99
@variables x(t)[1:3] y(t)
1010
@parameters p[1:3, 1:3] q
11-
eqs = [
12-
D(x) ~ p * x
13-
D(y) ~ sum(p) + q * y
14-
]
11+
eqs = [D(x) ~ p * x
12+
D(y) ~ sum(p) + q * y]
1513
u0 = [x => zeros(3),
16-
y => 1.]
14+
y => 1.0]
1715
ps = [p => zeros(3, 3),
18-
q => 1.]
19-
tspan = (0., 10.)
16+
q => 1.0]
17+
tspan = (0.0, 10.0)
2018
@mtkbuild sys = ODESystem(eqs, t)
2119
prob = ODEProblem(sys, u0, tspan, ps)
2220
sol = solve(prob, Tsit5())

0 commit comments

Comments
 (0)