Skip to content

Commit e787af4

Browse files
committed
simplify partial with suggestions from devmotion
1 parent 9c4ff2e commit e787af4

File tree

1 file changed

+39
-13
lines changed

1 file changed

+39
-13
lines changed

src/diffKernel.jl

Lines changed: 39 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -25,20 +25,33 @@ end
2525

2626
DiffPt(x; partial=()) = DiffPt{length(x)}(x, partial) # convenience constructor
2727

28+
"""
29+
partial(fun, idx)
30+
31+
Return ∂ᵢf where
32+
f = fun
33+
i = idx
34+
"""
35+
function partial(fun, idx)
36+
return x -> FD.derivative(0) do dx
37+
y = similar(x)
38+
y = copyto!(y, x)
39+
y[idx] += dx
40+
fun(y)
41+
end
42+
end
43+
2844
"""
29-
Take the partial derivative of a function `fun` with input dimesion `dim`.
30-
If partials=(i,j), then (∂ᵢ∂ⱼ fun) is returned.
45+
partial(fun, indices...)
46+
47+
Return the partial derivative with respect to all indices, e.g.
48+
```julia
49+
partial(f, i, j) # = ∂ᵢ∂ⱼf
50+
```
3151
"""
32-
function partial(fun, dim, partials=())
33-
if !isnothing(local next = iterate(partials))
34-
idx, state = next
35-
return partial(
36-
x -> FD.derivative(0) do dx
37-
fun(x .+ dx * OneHotVector(idx, dim))
38-
end, dim, Base.rest(partials, state)
39-
)
40-
end
41-
return fun
52+
function partial(fun, indices...)
53+
idx, state = iterate(indices)
54+
return partial(partial(fun, idx), Base.rest(indices, state)...)
4255
end
4356

4457
"""
@@ -74,7 +87,20 @@ then julia would not know whether to use
7487
`(::SpecialKernel)(x,y)` or `(::Kernel)(x::DiffPt, y::DiffPt)`
7588
```
7689
=#
77-
for T in [SimpleKernel, Kernel] #subtypes(Kernel)
90+
for T in [
91+
SimpleKernel,
92+
Kernel,
93+
ZeroKernel,
94+
NeuralNetworkKernel,
95+
NeuralKernelNetwork,
96+
GibbsKernel,
97+
WienerKernel,
98+
WienerKernel{2},
99+
TransformedKernel,
100+
KernelSum,
101+
NormalizedKernel,
102+
KernelTensorProduct
103+
] #subtypes(Kernel)
78104
(k::T)(x::DiffPt{Dim}, y::DiffPt{Dim}) where {Dim} = _evaluate(k, x, y)
79105
(k::T)(x::DiffPt{Dim}, y) where {Dim} = _evaluate(k, x, DiffPt(y))
80106
(k::T)(x, y::DiffPt{Dim}) where {Dim} = _evaluate(k, DiffPt(x), y)

0 commit comments

Comments
 (0)