Skip to content

Commit 06ffeda

Browse files
author
oscarddssmith
committed
fix and add tests
1 parent 5329ed6 commit 06ffeda

File tree

3 files changed

+14
-3
lines changed

3 files changed

+14
-3
lines changed

ext/DiffEqBaseUnitfulExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,5 +30,5 @@ end
3030
real(abs2(x) / oneunit(x) * oneunit(x))
3131
end
3232

33-
_rate_prototype(u, t, onet) = u / unit(t)
33+
DiffEqBase._rate_prototype(u, t, onet) = u / unit(t)
3434
end

src/utils.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,4 +62,5 @@ function default_logger(logger)
6262
LoggingExtras.TeeLogger(logger1, logger2)
6363
end
6464

65-
_rate_prototype(u, t::T, onet::T) where {T} = u / oneunit(t)
65+
# for the non-unitful case the correct type is just u
66+
_rate_prototype(u, t::T, onet::T) where {T} = u

test/utils.jl

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
using Test
22

33
using DiffEqBase, ForwardDiff
4-
using DiffEqBase: prob2dtmin, timedepentdtmin
4+
using DiffEqBase: prob2dtmin, timedepentdtmin, _rate_prototype
5+
using Unitful
56

67
@testset "tspan2dtmin" begin
78
# we only need to test very rough equality since timestepping isn't science.
@@ -32,3 +33,12 @@ end
3233
@test prob2dtmin((0.0f0, 10.0f0), 1.0f0, false) == eps(Float32)
3334
@test prob2dtmin((0.0, 10.0), ForwardDiff.Dual(1.0), false) == eps(Float64)
3435
end
36+
37+
@testset "_rate_prototype" begin
38+
@test _rate_prototype([1f0], 1.0, 1.0) isa Vector{Float32}
39+
td = Dual{Tag{typeof(+), Float64}}(2.0,1.0)
40+
@test _rate_prototype([1f0], td, td) isa Vector{Float32}
41+
xd = [Dual{Tag{typeof(+), Float32}}(2.0,1.0)]
42+
@test _rate_prototype(xd, 1.0, 1.0) isa typeof(xd)
43+
@test _rate_prototype([u"1f0m"], u"1.0s", 1.0) isa typeof([u"1f0m/s"])
44+
end

0 commit comments

Comments
 (0)