Skip to content

Commit f9157be

Browse files
Merge pull request SciML#1197 from jClugstor/move_to_SciMLBase
Move utilities to SciMLBase
2 parents 89c4664 + f845a85 commit f9157be

16 files changed

+59
-1374
lines changed

Project.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@ Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
5151
[extensions]
5252
DiffEqBaseCUDAExt = "CUDA"
5353
DiffEqBaseChainRulesCoreExt = "ChainRulesCore"
54-
DiffEqBaseDistributionsExt = "Distributions"
5554
DiffEqBaseEnzymeExt = ["ChainRulesCore", "Enzyme"]
5655
DiffEqBaseForwardDiffExt = ["ForwardDiff"]
5756
DiffEqBaseGTPSAExt = "GTPSA"

ext/DiffEqBaseDistributionsExt.jl

Lines changed: 0 additions & 8 deletions
This file was deleted.

ext/DiffEqBaseForwardDiffExt.jl

Lines changed: 17 additions & 414 deletions
Large diffs are not rendered by default.

ext/DiffEqBaseGTPSAExt.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
module DiffEqBaseGTPSAExt
22

33
using DiffEqBase
4-
import DiffEqBase: value, ODE_DEFAULT_NORM
4+
import DiffEqBase: ODE_DEFAULT_NORM
5+
import SciMLBase: value, unitfulvalue
56
using GTPSA
67

78
value(x::TPS) = scalar(x)

ext/DiffEqBaseMeasurementsExt.jl

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,6 @@ using DiffEqBase
44
import DiffEqBase: value
55
using Measurements
66

