Skip to content

Commit 02f3ae0

Browse files
don't rebuild things if you're not doing anything fancy
1 parent 4e64595 commit 02f3ae0

File tree

2 files changed

+15
-5
lines changed

2 files changed

+15
-5
lines changed

src/init.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ cuify(x) = error("To use LinSolveGPUFactorize, you must do `using CuArrays`")
33
promote_u0(u0,p,t0) = u0
44
promote_tspan(u0,p,tspan,prob,kwargs) = tspan
55
get_tmp(x) = nothing
6+
isdistribution(u0) = false
67

78
if VERSION < v"1.4.0-DEV.635"
89
# Piracy, should get upstreamed
@@ -18,6 +19,7 @@ function __init__()
1819

1920
@require Distributions="31c24e10-a181-5473-b8eb-7969acd0382f" begin
2021
handle_distribution_u0(_u0::Distributions.Sampleable) = rand(_u0)
22+
isdistribution(_u0::Distributions.Sampleable) = true
2123
end
2224

2325
@require ForwardDiff="f6369f11-7733-5829-9624-2563aa707210" begin
@@ -54,7 +56,7 @@ function __init__()
5456

5557
@inline ODE_DEFAULT_NORM(u::AbstractArray{<:ForwardDiff.Dual{<:Any,<:ForwardDiff.Dual}},::ForwardDiff.Dual) = sqrt(sum(UNITLESS_ABS2value,u) / length(u))
5658
@inline ODE_DEFAULT_NORM(u::ForwardDiff.Dual{<:Any,ForwardDiff.Dual},::ForwardDiff.Dual) = abs(value(u))
57-
59+
5860
@inline ODE_DEFAULT_NORM(u::AbstractArray{<:ForwardDiff.Dual{<:Any,<:ForwardDiff.Dual}},::ForwardDiff.Dual{<:Any,ForwardDiff.Dual}) = sqrt(sum(UNITLESS_ABS2,u) / length(u))
5961
@inline ODE_DEFAULT_NORM(u::ForwardDiff.Dual{<:Any,ForwardDiff.Dual},::ForwardDiff.Dual{<:Any,ForwardDiff.Dual}) = abs(u)
6062

src/solve.jl

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ function solve_call(_prob,args...;merge_callbacks = true, kwargs...)
4444
callbacks = NamedTuple{(:callback,)}( [DiffEqBase.CallbackSet(_prob.kwargs[:callback], values(kwargs).callback )] )
4545
kwargs = merge(kwargs_temp, callbacks)
4646
end
47-
kwargs = merge(values(_prob.kwargs), kwargs)
47+
kwargs = isempty(_prob.kwargs) ? kwargs : merge(values(_prob.kwargs), kwargs)
4848
end
4949

5050
T = Core.Compiler.return_type(__solve,Tuple{typeof(_prob),map(typeof, args)...})
@@ -121,9 +121,13 @@ function discretize end
121121
function get_concrete_problem(prob, kwargs)
122122
tspan = get_concrete_tspan(prob, kwargs)
123123
u0 = get_concrete_u0(prob, tspan[1], kwargs)
124-
u0 = promote_u0(u0, prob.p, tspan[1])
125-
tspan = promote_tspan(u0, prob.p, tspan, prob, kwargs)
126-
remake(prob; u0 = u0, tspan = tspan)
124+
u0_promote = promote_u0(u0, prob.p, tspan[1])
125+
tspan_promote = promote_tspan(u0, prob.p, tspan, prob, kwargs)
126+
if isconcreteu0(prob, t0, kwargs) && typeof(u0_promote) === typeof(u0) && typeof(tspan) === typeof(tspan_promote)
127+
return prob
128+
else
129+
return remake(prob; u0 = u0_promote, tspan = tspan_promote)
130+
end
127131
end
128132

129133
function get_concrete_problem(prob::DDEProblem, kwargs)
@@ -159,6 +163,10 @@ function get_concrete_tspan(prob, kwargs)
159163
tspan
160164
end
161165

166+
function isconcreteu0(prob, t0, kwargs)
167+
!eval_u0(prob.u0) && prob.u0 !== nothing && !isdistribution(prob.u0)
168+
end
169+
162170
function get_concrete_u0(prob, t0, kwargs)
163171
if eval_u0(prob.u0)
164172
u0 = prob.u0(prob.p, t0)

0 commit comments

Comments
 (0)