Skip to content

Commit 1e25a68

Browse files
authored
Demo: Gaussian process regression (#632)
1 parent ecfbb83 commit 1e25a68

File tree

1 file changed

+40
-2
lines changed

1 file changed

+40
-2
lines changed

examples/kernelregression.dx

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1+
import linalg
12
import plot
23

34
-- Conjugate gradients solver
4-
def solve (mat:m=>m=>Float) (b:m=>Float) : m=>Float =
5+
def solve' (mat:m=>m=>Float) (b:m=>Float) : m=>Float =
56
x0 = for i:m. 0.0
67
ax = mat **. x0
78
r0 = b - ax
@@ -16,6 +17,11 @@ def solve (mat:m=>m=>Float) (b:m=>Float) : m=>Float =
1617
(x', r', p')
1718
xOut
1819

20+
def chol_solve (l:LowerTriMat m Float) (b:m=>Float) : m=>Float =
21+
b' = forward_substitute l b
22+
u = transposeLowerToUpper l
23+
backward_substitute u b'
24+
1925
' # Kernel ridge regression
2026

2127
' To learn a function $f_{true}: \mathcal{X} \to \mathbb R$
@@ -40,7 +46,7 @@ ys : Nx=>Float = for i. trueFun xs.i + noise * randn (ixkey k2 i)
4046
-- Kernel ridge regression
4147
def regress (kernel: a -> a -> Float) (xs: Nx=>a) (ys: Nx=>Float) : a -> Float =
4248
gram = for i j. kernel xs.i xs.j + select (i==j) 0.0001 0.0
43-
alpha = solve gram ys
49+
alpha = solve' gram ys
4450
predict = \x. sum for i. alpha.i * kernel xs.i x
4551
predict
4652

@@ -59,3 +65,35 @@ preds = map predict xtest
5965

6066
:html showPlot $ xyPlot xtest preds
6167
> <html output>
68+
69+
' # Gaussian process regression
70+
71+
' GP regression (kriging) works in a similar way. Compared with kernel ridge regression, GP regression assumes Gaussian distributed prior. This, combined
72+
with the Bayes rule, gives the variance of the prediction.
73+
74+
' In this implementation, the conjugate gradient solver is replaced with the
75+
cholesky solver from `lib/linalg.dx` for efficiency.
76+
77+
def gp_regress (kernel: a -> a -> Float) (xs: n=>a) (ys: n=>Float)
78+
: (a -> (Float&Float)) =
79+
noise_var = 0.0001
80+
gram = for i j. kernel xs.i xs.j
81+
c = chol (gram + eye *. noise_var)
82+
alpha = chol_solve c ys
83+
predict = \x.
84+
k' = for i. kernel xs.i x
85+
mu = sum for i. alpha.i * k'.i
86+
alpha' = chol_solve c k'
87+
var = kernel x x + noise_var - sum for i. k'.i * alpha'.i
88+
(mu, var)
89+
predict
90+
91+
gp_predict = gp_regress (rbf 0.2) xs ys
92+
93+
(gp_preds, vars) = unzip (map gp_predict xtest)
94+
95+
:html showPlot $ xycPlot xtest gp_preds (map sqrt vars)
96+
> <html output>
97+
98+
:html showPlot $ xyPlot xtest vars
99+
> <html output>

0 commit comments

Comments
 (0)