Skip to content

Commit ffa664f

Browse files
authored
Move __sum to DiffEqBase
1 parent 9dabe45 commit ffa664f

File tree

2 files changed

+9
-9
lines changed

2 files changed

+9
-9
lines changed

ext/DiffEqBaseForwardDiffExt.jl

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -508,12 +508,12 @@ end
508508
@inline ODE_DEFAULT_NORM(u::ForwardDiff.Dual, ::Any) = sqrt(sse(u))
509509
@inline function ODE_DEFAULT_NORM(u::AbstractArray{<:ForwardDiff.Dual{Tag, T}},
510510
t::Any) where {Tag, T}
511-
sqrt(__sum(sse, u; init = sse(zero(T))) / totallength(u))
511+
sqrt(DiffEqBase.__sum(sse, u; init = sse(zero(T))) / totallength(u))
512512
end
513513
@inline ODE_DEFAULT_NORM(u::ForwardDiff.Dual, ::ForwardDiff.Dual) = sqrt(sse(u))
514514
@inline function ODE_DEFAULT_NORM(u::AbstractArray{<:ForwardDiff.Dual{Tag, T}},
515515
::ForwardDiff.Dual) where {Tag, T}
516-
sqrt(__sum(sse, u; init = sse(zero(T))) / totallength(u))
516+
sqrt(DiffEqBase.__sum(sse, u; init = sse(zero(T))) / totallength(u))
517517
end
518518

519519
if !hasmethod(nextfloat, Tuple{ForwardDiff.Dual})
@@ -528,13 +528,6 @@ end
528528

529529
# bisection(f, tup::Tuple{T,T}, t_forward::Bool) where {T<:ForwardDiff.Dual} = find_zero(f, tup, Roots.AlefeldPotraShi())
530530

531-
# Static Arrays don't support the `init` keyword argument for `sum`
532-
@inline __sum(f::F, args...; init, kwargs...) where {F} = sum(f, args...; init, kwargs...)
533-
@inline function __sum(
534-
f::F, a::DiffEqBase.StaticArraysCore.StaticArray...; init, kwargs...) where {F}
535-
return mapreduce(f, +, a...; init, kwargs...)
536-
end
537-
538531
# Differentiation of internal solver
539532

540533
function scalar_nlsolve_ad(prob, alg::InternalITP, args...; kwargs...)

src/utils.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,13 @@ unitfulvalue(x) = x
44
isdistribution(u0) = false
55
sse(x::Number) = abs2(x)
66

7+
# Static Arrays don't support the `init` keyword argument for `sum`
8+
@inline __sum(f::F, args...; init, kwargs...) where {F} = sum(f, args...; init, kwargs...)
9+
@inline function __sum(
10+
f::F, a::StaticArraysCore.StaticArray...; init, kwargs...) where {F}
11+
return mapreduce(f, +, a...; init, kwargs...)
12+
end
13+
714
totallength(x::Number) = 1
815
totallength(x::AbstractArray) = __sum(totallength, x; init = 0)
916

0 commit comments

Comments
 (0)