Skip to content

Commit ead49da

Browse files
Merge pull request #925 from SciML/u/itpfix
Fix unwanted type promotion in InternalITP
2 parents 8ae17f1 + 841ef94 commit ead49da

File tree

8 files changed

+67
-58
lines changed

8 files changed

+67
-58
lines changed

src/callbacks.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -118,26 +118,26 @@ end
118118

119119
# Use a generated function for type stability even when many callbacks are given
120120
@inline function find_first_continuous_callback(integrator,
121-
callbacks::Vararg{
122-
AbstractContinuousCallback,
123-
N}) where {N}
121+
callbacks::Vararg{
122+
AbstractContinuousCallback,
123+
N}) where {N}
124124
find_first_continuous_callback(integrator, tuple(callbacks...))
125125
end
126126
@generated function find_first_continuous_callback(integrator,
127-
callbacks::NTuple{N,
128-
AbstractContinuousCallback
129-
}) where {N}
127+
callbacks::NTuple{N,
128+
AbstractContinuousCallback,
129+
}) where {N}
130130
ex = quote
131131
tmin, upcrossing, event_occurred, event_idx = find_callback_time(integrator,
132-
callbacks[1], 1)
132+
callbacks[1], 1)
133133
identified_idx = 1
134134
end
135135
for i in 2:N
136136
ex = quote
137137
$ex
138138
tmin2, upcrossing2, event_occurred2, event_idx2 = find_callback_time(integrator,
139-
callbacks[$i],
140-
$i)
139+
callbacks[$i],
140+
$i)
141141
if event_occurred2 && (tmin2 < tmin || !event_occurred)
142142
tmin = tmin2
143143
upcrossing = upcrossing2

src/internal_falsi.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::InternalFalsi, arg
3636

3737
if iszero(fr)
3838
return SciMLBase.build_solution(prob, alg, right, fr;
39-
retcode = ReturnCode.ExactSolutionLeft, left = left,
40-
right = right)
39+
retcode = ReturnCode.ExactSolutionLeft, left = left,
40+
right = right)
4141
end
4242

4343
i = 1
@@ -129,7 +129,7 @@ function scalar_nlsolve_ad(prob, alg::InternalFalsi, args...; kwargs...)
129129
end
130130

