Skip to content
Draft
29 changes: 20 additions & 9 deletions src/diffKernel.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
using OneHotArrays: OneHotVector
import ForwardDiff as FD
import LinearAlgebra as LA

Expand Down Expand Up @@ -33,18 +32,30 @@ function DiffPt(x::T, partials::NTuple{Order,KeyT}) where {T,Order,KeyT}
return DiffPt{Order,KeyT,T}(x, partials)
end

partial(func) = func
function partial(func, idx::Int)
return x -> FD.derivative(0) do dx
return func(x .+ dx * OneHotVector(idx, length(x)))
"""
tangentCurve(x₀, i::IndexType)
returns the function (t ↦ x₀ + teᵢ) where eᵢ is the unit vector at index i
"""
function tangentCurve(x0::AbstractArray{N,T}, idx::IndexType) where {N, T}
return t -> begin
x = similar(x0)
copyto!(x, x0)
x[idx] +=t
return x
end
end
function partial(func, partials::Int...)
function tangentCurve(x0::Number, ::IndexType)
return t -> x0 + t
end

partial(func) = func
function partial(func, idx::IndexType)
return x -> FD.derivative(func ∘ tangentCurve(x, idx), 0)
end
function partial(func, partials::IndexType...)
idx, state = iterate(partials)
return partial(
x -> FD.derivative(0) do dx
return func(x .+ dx * OneHotVector(idx, length(x)))
end,
x -> FD.derivative(func ∘ tangentCurve(x, idx), 0),
Base.rest(partials, state)...,
)
end
Expand Down