Skip to content

Commit 6e63890

Browse files
authored
Merge pull request #9 from climate-machine/gpr
Implement 'mmstd', 'plot_fit'; add tests
2 parents b841f08 + 8906d52 commit 6e63890

File tree

9 files changed

+160
-9
lines changed

9 files changed

+160
-9
lines changed

examples/GPR/main.jl

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,18 +32,21 @@ GPR.learn!(gprw) # fit GPR with "Const * RBF + White" kernel
3232
#GPR.learn!(gprw, kernel = "matern") # "rbf" and "matern" are supported for now
3333
#GPR.learn!(gprw, kernel = "matern", nu = 1) # Matern's parameter nu; 1.5 by def
3434

35-
mesh = minimum(gpr_data, dims=1)[1] : 0.01 : maximum(gpr_data, dims=1)[1]
35+
mesh, mean, std = GPR.mmstd(gprw) # equispaced mesh; mean and std at mesh points
36+
#mesh, mean, std = GPR.mmstd(gprw, mesh_n = 11) # by default, `mesh_n` is 1001
3637

37-
mean, std = GPR.predict(gprw, mesh, return_std = true)
38+
#mean, std = GPR.predict(gprw, mesh, return_std = true) # your own mesh
3839
#mean = GPR.predict(gprw, mesh) # `return_std` is false by default
3940

4041
################################################################################
4142
# plot section #################################################################
4243
################################################################################
43-
plt.plot(gpr_data[:,1], gpr_data[:,2], "r.", ms = 6, label = "Data points")
44-
plt.plot(mesh, mean, "k", lw = 2.5, label = "GPR mean")
45-
plt.fill_between(mesh, mean - 2*std, mean + 2*std, alpha = 0.4, zorder = 10,
46-
color = "k", label = "95% interval")
44+
#GPR.plot_fit(gprw, plt, plot_95 = true) # by default, `plot_95` is false
45+
46+
# no legend by default, but you can specify yours in the following order:
47+
# data points, subsample, mean, 95% interval
48+
GPR.plot_fit(gprw, plt, plot_95 = true,
49+
label = ["Points", "Training", "Mean", "95% interval"])
4750

4851
plt.legend()
4952
plt.show()

