Skip to content

Commit d020612

Browse files
add unitfulvalue
1 parent 8b36969 commit d020612

8 files changed

+33
-0
lines changed

ext/DiffEqBaseForwardDiffExt.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -497,6 +497,9 @@ end
497497
value(x::Type{ForwardDiff.Dual{T, V, N}}) where {T, V, N} = V
498498
value(x::ForwardDiff.Dual) = value(ForwardDiff.value(x))
499499

500+
unitfulvalue(x::Type{ForwardDiff.Dual{T, V, N}}) where {T, V, N} = V
501+
unitfulvalue(x::ForwardDiff.Dual) = unitfulvalue(ForwardDiff.unitfulvalue(x))
502+
500503
sse(x::Number) = abs2(x)
501504
sse(x::ForwardDiff.Dual) = sse(ForwardDiff.value(x)) + sum(sse, ForwardDiff.partials(x))
502505
totallength(x::Number) = 1

ext/DiffEqBaseGTPSAExt.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@ end
1313
value(x::TPS) = scalar(x)
1414
value(::Type{<:TPS{T}}) where {T} = T
1515

16+
unitfulvalue(x::TPS) = scalar(x)
17+
unitfulvalue(::Type{<:TPS{T}}) where {T} = T
18+
1619
ODE_DEFAULT_NORM(u::TPS, t) = normTPS(u)
1720
ODE_DEFAULT_NORM(f::F, u::TPS, t) where {F} = normTPS(f(u))
1821

ext/DiffEqBaseMeasurementsExt.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ DiffEqBase.promote_u0(u0, p::AbstractArray{<:Measurements.Measurement}, t0) = el
1919
value(x::Type{Measurements.Measurement{T}}) where {T} = T
2020
value(x::Measurements.Measurement) = Measurements.value(x)
2121

22+
unitfulvalue(x::Type{Measurements.Measurement{T}}) where {T} = T
23+
unitfulvalue(x::Measurements.Measurement) = Measurements.value(x)
24+
2225
# Support adaptive steps should be errorless
2326
@inline function DiffEqBase.ODE_DEFAULT_NORM(
2427
u::AbstractArray{

ext/DiffEqBaseMonteCarloMeasurementsExt.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ end
2626

2727
DiffEqBase.value(x::Type{MonteCarloMeasurements.AbstractParticles{T, N}}) where {T, N} = T
2828
DiffEqBase.value(x::MonteCarloMeasurements.AbstractParticles) = mean(x.particles)
29+
DiffEqBase.unitfulvalue(x::Type{MonteCarloMeasurements.AbstractParticles{T, N}}) where {T, N} = T
30+
DiffEqBase.unitfulvalue(x::MonteCarloMeasurements.AbstractParticles) = mean(x.particles)
2931

3032
# Support adaptive steps should be errorless
3133
@inline function DiffEqBase.ODE_DEFAULT_NORM(

ext/DiffEqBaseReverseDiffExt.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,18 @@ end
2323
DiffEqBase.value(x::ReverseDiff.TrackedReal) = x.value
2424
DiffEqBase.value(x::ReverseDiff.TrackedArray) = x.value
2525

26+
27+
DiffEqBase.unitfulvalue(x::Type{ReverseDiff.TrackedReal{V, D, O}}) where {V, D, O} = V
28+
function DiffEqBase.unitfulvalue(x::Type{
29+
ReverseDiff.TrackedArray{V, D, N, VA, DA},
30+
}) where {V, D,
31+
N, VA,
32+
DA}
33+
Array{V, N}
34+
end
35+
DiffEqBase.unitfulvalue(x::ReverseDiff.TrackedReal) = x.value
36+
DiffEqBase.unitfulvalue(x::ReverseDiff.TrackedArray) = x.value
37+
2638
# Force TrackedArray from TrackedReal when reshaping W\b
2739
DiffEqBase._reshape(v::AbstractVector{<:ReverseDiff.TrackedReal}, siz) = reduce(vcat, v)
2840

ext/DiffEqBaseTrackerExt.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,11 @@ DiffEqBase.value(x::Type{Tracker.TrackedArray{T, N, A}}) where {T, N, A} = Array
1515
DiffEqBase.value(x::Tracker.TrackedReal) = x.data
1616
DiffEqBase.value(x::Tracker.TrackedArray) = x.data
1717

18+
DiffEqBase.unitfulvalue(x::Type{Tracker.TrackedReal{T}}) where {T} = T
19+
DiffEqBase.unitfulvalue(x::Type{Tracker.TrackedArray{T, N, A}}) where {T, N, A} = Array{T, N}
20+
DiffEqBase.unitfulvalue(x::Tracker.TrackedReal) = x.data
21+
DiffEqBase.unitfulvalue(x::Tracker.TrackedArray) = x.data
22+
1823
DiffEqBase.promote_u0(u0::Tracker.TrackedArray, p::Tracker.TrackedArray, t0) = u0
1924
function DiffEqBase.promote_u0(u0::AbstractArray{<:Tracker.TrackedReal},
2025
p::Tracker.TrackedArray, t0)

ext/DiffEqBaseUnitfulExt.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@ end
1212
# Support adaptive errors should be errorless for exponentiation
1313
value(x::Type{Unitful.AbstractQuantity{T, D, U}}) where {T, D, U} = T
1414
value(x::Unitful.AbstractQuantity) = x.val
15+
16+
unitfulvalue(x::Type{T}) where {T <: Unitful.AbstractQuantity} = T
17+
unitfulvalue(x::Unitful.AbstractQuantity) = x
18+
1519
@inline function DiffEqBase.ODE_DEFAULT_NORM(
1620
u::AbstractArray{
1721
<:Unitful.AbstractQuantity,

src/utils.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Handled in Extensions
22
value(x) = x
3+
unitfulvalue(x) = x
34
isdistribution(u0) = false
45

56
_vec(v) = vec(v)

0 commit comments

Comments
 (0)