131131
function SciMLBase.solve(prob::IntervalNonlinearProblem{uType, iip,
132-
<:ForwardDiff.Dual{T, V, P}},
132+
<:ForwardDiff.Dual{T, V, P}},
133133
alg::InternalFalsi, args...;
134134
kwargs...) where {uType, iip, T, V, P}
135135
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
@@ -140,15 +140,15 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem{uType, iip,
140140
end
141141

142142
function SciMLBase.solve(prob::IntervalNonlinearProblem{uType, iip,
143-
<:AbstractArray{
144-
<:ForwardDiff.Dual{T,
145-
V,
146-
P},
147-
}},
143+
<:AbstractArray{
144+
<:ForwardDiff.Dual{T,
145+
V,
146+
P},
147+
}},
148148
alg::InternalFalsi, args...;
149149
kwargs...) where {uType, iip, T, V, P}
150150
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
151-
151+
152152
return SciMLBase.build_solution(prob, alg, ForwardDiff.Dual{T, V, P}(sol.u, partials),
153153
sol.resid; retcode = sol.retcode,
154154
left = ForwardDiff.Dual{T, V, P}(sol.left, partials),

src/internal_itp.jl

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,16 @@
22
`InternalITP`: A non-allocating ITP method, internal to DiffEqBase for
33
simpler dependencies.
44
"""
5-
struct InternalITP
5+
struct InternalITP
66
k1::Float64
77
k2::Float64
88
n0::Int
99
end
1010

1111
InternalITP() = InternalITP(0.007, 1.5, 10)
1212

13-
function SciMLBase.solve(prob::IntervalNonlinearProblem{IP, Tuple{T,T}}, alg::InternalITP, args...;
13+
function SciMLBase.solve(prob::IntervalNonlinearProblem{IP, Tuple{T, T}}, alg::InternalITP,
14+
args...;
1415
maxiters = 1000, kwargs...) where {IP, T}
1516
f = Base.Fix2(prob.f, prob.p)
1617
left, right = prob.tspan # a and b
@@ -26,9 +27,9 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem{IP, Tuple{T,T}}, alg::In
2627
right = right)
2728
end
2829
#defining variables/cache
29-
k1 = alg.k1
30-
k2 = alg.k2
31-
n0 = alg.n0
30+
k1 = T(alg.k1)
31+
k2 = T(alg.k2)
32+
n0 = T(alg.n0)
3233
n_h = ceil(log2(abs(right - left) / (2 * ϵ)))
3334
mid = (left + right) / 2
3435
x_f = (fr * left - fl * right) / (fr - fl)
@@ -46,7 +47,7 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem{IP, Tuple{T,T}}, alg::In
4647
δ = k1 * (span^k2)
4748

4849
## Interpolation step ##
49-
x_f = left + (right - left) * (fl/(fl - fr))
50+
x_f = left + (right - left) * (fl / (fl - fr))
5051

5152
## Truncation step ##
5253
σ = sign(mid - x_f)
@@ -79,8 +80,8 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem{IP, Tuple{T,T}}, alg::In
7980
left = prevfloat_tdir(xp, prob.tspan...)
8081
right = xp
8182
return SciMLBase.build_solution(prob, alg, left, f(left);
82-
retcode = ReturnCode.Success, left = left,
83-
right = right)
83+
retcode = ReturnCode.Success, left = left,
84+
right = right)
8485
end
8586
i += 1
8687
mid = (left + right) / 2
@@ -127,7 +128,7 @@ function scalar_nlsolve_ad(prob, alg::InternalITP, args...; kwargs...)
127128
end
128129

129130
function SciMLBase.solve(prob::IntervalNonlinearProblem{uType, iip,
130-
<:ForwardDiff.Dual{T, V, P}},
131+
<:ForwardDiff.Dual{T, V, P}},
131132
alg::InternalITP, args...;
132133
kwargs...) where {uType, iip, T, V, P}
133134
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
@@ -138,15 +139,15 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem{uType, iip,
138139
end
139140

140141
function SciMLBase.solve(prob::IntervalNonlinearProblem{uType, iip,
141-
<:AbstractArray{
142-
<:ForwardDiff.Dual{T,
143-
V,
144-
P},
145-
}},
142+
<:AbstractArray{
143+
<:ForwardDiff.Dual{T,
144+
V,
145+
P},
146+
}},
146147
alg::InternalITP, args...;
147148
kwargs...) where {uType, iip, T, V, P}
148149
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
149-
150+
150151
return SciMLBase.build_solution(prob, alg, ForwardDiff.Dual{T, V, P}(sol.u, partials),
151152
sol.resid; retcode = sol.retcode,
152153
left = ForwardDiff.Dual{T, V, P}(sol.left, partials),

src/solve.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ const NOISE_SIZE_MESSAGE = """
203203
be double checked.
204204
"""
205205

206-
struct NoiseSizeIncompatabilityError <: Exception
206+
struct NoiseSizeIncompatabilityError <: Exception
207207
prototypesize::Int
208208
noisesize::Int
209209
end
@@ -1025,7 +1025,7 @@ function solve(prob::EnsembleProblem, args...; kwargs...)
10251025
end
10261026
end
10271027
function solve(prob::SciMLBase.WeightedEnsembleProblem, args...; kwargs...)
1028-
SciMLBase.WeightedEnsembleSolution(solve(prob.ensembleprob), prob.weights)
1028+
SciMLBase.WeightedEnsembleSolution(solve(prob.ensembleprob), prob.weights)
10291029
end
10301030
function solve(prob::AbstractNoiseProblem, args...; kwargs...)
10311031
__solve(prob, args...; kwargs...)
@@ -1307,8 +1307,10 @@ function check_prob_alg_pairing(prob, alg)
13071307
end
13081308

13091309
if prob isa SDEProblem && prob.noise_rate_prototype !== nothing &&
1310-
prob.noise !== nothing && size(prob.noise_rate_prototype,2) != length(prob.noise.W[1])
1311-
throw(NoiseSizeIncompatabilityError(size(prob.noise_rate_prototype,2), length(prob.noise.W[1])))
1310+
prob.noise !== nothing &&
1311+
size(prob.noise_rate_prototype, 2) != length(prob.noise.W[1])
1312+
throw(NoiseSizeIncompatabilityError(size(prob.noise_rate_prototype, 2),
1313+
length(prob.noise.W[1])))
13121314
end
13131315

13141316
# Complex number support comes before arbitrary number support for a more direct

test/callbacks.jl

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -58,11 +58,11 @@ struct EmptyIntegrator
5858
u::Vector{Float64}
5959
end
6060
function DiffEqBase.find_callback_time(integrator::EmptyIntegrator,
61-
callback::ContinuousCallback, counter)
61+
callback::ContinuousCallback, counter)
6262
1.0 + counter, 0.9 + counter, true, counter
6363
end
6464
function DiffEqBase.find_callback_time(integrator::EmptyIntegrator,
65-
callback::VectorContinuousCallback, counter)
65+
callback::VectorContinuousCallback, counter)
6666
1.0 + counter, 0.9 + counter, true, counter
6767
end
6868
find_first_integrator = EmptyIntegrator([1.0, 2.0])
@@ -82,21 +82,21 @@ cond_9(u, t, integrator) = t - 1.8
8282
cond_10(u, t, integrator) = t - 1.9
8383
# Setup a lot of callbacks so the recursive inference failure happens
8484
callbacks = (ContinuousCallback(cond_1, affect!),
85-
ContinuousCallback(cond_2, affect!),
86-
ContinuousCallback(cond_3, affect!),
87-
ContinuousCallback(cond_4, affect!),
88-
ContinuousCallback(cond_5, affect!),
89-
ContinuousCallback(cond_6, affect!),
90-
ContinuousCallback(cond_7, affect!),
91-
ContinuousCallback(cond_8, affect!),
92-
ContinuousCallback(cond_9, affect!),
93-
ContinuousCallback(cond_10, affect!),
94-
VectorContinuousCallback(cond_1, vector_affect!, 2),
95-
VectorContinuousCallback(cond_2, vector_affect!, 2),
96-
VectorContinuousCallback(cond_3, vector_affect!, 2),
97-
VectorContinuousCallback(cond_4, vector_affect!, 2),
98-
VectorContinuousCallback(cond_5, vector_affect!, 2),
99-
VectorContinuousCallback(cond_6, vector_affect!, 2));
85+
ContinuousCallback(cond_2, affect!),
86+
ContinuousCallback(cond_3, affect!),
87+
ContinuousCallback(cond_4, affect!),
88+
ContinuousCallback(cond_5, affect!),
89+
ContinuousCallback(cond_6, affect!),
90+
ContinuousCallback(cond_7, affect!),
91+
ContinuousCallback(cond_8, affect!),
92+
ContinuousCallback(cond_9, affect!),
93+
ContinuousCallback(cond_10, affect!),
94+
VectorContinuousCallback(cond_1, vector_affect!, 2),
95+
VectorContinuousCallback(cond_2, vector_affect!, 2),
96+
VectorContinuousCallback(cond_3, vector_affect!, 2),
97+
VectorContinuousCallback(cond_4, vector_affect!, 2),
98+
VectorContinuousCallback(cond_5, vector_affect!, 2),
99+
VectorContinuousCallback(cond_6, vector_affect!, 2));
100100
function test_find_first_callback(callbacks, int)
101101
@timed(DiffEqBase.find_first_continuous_callback(int, callbacks...))
102102
end

test/downstream/solve_error_handling.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,5 +61,10 @@ function g(du, u, p, t)
6161
du[2, 4] = 1.8u[2]
6262
end
6363

64-
prob = SDEProblem(f, g, randn(ComplexF64,2), (0.0, 1.0), noise_rate_prototype =complex(zeros(2, 4)),noise=StochasticDiffEq.RealWienerProcess(0.0,zeros(3)))
64+
prob = SDEProblem(f,
65+
g,
66+
randn(ComplexF64, 2),
67+
(0.0, 1.0),
68+
noise_rate_prototype = complex(zeros(2, 4)),
69+
noise = StochasticDiffEq.RealWienerProcess(0.0, zeros(3)))
6570
@test_throws DiffEqBase.NoiseSizeIncompatabilityError solve(prob, LambaEM())

test/forwarddiff_dual_detection.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,8 @@ p_possibilities17 = [
8585
(Mod, ForwardDiff.Dual(2.0)), (() -> 2.0, ForwardDiff.Dual(2.0)),
8686
(Base.pointer([2.0]), ForwardDiff.Dual(2.0)),
8787
]
88-
VERSION >= v"1.7" && push!(p_possibilities17, Returns((a = 2, b = 1.3, c = ForwardDiff.Dual(2.0f0))))
88+
VERSION >= v"1.7" &&
89+
push!(p_possibilities17, Returns((a = 2, b = 1.3, c = ForwardDiff.Dual(2.0f0))))
8990

9091
for p in p_possibilities17
9192
@show p

test/internal_rootfinder.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ for Rootfinder in (InternalFalsi, InternalITP)
2020

2121
# https://github.com/SciML/DiffEqBase.jl/issues/916
2222
inp = IntervalNonlinearProblem((t, p) -> min(-1.0 + 0.001427344607477125 * t, 1e-9),
23-
(699.0079267259368, 700.6176418816023))
23+
(699.0079267259368, 700.6176418816023))
2424
@test solve(inp, rf).u 700.6016590257979
2525

2626
# Flipped signs & reversed tspan test for bracketing algorithms
@@ -36,4 +36,4 @@ for Rootfinder in (InternalFalsi, InternalITP)
3636
@test abs.(solve(inp3, rf).u) sqrt.(p)
3737
@test abs.(solve(inp4, rf).u) sqrt.(p)
3838
end
39-
end
39+
end

0 commit comments

Comments
 (0)