@@ -19,42 +19,37 @@ for higher order derivatives partial can be any iterable, i.e.
19
19
```
20
20
"""
21
21
struct DiffPt{Dim}
22
- pos # the actual position
23
- partial
22
+ pos # the actual position
23
+ partial
24
24
end
25
25
26
- DiffPt (x;partial= ()) = DiffPt {length(x)} (x, partial) # convenience constructor
26
+ DiffPt (x; partial= ()) = DiffPt {length(x)} (x, partial) # convenience constructor
27
27
28
28
"""
29
29
Take the partial derivative of a function `fun` with input dimesion `dim`.
30
30
If partials=(i,j), then (∂ᵢ∂ⱼ fun) is returned.
31
31
"""
32
32
function partial (fun, dim, partials= ())
33
- if ! isnothing (local next = iterate (partials))
34
- idx, state = next
35
- return partial (
36
- x -> FD. derivative (0 ) do dx
37
- fun (x .+ dx * OneHotVector (idx, dim))
38
- end ,
39
- dim,
40
- Base. rest (partials, state),
41
- )
42
- end
43
- return fun
33
+ if ! isnothing (local next = iterate (partials))
34
+ idx, state = next
35
+ return partial (
36
+ x -> FD. derivative (0 ) do dx
37
+ fun (x .+ dx * OneHotVector (idx, dim))
38
+ end , dim, Base. rest (partials, state)
39
+ )
40
+ end
41
+ return fun
44
42
end
45
43
46
44
"""
47
45
Take the partial derivative of a function with two dim-dimensional inputs,
48
46
i.e. 2*dim dimensional input
49
47
"""
50
48
function partial (k, dim; partials_x= (), partials_y= ())
51
- local f (x,y) = partial (t -> k (t,y), dim, partials_x)(x)
52
- return (x,y) -> partial (t -> f (x,t), dim, partials_y)(y)
49
+ local f (x, y) = partial (t -> k (t, y), dim, partials_x)(x)
50
+ return (x, y) -> partial (t -> f (x, t), dim, partials_y)(y)
53
51
end
54
52
55
-
56
-
57
-
58
53
"""
59
54
_evaluate(k::T, x::DiffPt{Dim}, y::DiffPt{Dim}) where {Dim, T<:Kernel}
60
55
@@ -65,15 +60,10 @@ redirection over `_evaluate` is necessary
65
60
unboxes the partial instructions from DiffPt and applies them to k,
66
61
evaluates them at the positions of DiffPt
67
62
"""
68
- function _evaluate (k:: T , x:: DiffPt{Dim} , y:: DiffPt{Dim} ) where {Dim, T<: Kernel }
69
- return partial (
70
- k, Dim,
71
- partials_x= x. partial, partials_y= y. partial
72
- )(x. pos, y. pos)
63
+ function _evaluate (k:: T , x:: DiffPt{Dim} , y:: DiffPt{Dim} ) where {Dim,T<: Kernel }
64
+ return partial (k, Dim; partials_x= x. partial, partials_y= y. partial)(x. pos, y. pos)
73
65
end
74
66
75
-
76
-
77
67
#=
78
68
This is a hack to work around the fact that the `where {T<:Kernel}` clause is
79
69
not allowed for the `(::T)(x,y)` syntax. If we were to only implement
@@ -85,8 +75,7 @@ then julia would not know whether to use
85
75
```
86
76
=#
87
77
for T in [SimpleKernel, Kernel] # subtypes(Kernel)
88
- (k:: T )(x:: DiffPt{Dim} , y:: DiffPt{Dim} ) where {Dim} = _evaluate (k, x, y)
89
- (k:: T )(x:: DiffPt{Dim} , y) where {Dim} = _evaluate (k, x, DiffPt (y))
90
- (k:: T )(x, y:: DiffPt{Dim} ) where {Dim} = _evaluate (k, DiffPt (x), y)
78
+ (k:: T )(x:: DiffPt{Dim} , y:: DiffPt{Dim} ) where {Dim} = _evaluate (k, x, y)
79
+ (k:: T )(x:: DiffPt{Dim} , y) where {Dim} = _evaluate (k, x, DiffPt (y))
80
+ (k:: T )(x, y:: DiffPt{Dim} ) where {Dim} = _evaluate (k, DiffPt (x), y)
91
81
end
92
-
0 commit comments