Skip to content

Commit 95b5033

Browse files
Merge pull request #963 from oscardssmith/add-rate_prototype
add `_rate_prototype` function for getting the type of `du`
2 parents 0532345 + 91c822a commit 95b5033

File tree

3 files changed

+16
-1
lines changed

3 files changed

+16
-1
lines changed

ext/DiffEqBaseUnitfulExt.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,4 +30,5 @@ end
3030
real(abs2(x) / oneunit(x) * oneunit(x))
3131
end
3232

33+
DiffEqBase._rate_prototype(u, t, onet) = u / unit(t)
3334
end

src/utils.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,3 +61,6 @@ function default_logger(logger)
6161

6262
LoggingExtras.TeeLogger(logger1, logger2)
6363
end
64+
65+
# for the non-unitful case the correct type is just u
66+
_rate_prototype(u, t::T, onet::T) where {T} = u

test/utils.jl

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
using Test
22

33
using DiffEqBase, ForwardDiff
4-
using DiffEqBase: prob2dtmin, timedepentdtmin
4+
using DiffEqBase: prob2dtmin, timedepentdtmin, _rate_prototype
5+
using Unitful
6+
using ForwardDiff: Dual, Tag
57

68
@testset "tspan2dtmin" begin
79
# we only need to test very rough equality since timestepping isn't science.
@@ -32,3 +34,12 @@ end
3234
@test prob2dtmin((0.0f0, 10.0f0), 1.0f0, false) == eps(Float32)
3335
@test prob2dtmin((0.0, 10.0), ForwardDiff.Dual(1.0), false) == eps(Float64)
3436
end
37+
38+
@testset "_rate_prototype" begin
39+
@test _rate_prototype([1f0], 1.0, 1.0) isa Vector{Float32}
40+
td = Dual{Tag{typeof(+), Float64}}(2.0,1.0)
41+
@test _rate_prototype([1f0], td, td) isa Vector{Float32}
42+
xd = [Dual{Tag{typeof(+), Float32}}(2.0,1.0)]
43+
@test _rate_prototype(xd, 1.0, 1.0) isa typeof(xd)
44+
@test _rate_prototype([u"1f0m"], u"1.0s", 1.0) isa typeof([u"1f0m/s"])
45+
end

0 commit comments

Comments
 (0)