|
| 1 | +""" |
| 2 | + WienerKernel{i}() |
| 3 | +
|
| 4 | +i-times integrated Wiener process kernel function. |
| 5 | +
|
| 6 | +- For i=-1, this is just the white noise covariance, see [`WhiteKernel`](@ref). |
| 7 | +- For i= 0, this is the Wiener process covariance, |
| 8 | +- For i= 1, this is the integrated Wiener process covariance (velocity), |
| 9 | +- For i= 2, this is the twice-integrated Wiener process covariance (accel.), |
| 10 | +- For i= 3, this is the thrice-integrated Wiener process covariance, |
| 11 | +
|
| 12 | +where ``κᵢ`` is given by |
| 13 | +
|
| 14 | +```math |
| 15 | + κ₋₁(x, y) = δ(x, y) |
| 16 | + κ₀(x, y) = min(x, y) |
| 17 | +``` |
| 18 | +and for ``i >= 1``, |
| 19 | +```math |
| 20 | + κᵢ(x, y) = 1 / aᵢ * min(x, y)^(2i + 1) + bᵢ * min(x, y)^(i + 1) * |x - y| * rᵢ(x, y), |
| 21 | +``` |
| 22 | + with the coefficients ``aᵢ``, ``bᵢ`` and the residual ``rᵢ(x, y)`` defined as follows: |
| 23 | +```math |
| 24 | + a₁ = 3, b₁ = 1/2, r₁(x, y) = 1, |
| 25 | + a₂ = 20, b₂ = 1/12, r₂(x, y) = x + y - min(x, y) / 2, |
| 26 | + a₃ = 252, b₃ = 1/720, r₃(x, y) = 5 * max(x, y)² + 2 * x * z + 3 * min(x, y)² |
| 27 | +
|
| 28 | +``` |
| 29 | +
|
| 30 | +# References: |
| 31 | +See the paper *Probabilistic ODE Solvers with Runge-Kutta Means* by Schober, Duvenaud and |
| 32 | +Hennig, NIPS, 2014, for more details. |
| 33 | +
|
| 34 | +""" |
| 35 | +struct WienerKernel{I} <: BaseKernel |
| 36 | + function WienerKernel{I}() where I |
| 37 | + @assert I ∈ (-1, 0, 1, 2, 3) "Invalid parameter i=$(I). Should be -1, 0, 1, 2 or 3." |
| 38 | + if I == -1 |
| 39 | + return WhiteKernel() |
| 40 | + end |
| 41 | + return new{I}() |
| 42 | + end |
| 43 | +end |
| 44 | + |
| 45 | +function WienerKernel(;i::Integer=0) |
| 46 | + return WienerKernel{i}() |
| 47 | +end |
| 48 | + |
| 49 | +function (::WienerKernel{0})(x, y) |
| 50 | + X = sqrt(sum(abs2, x)) |
| 51 | + Y = sqrt(sum(abs2, y)) |
| 52 | + return min(X, Y) |
| 53 | +end |
| 54 | + |
| 55 | +function (::WienerKernel{1})(x, y) |
| 56 | + X = sqrt(sum(abs2, x)) |
| 57 | + Y = sqrt(sum(abs2, y)) |
| 58 | + minXY = min(X, Y) |
| 59 | + return 1 / 3 * minXY^3 + 1 / 2 * minXY^2 * euclidean(x, y) |
| 60 | +end |
| 61 | + |
| 62 | +function (::WienerKernel{2})(x, y) |
| 63 | + X = sqrt(sum(abs2, x)) |
| 64 | + Y = sqrt(sum(abs2, y)) |
| 65 | + minXY = min(X, Y) |
| 66 | + return 1 / 20 * minXY^5 + 1 / 12 * minXY^3 * euclidean(x, y) * |
| 67 | + ( X + Y - 1 / 2 * minXY ) |
| 68 | +end |
| 69 | + |
| 70 | +function (::WienerKernel{3})(x, y) |
| 71 | + X = sqrt(sum(abs2, x)) |
| 72 | + Y = sqrt(sum(abs2, y)) |
| 73 | + minXY = min(X, Y) |
| 74 | + return 1 / 252 * minXY^7 + 1 / 720 * minXY^4 * euclidean(x, y) * |
| 75 | + ( 5 * max(X, Y)^2 + 2 * X * Y + 3 * minXY^2 ) |
| 76 | +end |
| 77 | + |
| 78 | +Base.show(io::IO, ::WienerKernel{I}) where I = print(io, I, "-times integrated Wiener kernel") |
0 commit comments