Skip to content

Commit 123db3e

Browse files
dburov190simonbyrne
authored andcommitted
GPRWrap (#7)
* Implement 'GPRTools.jl' * Rename vars & funcs; implement 'subsample!' methods * Add dependencies to 'Project.toml' * Rename GPRTools -> GPRWrap * Prettify output; change default kernel parameters * Reset in 'set_data!'; implement 'predict' * Add tests for 'GPRWrap' * force use of Conda Python * Reduce tolerance in GPRWrap tests * Tweak 'learn!' and 'predict' parameters * Add more GPRWrap tests * Add GPRWrap examples * Rename 'GPRWrap.jl' -> 'GPR.jl' * Move/add docs outside functions/structs * Add warning about including GPR.jl
1 parent b096536 commit 123db3e

15 files changed

+392
-0
lines changed

.travis.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
# Documentation: http://docs.travis-ci.com/user/languages/julia/
22
language: julia
3+
4+
env:
5+
global:
6+
- PYTHON=Conda
7+
38
os:
49
- linux
510
- osx

Project.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,12 @@ version = "0.1.0"
66
[deps]
77
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
88
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
9+
EllipsisNotation = "da5c29d0-fa7d-589e-88eb-ea29b0a81949"
910
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
11+
Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a"
12+
ScikitLearn = "3646fa90-6ef7-5e7e-9f22-8aca16db6324"
1013
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
14+
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1115

1216
[extras]
1317
DifferentialEquations = "0c46a032-eb83-5123-abaf-570d42b7fbaa"

examples/GPR/main.jl

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
#!/usr/bin/julia --
2+
3+
import NPZ
4+
import PyPlot
5+
const plt = PyPlot
6+
7+
include("../../src/GPR.jl")
8+
9+
const gpr_data = NPZ.npzread("../../test/GPR/data/data_points.npy")
10+
11+
################################################################################
12+
# main section #################################################################
13+
################################################################################
14+
# To avoid bloated and annoying warnings from ScikitLearn.jl, run this with
15+
#
16+
# julia --depwarn=no main.jl
17+
#
18+
19+
# create an instance of GPR.Wrap with threshold set to -1
20+
# by convention, negative threshold means use all data, i.e. do not subsample
21+
gprw = GPR.Wrap(thrsh = -1)
22+
#gprw = GPR.Wrap()
23+
24+
GPR.set_data!(gprw, gpr_data)
25+
26+
GPR.subsample!(gprw) # this form is unnecessary, `learn!` will do it anyway
27+
#GPR.subsample!(gprw, 500) # ignore `gprw.thrsh` and subsample 500 points
28+
#GPR.subsample!(gprw, indices = 1:10) # subsample points with indices `indices`
29+
30+
GPR.learn!(gprw) # fit GPR with "Const * RBF + White" kernel
31+
#GPR.learn!(gprw, noise = 1e-8) # `noise` is *non-optimized* additional noise
32+
#GPR.learn!(gprw, kernel = "matern") # "rbf" and "matern" are supported for now
33+
#GPR.learn!(gprw, kernel = "matern", nu = 1) # Matern's parameter nu; 1.5 by def
34+
35+
mesh = minimum(gpr_data, dims=1)[1] : 0.01 : maximum(gpr_data, dims=1)[1]
36+
37+
mean, std = GPR.predict(gprw, mesh, return_std = true)
38+
#mean = GPR.predict(gprw, mesh) # `return_std` is false by default
39+
40+
################################################################################
41+
# plot section #################################################################
42+
################################################################################
43+
plt.plot(gpr_data[:,1], gpr_data[:,2], "r.", ms = 6, label = "Data points")
44+
plt.plot(mesh, mean, "k", lw = 2.5, label = "GPR mean")
45+
plt.fill_between(mesh, mean - 2*std, mean + 2*std, alpha = 0.4, zorder = 10,
46+
color = "k", label = "95% interval")
47+
48+
plt.legend()
49+
plt.show()
50+
51+

src/GPR.jl

Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
module GPR
2+
"""
3+
For the time being, please use `include("src/GPR.jl")` and not `using Solus.GPR`
4+
since there are precompile issues with the backend (scikit-learn)
5+
"""
6+
7+
using Parameters # lets you have defaults for fields
8+
9+
using EllipsisNotation # adds '..' to refer to the rest of array
10+
import ScikitLearn
11+
import StatsBase
12+
const sklearn = ScikitLearn
13+
14+
sklearn.@sk_import gaussian_process : GaussianProcessRegressor
15+
sklearn.@sk_import gaussian_process.kernels : (RBF, Matern, WhiteKernel)
16+
17+
18+
"""
19+
A simple struct to handle Gaussian Process Regression related stuff
20+
21+
Functions that operate on GPR.Wrap struct:
22+
- set_data! (1 method)
23+
- subsample! (3 methods)
24+
- learn! (1 method)
25+
- predict (1 method)
26+
27+
Do *not* set Wrap's variables except for `thrsh`; use setter functions!
28+
"""
29+
@with_kw mutable struct Wrap
30+
thrsh::Int = 500
31+
data = nothing
32+
subsample = nothing
33+
GPR = nothing
34+
__data_set::Bool = false
35+
__subsample_set::Bool = false
36+
end
37+
38+
################################################################################
39+
# GRPWrap-related functions ####################################################
40+
################################################################################
41+
"""
42+
Set `gprw.data` and reset `gprw.subsample` and `gprw.GPR` -- very important!
43+
44+
Parameters:
45+
- gprw: an instance of GPR.Wrap
46+
- data: input data to learn from (at least 2-dimensional)
47+
48+
`data` should be in the following format:
49+
last column: values/labels/y values
50+
first column(s): locations/x values
51+
"""
52+
function set_data!(gprw::Wrap, data::Array{<:Real})
53+
if ndims(data) > 2
54+
println(warn("set_data!"), "ndims(data) > 2; will use the first two dims")
55+
idx = fill(1, ndims(data) - 2)
56+
data = data[:,:,idx...]
57+
elseif ndims(data) < 2
58+
throw(error("set_data!: ndims(data) < 2; cannot proceed"))
59+
end
60+
gprw.data = data
61+
gprw.subsample = nothing
62+
gprw.GPR = nothing
63+
gprw.__data_set = true
64+
gprw.__subsample_set = false
65+
println(name("set_data!"), size(gprw.data,1), " points")
66+
flush(stdout)
67+
end
68+
69+
"""
70+
Subsample `gprw.data` using `indices`
71+
72+
Parameters:
73+
- gprw: an instance of GPR.Wrap
74+
- indices: indices that will be used to subsample `gprw.data`
75+
"""
76+
function subsample!(gprw::Wrap; indices::Union{Array{Int,1}, UnitRange{Int}})
77+
gprw.subsample = @view(gprw.data[indices,..])
78+
gprw.__subsample_set = true
79+
println(name("subsample!"), size(gprw.subsample,1), " subsampled")
80+
flush(stdout)
81+
end
82+
83+
"""
84+
Draw `thrsh` subsamples from `gprw.data`
85+
86+
Parameters:
87+
- gprw: an instance of GPR.Wrap
88+
- thrsh: threshold for the maximum number of points used in subsampling
89+
90+
If `thrsh` > 0 and `thrsh` < number of `gprw.data` points:
91+
subsample `thrsh` points uniformly randomly from `gprw.data`
92+
If `thrsh` > 0 and `thrsh` >= number of `gprw.data` points:
93+
no subsampling, use whole `gprw.data`
94+
If `thrsh` < 0:
95+
no subsampling, use whole `gprw.data`
96+
97+
This function ignores `gprw.thrsh`
98+
"""
99+
function subsample!(gprw::Wrap, thrsh::Int)
100+
if !gprw.__data_set
101+
throw(error("subsample!: 'data' is not set, cannot sample"))
102+
end
103+
if thrsh == 0
104+
throw(error("subsample!: 'thrsh' == 0, cannot sample"))
105+
end
106+
107+
N = size(gprw.data,1)
108+
if thrsh < 0
109+
thrsh = N
110+
end
111+
112+
if N > thrsh
113+
inds = StatsBase.sample(1:N, thrsh, replace = false)
114+
else
115+
inds = 1:N
116+
end
117+
118+
subsample!(gprw, indices = inds)
119+
end
120+
121+
"""
122+
Wrapper for subsample!(gprw::Wrap, thrsh:Int)
123+
"""
124+
function subsample!(gprw::Wrap)
125+
subsample!(gprw, gprw.thrsh)
126+
end
127+
128+
"""
129+
Fit a GP regressor to `gprw.data` that was previously set
130+
131+
Parameters:
132+
- gprw: an instance of GPR.Wrap
133+
- kernel: "rbf" or "matern"; "rbf" by default
134+
- noise: non-optimized noise level for the RBF kernel
135+
(in addition to the optimized one)
136+
- nu: Matern's nu parameter (smoothness of functions)
137+
"""
138+
function learn!(gprw::Wrap; kernel::String = "rbf", noise = 0.5, nu = 1.5)
139+
if !gprw.__subsample_set
140+
println(warn("learn!"), "'subsample' is not set; attempting to set...")
141+
subsample!(gprw)
142+
end
143+
144+
WK = WhiteKernel(1, (1e-10, 10))
145+
if kernel == "matern"
146+
GPR_kernel = 1.0 * Matern(length_scale = 1.0, nu = nu) + WK
147+
else # including "rbf", which is the default
148+
if kernel != "rbf"
149+
println(warn("learn!"), "Kernel '", kernel, "' is not supported; ",
150+
"falling back to RBF")
151+
end
152+
GPR_kernel = 1.0 * RBF(1.0, (1e-10, 1e+6)) + WK
153+
end
154+
155+
gprw.GPR = GaussianProcessRegressor(
156+
kernel = GPR_kernel,
157+
n_restarts_optimizer = 7,
158+
alpha = noise
159+
)
160+
sklearn.fit!(gprw.GPR, gprw.subsample[:,1:end-1], gprw.subsample[:,end])
161+
162+
println(name("learn!"), gprw.GPR.kernel_)
163+
flush(stdout)
164+
end
165+
166+
"""
167+
Return mean (and st. deviation) values
168+
169+
Parameters:
170+
- gprw: an instance of GPR.Wrap
171+
- x: data for prediction
172+
- return_std: boolean flag, whether to return st. deviation
173+
174+
Returns:
175+
- mean: mean of the GP regressor at `x` locations
176+
- (mean, std): mean and st. deviation if `return_std` flag is true
177+
"""
178+
function predict(gprw::Wrap, x; return_std = false)
179+
if ndims(x) == 1
180+
# add an extra dimension to `x` if it's a vector (scikit-learn's whim)
181+
return gprw.GPR.predict(reshape(x, (size(x)...,1)), return_std = return_std)
182+
else
183+
return gprw.GPR.predict(x, return_std = return_std)
184+
end
185+
end
186+
187+
################################################################################
188+
# convenience functions ########################################################
189+
################################################################################
190+
const RPAD = 25
191+
192+
function name(name::AbstractString)
193+
return rpad(name * ":", RPAD)
194+
end
195+
196+
function warn(name::AbstractString)
197+
return rpad("WARNING (" * name * "):", RPAD)
198+
end
199+
200+
end # module
201+
202+

src/Solus.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,6 @@ using Distributions, Statistics, LinearAlgebra, DocStringExtensions
55
include("spaces.jl")
66
include("problems.jl")
77
include("neki.jl")
8+
include("GPR.jl")
89

910
end # module

test/GPR/data/data_points.npy

12.6 KB
Binary file not shown.

test/GPR/data/matern_05_mean.npy

7.89 KB
Binary file not shown.

test/GPR/data/matern_05_std.npy

7.89 KB
Binary file not shown.

test/GPR/data/matern_def_mean.npy

7.89 KB
Binary file not shown.

test/GPR/data/matern_def_std.npy

7.89 KB
Binary file not shown.

0 commit comments

Comments
 (0)