src/GPR.jl

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ Functions that operate on GPR.Wrap struct:
2323
- subsample! (3 methods)
2424
- learn! (1 method)
2525
- predict (1 method)
26+
- mmstd (1 method)
27+
- plot_fit (1 method)
2628
2729
Do *not* set Wrap's variables except for `thrsh`; use setter functions!
2830
"""
@@ -184,6 +186,144 @@ function predict(gprw::Wrap, x; return_std = false)
184186
end
185187
end
186188

189+
"""
190+
Return mesh, mean and st. deviation over the whole data range
191+
192+
Computes min and max of `gprw.data` x-range and returns equispaced mesh with
193+
`mesh_n` number of points, mean and st. deviation computed over that mesh
194+
195+
Parameters:
196+
- gprw: an instance of GPR.Wrap
197+
- mesh_n: number of mesh points (1001 by default)
198+
199+
Returns:
200+
- (m, m, std): mesh, mean and st. deviation
201+
"""
202+
function mmstd(gprw::Wrap; mesh_n = 1001)
203+
mesh = range(minimum(gprw.data, dims=1)[1],
204+
maximum(gprw.data, dims=1)[1],
205+
length = mesh_n)
206+
return (mesh, predict(gprw, mesh, return_std = true)...)
207+
end
208+
209+
"""
210+
Plot mean (and 95% interval) along with data and subsample
211+
212+
The flag `plot_95` controls whether to plot 95% interval; `label` may provide
213+
labels in the following order:
214+
data points, subsample, GP mean, 95% interval (if requested)
215+
216+
If you're using Plots.jl and running it from a file rather than REPL, you need
217+
to wrap the call:
218+
`display(GPR.plot_fit(gprw, Plots))`
219+
220+
Parameters:
221+
- gprw: an instance of GPR.Wrap
222+
- plt: a module used for plotting (only PyPlot & Plots supported)
223+
- plot_95: boolean flag, whether to plot 95% confidence interval
224+
- label: a 3- or 4-tuple or vector of strings (no label by default)
225+
"""
226+
function plot_fit(gprw::Wrap, plt; plot_95 = false, label = nothing)
227+
if !gprw.__data_set
228+
println(warn("plot_fit"), "data is not set, nothing to plot")
229+
return
230+
end
231+
is_pyplot = (Symbol(plt) == :PyPlot)
232+
is_plots = (Symbol(plt) == :Plots)
233+
if !is_pyplot && !is_plots
234+
println(warn("plot_fit"), "only PyPlot & Plots are supported; not plotting")
235+
return
236+
end
237+
238+
# set `cols` Dict with colors of the plots
239+
alpha_95 = 0.3 # alpha channel for shaded region, i.e. 95% interval
240+
cols = Dict{String, Any}()
241+
cols["mean"] = "black"
242+
if is_pyplot
243+
cols["data"] = "tab:gray"
244+
cols["sub"] = "tab:red"
245+
cols["shade"] = (0, 0, 0, alpha_95)
246+
elseif is_plots
247+
cols["data"] = "#7f7f7f" # tab10 gray
248+
cols["sub"] = "#d62728" # tab10 red
249+
end
250+
251+
# set keyword argument dictionaries for plotting functions
252+
kwargs_data = Dict{Symbol, Any}()
253+
kwargs_sub = Dict{Symbol, Any}()
254+
kwargs_mean = Dict{Symbol, Any}()
255+
kwargs_95 = Dict{Symbol, Any}()
256+
kwargs_aux = Dict{Symbol, Any}()
257+
258+
kwargs_data[:color] = cols["data"]
259+
kwargs_sub[:color] = cols["sub"]
260+
kwargs_mean[:color] = cols["mean"]
261+
kwargs_mean[:lw] = 2.5
262+
if is_pyplot
263+
kwargs_data[:ms] = 4
264+
kwargs_sub[:ms] = 4
265+
kwargs_95[:facecolor] = cols["shade"]
266+
kwargs_95[:edgecolor] = cols["mean"]
267+
kwargs_95[:lw] = 0.5
268+
kwargs_95[:zorder] = 10
269+
elseif is_plots
270+
kwargs_data[:ms] = 2
271+
kwargs_sub[:ms] = 2
272+
kwargs_95[:color] = cols["mean"]
273+
kwargs_95[:fillalpha] = alpha_95
274+
kwargs_95[:lw] = 2.5
275+
kwargs_95[:z] = 10
276+
kwargs_aux[:color] = cols["mean"]
277+
kwargs_aux[:lw] = 0.5
278+
kwargs_aux[:label] = ""
279+
end
280+
281+
if label != nothing
282+
kwargs_data[:label] = label[1]
283+
kwargs_sub[:label] = label[2]
284+
kwargs_mean[:label] = label[3]
285+
if is_pyplot
286+
kwargs_95[:label] = label[4]
287+
elseif is_plots
288+
kwargs_95[:label] = label[3]
289+
end
290+
elseif is_plots
291+
kwargs_data[:label] = ""
292+
kwargs_sub[:label] = ""
293+
kwargs_mean[:label] = ""
294+
kwargs_95[:label] = ""
295+
end
296+
297+
298+
mesh, mean, std = mmstd(gprw)
299+
300+
# plot data, subsample and mean
301+
if is_pyplot
302+
plt.plot(gprw.data[:,1], gprw.data[:,2], "."; kwargs_data...)
303+
plt.plot(gprw.subsample[:,1], gprw.subsample[:,2], "."; kwargs_sub...)
304+
if plot_95
305+
plt.fill_between(mesh,
306+
mean - 1.96 * std,
307+
mean + 1.96 * std;
308+
kwargs_95...)
309+
end
310+
plt.plot(mesh, mean; kwargs_mean...)
311+
elseif is_plots
312+
plt.scatter!(gprw.data[:,1], gprw.data[:,2]; kwargs_data...)
313+
plt.scatter!(gprw.subsample[:,1], gprw.subsample[:,2]; kwargs_sub...)
314+
if plot_95
315+
plt.plot!(mesh,
316+
mean,
317+
ribbon = (1.96 * std, 1.96 * std);
318+
kwargs_95...)
319+
plt.plot!(mesh, mean - 1.96 * std; kwargs_aux...)
320+
plt.plot!(mesh, mean + 1.96 * std; kwargs_aux...)
321+
else
322+
plt.plot!(mesh, mean; kwargs_mean...)
323+
end
324+
end
325+
end
326+
187327
################################################################################
188328
# convenience functions ########################################################
189329
################################################################################

test/GPR/data/matern_05_mesh.npy

7.94 KB
Binary file not shown.

test/GPR/data/matern_def_mean.npy

8 Bytes
Binary file not shown.

test/GPR/data/matern_def_std.npy

8 Bytes
Binary file not shown.

test/GPR/data/mesh.npy

-40 Bytes
Binary file not shown.

test/GPR/data/rbf_mean.npy

8 Bytes
Binary file not shown.

test/GPR/data/rbf_std.npy

8 Bytes
Binary file not shown.

test/GPR/runtests.jl

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ const rbf_mean = NPZ.npzread(joinpath(data_dir, "rbf_mean.npy"))
1111
const rbf_std = NPZ.npzread(joinpath(data_dir, "rbf_std.npy"))
1212
const matern_def_mean = NPZ.npzread(joinpath(data_dir, "matern_def_mean.npy"))
1313
const matern_def_std = NPZ.npzread(joinpath(data_dir, "matern_def_std.npy"))
14+
const matern_05_mesh = NPZ.npzread(joinpath(data_dir, "matern_05_mesh.npy"))
1415
const matern_05_mean = NPZ.npzread(joinpath(data_dir, "matern_05_mean.npy"))
1516
const matern_05_std = NPZ.npzread(joinpath(data_dir, "matern_05_std.npy"))
1617

@@ -88,9 +89,15 @@ thrsh = gprw.thrsh
8889

8990
ytrue = xmesh.^2
9091
ypred = GPR.predict(gprw, xmesh)
91-
@test isapprox(ytrue, ypred, atol=1e-5, norm=inf_norm)
92+
@test isapprox(ytrue, ypred, atol=1e-4, norm=inf_norm)
93+
94+
mesh_, mean_, std_ = GPR.mmstd(gprw, mesh_n=101)
95+
@test isapprox(mesh_, xmesh, atol=1e-8, norm=inf_norm)
96+
@test isapprox(ypred, mean_, atol=1e-8, norm=inf_norm)
9297

9398
mean, std = GPR.predict(gprw, xmesh, return_std = true)
99+
@test isapprox(mean, mean_, atol=1e-8, norm=inf_norm)
100+
@test isapprox(std, std_, atol=1e-8, norm=inf_norm)
94101
@test ndims(mean) == 1
95102
@test ndims(std) == 1
96103
@test size(std,1) == size(mean,1)
@@ -107,9 +114,10 @@ GPR.set_data!(gprw, gpr_data)
107114
gprw.thrsh = -1
108115
@testset "non-synthetic testing" begin
109116
GPR.learn!(gprw)
110-
mean, std = GPR.predict(gprw, gpr_mesh, return_std = true)
117+
mesh, mean, std = GPR.mmstd(gprw)
111118
@test size(gprw.data,1) == 800
112119
@test size(gprw.subsample,1) == 800
120+
@test isapprox(mesh, gpr_mesh, atol=1e-8, norm=inf_norm)
113121
@test isapprox(mean, rbf_mean, atol=1e-3, norm=inf_norm)
114122
@test isapprox(std, rbf_std, atol=1e-3, norm=inf_norm)
115123

@@ -119,7 +127,7 @@ gprw.thrsh = -1
119127
@test isapprox(std, matern_def_std, atol=1e-3, norm=inf_norm)
120128

121129
GPR.learn!(gprw, kernel = "matern", nu = 0.5)
122-
mean, std = GPR.predict(gprw, gpr_mesh, return_std = true)
130+
mean, std = GPR.predict(gprw, matern_05_mesh, return_std = true)
123131
@test isapprox(mean, matern_05_mean, atol=1e-3, norm=inf_norm)
124132
@test isapprox(std, matern_05_std, atol=1e-3, norm=inf_norm)
125133
end

0 commit comments

Comments
 (0)