Skip to content

Commit 90f7e09

Browse files
committed
git issue 335 test passing; concentrate all direct calls to ForwardDiff to one file
1 parent eec6d38 commit 90f7e09

File tree

6 files changed

+30
-11
lines changed

6 files changed

+30
-11
lines changed

src/Interpolations.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -447,9 +447,29 @@ maybe_clamp(itp, xs) = maybe_clamp(BoundsCheckStyle(itp), itp, xs)
447447
maybe_clamp(::NeedsCheck, itp, xs) = map(clamp, xs, lbounds(itp), ubounds(itp))
448448
maybe_clamp(::CheckWillPass, itp, xs) = xs
449449

450+
using ForwardDiff
451+
# TODO
452+
function maybe_clamp(::NeedsCheck, itp, xs::Tuple{Vararg{ForwardDiff.Dual}})
453+
xs_values = just_dual_value.(xs)
454+
clamped_vals = maybe_clamp(NeedsCheck(), itp, xs_values)
455+
apply_partials.(xs, clamped_vals)
456+
end
457+
450458
Base.hash(x::AbstractInterpolation, h::UInt) = Base.hash_uint(3h - objectid(x))
451459
Base.hash(x::AbstractExtrapolation, h::UInt) = Base.hash_uint(3h - objectid(x))
452460

461+
# TODO use this, and define a method in a ForwardDiff package extension
462+
# stip off arbitrary layers of ForwardDiff.Dual, returning the innermost value
463+
just_dual_value(x::Number) = x
464+
just_dual_value(x::ForwardDiff.Dual) = just_dual_value(ForwardDiff.value(x))
465+
466+
# apply partials from arbitrarily nested ForwardDiff.Dual to a value
467+
function apply_partials(x_dual::D, val::Number) where D <: ForwardDiff.Dual
468+
∂s = ForwardDiff.partials(x_dual)
469+
apply_partials(ForwardDiff.value(x_dual), D(val, ∂s))
470+
end
471+
apply_partials(x_dual::Number, val::Number) = val
472+
453473
include("nointerp/nointerp.jl")
454474
include("b-splines/b-splines.jl")
455475
include("gridded/gridded.jl")

src/b-splines/constant.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -67,35 +67,35 @@ to `A[ceil(Int,x)]` without scaling.
6767
Constant
6868

6969
function positions(c::Constant{Previous}, ax, x) # discontinuity occurs at integer locations
70-
x_value = ForwardDiff.value(ForwardDiff.value(x))
70+
x_value = just_dual_value.(x)
7171
xm = floorbounds(x_value, ax)
7272
δx = x - xm
7373
fast_trunc(Int, xm), δx
7474
end
7575
function positions(c::Constant{Next}, ax, x) # discontinuity occurs at integer locations
76-
x_value = ForwardDiff.value(ForwardDiff.value(x))
76+
x_value = just_dual_value.(x)
7777
xm = ceilbounds(x_value, ax)
7878
δx = x - xm
7979
fast_trunc(Int, xm), δx
8080
end
8181
function positions(c::Constant{Nearest}, ax, x) # discontinuity occurs at half-integer locations
82-
x_value = ForwardDiff.value(ForwardDiff.value(x))
82+
x_value = just_dual_value.(x)
8383
xm = roundbounds(x_value, ax)
8484
δx = x - xm
8585
i = fast_trunc(Int, xm)
8686
i, δx
8787
end
8888

8989
function positions(c::Constant{Previous,Periodic{OnCell}}, ax, x)
90-
x_value = ForwardDiff.value(ForwardDiff.value(x))
90+
x_value = just_dual_value.(x)
9191
# We do not use floorbounds because we do not want to add a half at
9292
# the lowerbound to round up.
9393
xm = floor(x_value)
9494
δx = x - xm
9595
modrange(fast_trunc(Int, xm), ax), δx
9696
end
9797
function positions(c::Constant{Next,Periodic{OnCell}}, ax, x) # discontinuity occurs at integer locations
98-
x_value = ForwardDiff.value(ForwardDiff.value(x))
98+
x_value = just_dual_value.(x)
9999
xm = ceilbounds(x_value, ax)
100100
δx = x - xm
101101
modrange(fast_trunc(Int, xm), ax), δx

src/b-splines/indexing.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
itpinfo(itp) = (tcollect(itpflag, itp), axes(itp))
44

55
@inline function (itp::BSplineInterpolation{T,N})(x::Vararg{Number,N}) where {T,N}
6-
@boundscheck (checkbounds(Bool, itp, ForwardDiff.value.(ForwardDiff.value.(x))...) || Base.throw_boundserror(itp, x))
6+
@boundscheck (checkbounds(Bool, itp, just_dual_value.(x)...) || Base.throw_boundserror(itp, x))
77
wis = weightedindexes((value_weights,), itpinfo(itp)..., x)
88
InterpGetindex(itp)[wis...]
99
end

src/b-splines/linear.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
using ForwardDiff # TODO
2-
31
struct Linear{BC<:Union{Throw{OnGrid},Periodic{OnCell}}} <: DegreeBC{1}
42
bc::BC
53
function Linear{BC}(bc::BC=BC()) where BC<:Union{Throw{OnGrid},Periodic{OnCell}}
@@ -43,7 +41,7 @@ a piecewise linear function connecting each pair of neighboring data points.
4341
Linear
4442

4543
function positions(deg::Linear, ax::AbstractUnitRange{<:Integer}, x)
46-
x_value = ForwardDiff.value(ForwardDiff.value(x))
44+
x_value = just_dual_value.(x)
4745
f = floor(x_value)
4846
# When x == last(ax) we want to use the x-1, x pair
4947
f = ifelse(x_value == last(ax), f - oneunit(f), f)

src/monotonic/monotonic.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ function interpolate(
207207
end
208208

209209
function (itp::MonotonicInterpolation)(x::Number)
210-
x_value = ForwardDiff.value(ForwardDiff.value(x))
210+
x_value = just_dual_value.(x)
211211
@boundscheck (checkbounds(Bool, itp, x_value) || Base.throw_boundserror(itp, (x_value,)))
212212
k = searchsortedfirst(itp.knots, x_value)
213213
if k > 1

src/scaling/scaling.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,8 @@ ubound(ax::AbstractRange, ::DegreeBC, ::OnGrid) = last(ax)
7676

7777
# For (), we scale the evaluation point
7878
@propagate_inbounds function (sitp::ScaledInterpolation{T,N})(xs::Vararg{Number,N}) where {T,N}
79-
@boundscheck (checkbounds(Bool, sitp, xs...) || Base.throw_boundserror(sitp, xs))
79+
xs_values = just_dual_value.(xs)
80+
@boundscheck (checkbounds(Bool, sitp, xs_values...) || Base.throw_boundserror(sitp, xs_values))
8081
xl = maybe_clamp(sitp.itp, coordslookup(itpflag(sitp.itp), sitp.ranges, xs))
8182
@inbounds sitp.itp(xl...)
8283
end

0 commit comments

Comments
 (0)