Skip to content

Commit 8010eb8

Browse files
Merge pull request #514 from SciML/smalls
don't rebuild things if you're not doing anything fancy
2 parents 4e64595 + 53987dd commit 8010eb8

File tree

4 files changed

+27
-6
lines changed

4 files changed

+27
-6
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DiffEqBase"
22
uuid = "2b5f629d-d688-5b77-993f-72d75c75574e"
33
authors = ["Chris Rackauckas <[email protected]>"]
4-
version = "6.34.0"
4+
version = "6.34.1"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"

src/init.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ cuify(x) = error("To use LinSolveGPUFactorize, you must do `using CuArrays`")
33
promote_u0(u0,p,t0) = u0
44
promote_tspan(u0,p,tspan,prob,kwargs) = tspan
55
get_tmp(x) = nothing
6+
isdistribution(u0) = false
67

78
if VERSION < v"1.4.0-DEV.635"
89
# Piracy, should get upstreamed
@@ -18,6 +19,7 @@ function __init__()
1819

1920
@require Distributions="31c24e10-a181-5473-b8eb-7969acd0382f" begin
2021
handle_distribution_u0(_u0::Distributions.Sampleable) = rand(_u0)
22+
isdistribution(_u0::Distributions.Sampleable) = true
2123
end
2224

2325
@require ForwardDiff="f6369f11-7733-5829-9624-2563aa707210" begin
@@ -54,7 +56,7 @@ function __init__()
5456

5557
@inline ODE_DEFAULT_NORM(u::AbstractArray{<:ForwardDiff.Dual{<:Any,<:ForwardDiff.Dual}},::ForwardDiff.Dual) = sqrt(sum(UNITLESS_ABS2value,u) / length(u))
5658
@inline ODE_DEFAULT_NORM(u::ForwardDiff.Dual{<:Any,ForwardDiff.Dual},::ForwardDiff.Dual) = abs(value(u))
57-
59+
5860
@inline ODE_DEFAULT_NORM(u::AbstractArray{<:ForwardDiff.Dual{<:Any,<:ForwardDiff.Dual}},::ForwardDiff.Dual{<:Any,ForwardDiff.Dual}) = sqrt(sum(UNITLESS_ABS2,u) / length(u))
5961
@inline ODE_DEFAULT_NORM(u::ForwardDiff.Dual{<:Any,ForwardDiff.Dual},::ForwardDiff.Dual{<:Any,ForwardDiff.Dual}) = abs(u)
6062

src/solve.jl

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ function solve_call(_prob,args...;merge_callbacks = true, kwargs...)
4444
callbacks = NamedTuple{(:callback,)}( [DiffEqBase.CallbackSet(_prob.kwargs[:callback], values(kwargs).callback )] )
4545
kwargs = merge(kwargs_temp, callbacks)
4646
end
47-
kwargs = merge(values(_prob.kwargs), kwargs)
47+
kwargs = isempty(_prob.kwargs) ? kwargs : merge(values(_prob.kwargs), kwargs)
4848
end
4949

5050
T = Core.Compiler.return_type(__solve,Tuple{typeof(_prob),map(typeof, args)...})
@@ -121,9 +121,14 @@ function discretize end
121121
function get_concrete_problem(prob, kwargs)
122122
tspan = get_concrete_tspan(prob, kwargs)
123123
u0 = get_concrete_u0(prob, tspan[1], kwargs)
124-
u0 = promote_u0(u0, prob.p, tspan[1])
125-
tspan = promote_tspan(u0, prob.p, tspan, prob, kwargs)
126-
remake(prob; u0 = u0, tspan = tspan)
124+
u0_promote = promote_u0(u0, prob.p, tspan[1])
125+
tspan_promote = promote_tspan(u0, prob.p, tspan, prob, kwargs)
126+
if isconcreteu0(prob, tspan[1], kwargs) && typeof(u0_promote) === typeof(u0) &&
127+
prob.tspan == tspan && typeof(tspan) === typeof(tspan_promote)
128+
return prob
129+
else
130+
return remake(prob; u0 = u0_promote, tspan = tspan_promote)
131+
end
127132
end
128133

129134
function get_concrete_problem(prob::DDEProblem, kwargs)
@@ -159,6 +164,10 @@ function get_concrete_tspan(prob, kwargs)
159164
tspan
160165
end
161166

167+
function isconcreteu0(prob, t0, kwargs)
168+
!eval_u0(prob.u0) && prob.u0 !== nothing && !isdistribution(prob.u0)
169+
end
170+
162171
function get_concrete_u0(prob, t0, kwargs)
163172
if eval_u0(prob.u0)
164173
u0 = prob.u0(prob.p, t0)

test/downstream/inference.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,3 +44,13 @@ function solve_ode(f::F, p::P) where {F,P}
4444
return sol
4545
end
4646
@test_broken @inferred solve_ode(f, (a = 1, b = 1))
47+
48+
using StochasticDiffEq, Test
49+
u0=1/2
50+
ff(u,p,t) = u
51+
gg(u,p,t) = u
52+
dt = 1//2^(4)
53+
tspan = (0.0,1.0)
54+
prob = SDEProblem(ff,gg,u0,(0.0,1.0))
55+
sol = solve(prob,EM(),dt=dt)
56+
@inferred solve(prob,EM(),dt=dt)

0 commit comments

Comments
 (0)