Skip to content

Commit 953f965

Browse files
committed
put new functionality in a package extention
1 parent 90f7e09 commit 953f965

File tree

3 files changed

+29
-20
lines changed

3 files changed

+29
-20
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,17 @@ WoodburyMatrices = "efce3f68-66dc-5838-9240-27a6d6f5f9b6"
1919

2020
[weakdeps]
2121
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
22+
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
2223

2324
[extensions]
2425
InterpolationsUnitfulExt = "Unitful"
26+
InterpolationsForwardDiffExt = "ForwardDiff"
2527

2628
[compat]
2729
Adapt = "2, 3, 4.0"
2830
AxisAlgorithms = "0.3, 1"
2931
ChainRulesCore = "0.10, 1.0, 1.2, 1.3"
30-
ForwardDiff = "1.0.1"
32+
ForwardDiff = "0.10, 1.0"
3133
OffsetArrays = "0.10, 0.11, 1.0.1"
3234
Ratios = "0.3, 0.4"
3335
Requires = "1.1"

ext/InterpolationsForwardDiffExt.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
module InterpolationsForwardDiffExt
2+
3+
import Interpolations
4+
using ForwardDiff
5+
6+
# this strips arbitrary layers of ForwardDiff.Dual, returning the innermost value
7+
Interpolations.just_dual_value(x::ForwardDiff.Dual) = Interpolations.just_dual_value(ForwardDiff.value(x))
8+
9+
function Interpolations.maybe_clamp(::Interpolations.NeedsCheck, itp, xs::Tuple{Vararg{ForwardDiff.Dual}})
10+
xs_values = Interpolations.just_dual_value.(xs)
11+
clamped_vals = Interpolations.maybe_clamp(Interpolations.NeedsCheck(), itp, xs_values)
12+
apply_partials.(xs, clamped_vals)
13+
end
14+
15+
# apply partials from arbitrarily nested ForwardDiff.Dual to a value
16+
# used in maybe_clamp, above
17+
function apply_partials(x_dual::D, val::Number) where D <: ForwardDiff.Dual
18+
∂s = ForwardDiff.partials(x_dual)
19+
apply_partials(ForwardDiff.value(x_dual), D(val, ∂s))
20+
end
21+
apply_partials(_::Number, val::Number) = val
22+
23+
end

src/Interpolations.jl

Lines changed: 3 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -447,29 +447,13 @@ maybe_clamp(itp, xs) = maybe_clamp(BoundsCheckStyle(itp), itp, xs)
447447
maybe_clamp(::NeedsCheck, itp, xs) = map(clamp, xs, lbounds(itp), ubounds(itp))
448448
maybe_clamp(::CheckWillPass, itp, xs) = xs
449449

450-
using ForwardDiff
451-
# TODO
452-
function maybe_clamp(::NeedsCheck, itp, xs::Tuple{Vararg{ForwardDiff.Dual}})
453-
xs_values = just_dual_value.(xs)
454-
clamped_vals = maybe_clamp(NeedsCheck(), itp, xs_values)
455-
apply_partials.(xs, clamped_vals)
456-
end
450+
# this strips arbitrary layers of ForwardDiff.Dual, returning the innermost value
451+
# it's other methods are defined in InterpolationsForwardDiffExt.jl
452+
just_dual_value(x::Number) = x
457453

458454
Base.hash(x::AbstractInterpolation, h::UInt) = Base.hash_uint(3h - objectid(x))
459455
Base.hash(x::AbstractExtrapolation, h::UInt) = Base.hash_uint(3h - objectid(x))
460456

461-
# TODO use this, and define a method in a ForwardDiff package extension
462-
# stip off arbitrary layers of ForwardDiff.Dual, returning the innermost value
463-
just_dual_value(x::Number) = x
464-
just_dual_value(x::ForwardDiff.Dual) = just_dual_value(ForwardDiff.value(x))
465-
466-
# apply partials from arbitrarily nested ForwardDiff.Dual to a value
467-
function apply_partials(x_dual::D, val::Number) where D <: ForwardDiff.Dual
468-
∂s = ForwardDiff.partials(x_dual)
469-
apply_partials(ForwardDiff.value(x_dual), D(val, ∂s))
470-
end
471-
apply_partials(x_dual::Number, val::Number) = val
472-
473457
include("nointerp/nointerp.jl")
474458
include("b-splines/b-splines.jl")
475459
include("gridded/gridded.jl")

0 commit comments

Comments
 (0)