Skip to content

Commit 8f74495

Browse files
committed
run formatter
1 parent 21980a9 commit 8f74495

File tree

1 file changed

+19
-30
lines changed

1 file changed

+19
-30
lines changed

src/diffKernel.jl

Lines changed: 19 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -19,42 +19,37 @@ for higher order derivatives partial can be any iterable, i.e.
1919
```
2020
"""
2121
struct DiffPt{Dim}
22-
pos # the actual position
23-
partial
22+
pos # the actual position
23+
partial
2424
end
2525

26-
DiffPt(x;partial=()) = DiffPt{length(x)}(x, partial) # convenience constructor
26+
DiffPt(x; partial=()) = DiffPt{length(x)}(x, partial) # convenience constructor
2727

2828
"""
2929
Take the partial derivative of a function `fun` with input dimesion `dim`.
3030
If partials=(i,j), then (∂ᵢ∂ⱼ fun) is returned.
3131
"""
3232
function partial(fun, dim, partials=())
33-
if !isnothing(local next = iterate(partials))
34-
idx, state = next
35-
return partial(
36-
x -> FD.derivative(0) do dx
37-
fun(x .+ dx * OneHotVector(idx, dim))
38-
end,
39-
dim,
40-
Base.rest(partials, state),
41-
)
42-
end
43-
return fun
33+
if !isnothing(local next = iterate(partials))
34+
idx, state = next
35+
return partial(
36+
x -> FD.derivative(0) do dx
37+
fun(x .+ dx * OneHotVector(idx, dim))
38+
end, dim, Base.rest(partials, state)
39+
)
40+
end
41+
return fun
4442
end
4543

4644
"""
4745
Take the partial derivative of a function with two dim-dimensional inputs,
4846
i.e. 2*dim dimensional input
4947
"""
5048
function partial(k, dim; partials_x=(), partials_y=())
51-
local f(x,y) = partial(t -> k(t,y), dim, partials_x)(x)
52-
return (x,y) -> partial(t -> f(x,t), dim, partials_y)(y)
49+
local f(x, y) = partial(t -> k(t, y), dim, partials_x)(x)
50+
return (x, y) -> partial(t -> f(x, t), dim, partials_y)(y)
5351
end
5452

55-
56-
57-
5853
"""
5954
_evaluate(k::T, x::DiffPt{Dim}, y::DiffPt{Dim}) where {Dim, T<:Kernel}
6055
@@ -65,15 +60,10 @@ redirection over `_evaluate` is necessary
6560
unboxes the partial instructions from DiffPt and applies them to k,
6661
evaluates them at the positions of DiffPt
6762
"""
68-
function _evaluate(k::T, x::DiffPt{Dim}, y::DiffPt{Dim}) where {Dim, T<:Kernel}
69-
return partial(
70-
k, Dim,
71-
partials_x=x.partial, partials_y=y.partial
72-
)(x.pos, y.pos)
63+
function _evaluate(k::T, x::DiffPt{Dim}, y::DiffPt{Dim}) where {Dim,T<:Kernel}
64+
return partial(k, Dim; partials_x=x.partial, partials_y=y.partial)(x.pos, y.pos)
7365
end
7466

75-
76-
7767
#=
7868
This is a hack to work around the fact that the `where {T<:Kernel}` clause is
7969
not allowed for the `(::T)(x,y)` syntax. If we were to only implement
@@ -85,8 +75,7 @@ then julia would not know whether to use
8575
```
8676
=#
8777
for T in [SimpleKernel, Kernel] #subtypes(Kernel)
88-
(k::T)(x::DiffPt{Dim}, y::DiffPt{Dim}) where {Dim} = _evaluate(k, x, y)
89-
(k::T)(x::DiffPt{Dim}, y) where {Dim} = _evaluate(k, x, DiffPt(y))
90-
(k::T)(x, y::DiffPt{Dim}) where {Dim} = _evaluate(k, DiffPt(x), y)
78+
(k::T)(x::DiffPt{Dim}, y::DiffPt{Dim}) where {Dim} = _evaluate(k, x, y)
79+
(k::T)(x::DiffPt{Dim}, y) where {Dim} = _evaluate(k, x, DiffPt(y))
80+
(k::T)(x, y::DiffPt{Dim}) where {Dim} = _evaluate(k, DiffPt(x), y)
9181
end
92-

0 commit comments

Comments
 (0)