7-
function DiffEqBase.promote_u0(u0::AbstractArray{<:Measurements.Measurement},
8-
p::AbstractArray{<:Measurements.Measurement}, t0)
9-
u0
10-
end
11-
DiffEqBase.promote_u0(u0, p::AbstractArray{<:Measurements.Measurement}, t0) = eltype(p).(u0)
12-
13-
value(x::Type{Measurements.Measurement{T}}) where {T} = T
14-
value(x::Measurements.Measurement) = Measurements.value(x)
15-
16-
unitfulvalue(x::Type{Measurements.Measurement{T}}) where {T} = T
17-
unitfulvalue(x::Measurements.Measurement) = Measurements.value(x)
18-
197
# Support adaptive steps should be errorless
208
@inline function DiffEqBase.ODE_DEFAULT_NORM(
219
u::AbstractArray{

ext/DiffEqBaseMonteCarloMeasurementsExt.jl

Lines changed: 0 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -4,34 +4,6 @@ using DiffEqBase
44
import DiffEqBase: value
55
using MonteCarloMeasurements
66

7-
function DiffEqBase.promote_u0(
8-
u0::AbstractArray{
9-
<:MonteCarloMeasurements.AbstractParticles,
10-
},
11-
p::AbstractArray{<:MonteCarloMeasurements.AbstractParticles},
12-
t0)
13-
u0
14-
end
15-
function DiffEqBase.promote_u0(u0,
16-
p::AbstractArray{<:MonteCarloMeasurements.AbstractParticles},
17-
t0)
18-
eltype(p).(u0)
19-
end
20-
21-
function DiffEqBase.promote_u0(::Nothing,
22-
p::AbstractArray{<:MonteCarloMeasurements.AbstractParticles},
23-
t0)
24-
return nothing
25-
end
26-
27-
DiffEqBase.value(x::Type{MonteCarloMeasurements.AbstractParticles{T, N}}) where {T, N} = T
28-
DiffEqBase.value(x::MonteCarloMeasurements.AbstractParticles) = mean(x.particles)
29-
function DiffEqBase.unitfulvalue(x::Type{MonteCarloMeasurements.AbstractParticles{
30-
T, N}}) where {T, N}
31-
T
32-
end
33-
DiffEqBase.unitfulvalue(x::MonteCarloMeasurements.AbstractParticles) = mean(x.particles)
34-
357
# Support adaptive steps should be errorless
368
@inline function DiffEqBase.ODE_DEFAULT_NORM(
379
u::AbstractArray{

ext/DiffEqBaseMooncakeExt.jl

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -29,18 +29,4 @@ import Mooncake: rrule!!, CoDual, zero_fcodual, @is_primitive,
2929
},
3030
true,)
3131

32-
@zero_adjoint MinimalCtx Tuple{typeof(DiffEqBase.numargs), Any}
33-
@is_primitive MinimalCtx Tuple{
34-
typeof(DiffEqBase.set_mooncakeoriginator_if_mooncake), SciMLBase.ChainRulesOriginator
35-
}
36-
37-
@mooncake_overlay DiffEqBase.set_mooncakeoriginator_if_mooncake(x::SciMLBase.ADOriginator) = SciMLBase.MooncakeOriginator()
38-
39-
function rrule!!(
40-
f::CoDual{typeof(DiffEqBase.set_mooncakeoriginator_if_mooncake)},
41-
X::CoDual{SciMLBase.ChainRulesOriginator}
42-
)
43-
return zero_fcodual(SciMLBase.MooncakeOriginator()), NoPullback(f, X)
44-
end
45-
4632
end

ext/DiffEqBaseReverseDiffExt.jl

Lines changed: 0 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -5,57 +5,6 @@ import DiffEqBase: value
55
import ReverseDiff
66
import DiffEqBase.ArrayInterface
77

8-
function DiffEqBase.anyeltypedual(::Type{T},
9-
::Type{Val{counter}} = Val{0}) where {counter} where {
10-
V, D, N, VA, DA, T <: ReverseDiff.TrackedArray{V, D, N, VA, DA}}
11-
DiffEqBase.anyeltypedual(V, Val{counter})
12-
end
13-
14-
DiffEqBase.value(x::Type{ReverseDiff.TrackedReal{V, D, O}}) where {V, D, O} = V
15-
function DiffEqBase.value(x::Type{
16-
ReverseDiff.TrackedArray{V, D, N, VA, DA},
17-
}) where {V, D,
18-
N, VA,
19-
DA}
20-
Array{V, N}
21-
end
22-
DiffEqBase.value(x::ReverseDiff.TrackedReal) = x.value
23-
DiffEqBase.value(x::ReverseDiff.TrackedArray) = x.value
24-
25-
DiffEqBase.unitfulvalue(x::Type{ReverseDiff.TrackedReal{V, D, O}}) where {V, D, O} = V
26-
function DiffEqBase.unitfulvalue(x::Type{
27-
ReverseDiff.TrackedArray{V, D, N, VA, DA},
28-
}) where {V, D,
29-
N, VA,
30-
DA}
31-
Array{V, N}
32-
end
33-
DiffEqBase.unitfulvalue(x::ReverseDiff.TrackedReal) = x.value
34-
DiffEqBase.unitfulvalue(x::ReverseDiff.TrackedArray) = x.value
35-
36-
# Force TrackedArray from TrackedReal when reshaping W\b
37-
DiffEqBase._reshape(v::AbstractVector{<:ReverseDiff.TrackedReal}, siz) = reduce(vcat, v)
38-
39-
DiffEqBase.promote_u0(u0::ReverseDiff.TrackedArray, p::ReverseDiff.TrackedArray, t0) = u0
40-
function DiffEqBase.promote_u0(u0::AbstractArray{<:ReverseDiff.TrackedReal},
41-
p::ReverseDiff.TrackedArray, t0)
42-
u0
43-
end
44-
function DiffEqBase.promote_u0(u0::ReverseDiff.TrackedArray,
45-
p::AbstractArray{<:ReverseDiff.TrackedReal}, t0)
46-
u0
47-
end
48-
function DiffEqBase.promote_u0(u0::AbstractArray{<:ReverseDiff.TrackedReal},
49-
p::AbstractArray{<:ReverseDiff.TrackedReal}, t0)
50-
u0
51-
end
52-
DiffEqBase.promote_u0(u0, p::ReverseDiff.TrackedArray, t0) = ReverseDiff.track(u0)
53-
function DiffEqBase.promote_u0(
54-
u0, p::ReverseDiff.TrackedArray{T}, t0) where {T <: ReverseDiff.ForwardDiff.Dual}
55-
ReverseDiff.track(T.(u0))
56-
end
57-
DiffEqBase.promote_u0(u0, p::AbstractArray{<:ReverseDiff.TrackedReal}, t0) = eltype(p).(u0)
58-
598
# Support adaptive with non-tracked time
609
@inline function DiffEqBase.ODE_DEFAULT_NORM(u::ReverseDiff.TrackedArray, t)
6110
sqrt(sum(abs2, DiffEqBase.value(u)) / length(u))

ext/DiffEqBaseTrackerExt.jl

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -4,36 +4,6 @@ using DiffEqBase
44
import DiffEqBase: value
55
import Tracker
66

7-
DiffEqBase.value(x::Type{Tracker.TrackedReal{T}}) where {T} = T
8-
DiffEqBase.value(x::Type{Tracker.TrackedArray{T, N, A}}) where {T, N, A} = Array{T, N}
9-
DiffEqBase.value(x::Tracker.TrackedReal) = x.data
10-
DiffEqBase.value(x::Tracker.TrackedArray) = x.data
11-
12-
DiffEqBase.unitfulvalue(x::Type{Tracker.TrackedReal{T}}) where {T} = T
13-
function DiffEqBase.unitfulvalue(x::Type{Tracker.TrackedArray{T, N, A}}) where {T, N, A}
14-
Array{T, N}
15-
end
16-
DiffEqBase.unitfulvalue(x::Tracker.TrackedReal) = x.data
17-
DiffEqBase.unitfulvalue(x::Tracker.TrackedArray) = x.data
18-
19-
DiffEqBase.promote_u0(u0::Tracker.TrackedArray, p::Tracker.TrackedArray, t0) = u0
20-
function DiffEqBase.promote_u0(u0::AbstractArray{<:Tracker.TrackedReal},
21-
p::Tracker.TrackedArray, t0)
22-
u0
23-
end
24-
function DiffEqBase.promote_u0(u0::Tracker.TrackedArray,
25-
p::AbstractArray{<:Tracker.TrackedReal}, t0)
26-
u0
27-
end
28-
function DiffEqBase.promote_u0(u0::AbstractArray{<:Tracker.TrackedReal},
29-
p::AbstractArray{<:Tracker.TrackedReal}, t0)
30-
u0
31-
end
32-
DiffEqBase.promote_u0(u0, p::Tracker.TrackedArray, t0) = Tracker.track(u0)
33-
DiffEqBase.promote_u0(u0, p::AbstractArray{<:Tracker.TrackedReal}, t0) = eltype(p).(u0)
34-
35-
@inline Base.any(f::Function, x::Tracker.TrackedArray) = any(f, Tracker.data(x))
36-
377
# Support adaptive with non-tracked time
388
@inline function DiffEqBase.ODE_DEFAULT_NORM(u::Tracker.TrackedArray, t)
399
sqrt(sum(abs2, DiffEqBase.value(u)) / length(u))

ext/DiffEqBaseUnitfulExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
module DiffEqBaseUnitfulExt
22

33
using DiffEqBase
4-
import DiffEqBase: value
4+
import SciMLBase: unitfulvalue, value
55
using Unitful
66

77
# Support adaptive errors should be errorless for exponentiation

0 commit comments

Comments
 (0)