6767
6868# helper function for setting up min/max heaps for tstops and saveat
6969function tstops_and_saveat_heaps (t0, tf, tstops, saveat = [])
70- FT = typeof (tf)
70+ # We promote to a common type to ensure that t0 and tf have the same type
71+ FT = typeof (first (promote (t0, tf)))
7172 ordering = tf > t0 ? DataStructures. FasterForward : DataStructures. FasterReverse
7273
7374 # ensure that tstops includes tf and only has values ahead of t0
@@ -81,7 +82,7 @@ function tstops_and_saveat_heaps(t0, tf, tstops, saveat = [])
8182 return tstops, saveat
8283end
8384
84- compute_tdir (ts) = ts[1 ] > ts[end ] ? sign (ts[end ] - ts[1 ]) : eltype (ts)( 1 )
85+ compute_tdir (ts) = ts[1 ] > ts[end ] ? sign (ts[end ] - ts[1 ]) : oneunit (ts[ 1 ] )
8586
8687# called by DiffEqBase.init and DiffEqBase.solve
8788function DiffEqBase. __init (
@@ -102,8 +103,10 @@ function DiffEqBase.__init(
102103)
103104 (; u0, p) = prob
104105 t0, tf = prob. tspan
106+ t0, tf, dt = promote (t0, tf, dt)
105107
106- dt > zero (dt) || error (" dt must be positive" )
108+ # We need zero(oneunit()) because there's no zerounit
109+ dt > zero (oneunit (dt)) || error (" dt must be positive" )
107110 _dt = dt
108111 dt = tf > t0 ? dt : - dt
109112
@@ -243,8 +246,9 @@ function __step!(integrator)
243246 # is taken from OrdinaryDiffEq.jl
244247 t_plus_dt = integrator. t + integrator. dt
245248 t_unit = oneunit (integrator. t)
246- max_t_error = 100 * eps (float (integrator. t / t_unit)) * t_unit
247- integrator. t = ! isempty (tstops) && abs (first (tstops) - t_plus_dt) < max_t_error ? first (tstops) : t_plus_dt
249+ max_t_error = 100 * eps (float (integrator. t / t_unit)) * float (t_unit)
250+ integrator. t =
251+ ! isempty (tstops) && abs (float (first (tstops)) - float (t_plus_dt)) < max_t_error ? first (tstops) : t_plus_dt
248252
249253 # apply callbacks
250254 discrete_callbacks = integrator. callback. discrete_callbacks
0 commit comments