@@ -23,6 +23,8 @@ Functions that operate on GPR.Wrap struct:
23
23
- subsample! (3 methods)
24
24
- learn! (1 method)
25
25
- predict (1 method)
26
+ - mmstd (1 method)
27
+ - plot_fit (1 method)
26
28
27
29
Do *not* set Wrap's variables except for `thrsh`; use setter functions!
28
30
"""
@@ -184,6 +186,144 @@ function predict(gprw::Wrap, x; return_std = false)
184
186
end
185
187
end
186
188
189
+ """
190
+ Return mesh, mean and st. deviation over the whole data range
191
+
192
+ Computes min and max of `gprw.data` x-range and returns equispaced mesh with
193
+ `mesh_n` number of points, mean and st. deviation computed over that mesh
194
+
195
+ Parameters:
196
+ - gprw: an instance of GPR.Wrap
197
+ - mesh_n: number of mesh points (1001 by default)
198
+
199
+ Returns:
200
+ - (m, m, std): mesh, mean and st. deviation
201
+ """
202
+ function mmstd (gprw:: Wrap ; mesh_n = 1001 )
203
+ mesh = range (minimum (gprw. data, dims= 1 )[1 ],
204
+ maximum (gprw. data, dims= 1 )[1 ],
205
+ length = mesh_n)
206
+ return (mesh, predict (gprw, mesh, return_std = true )... )
207
+ end
208
+
209
+ """
210
+ Plot mean (and 95% interval) along with data and subsample
211
+
212
+ The flag `plot_95` controls whether to plot 95% interval; `label` may provide
213
+ labels in the following order:
214
+ data points, subsample, GP mean, 95% interval (if requested)
215
+
216
+ If you're using Plots.jl and running it from a file rather than REPL, you need
217
+ to wrap the call:
218
+ `display(GPR.plot_fit(gprw, Plots))`
219
+
220
+ Parameters:
221
+ - gprw: an instance of GPR.Wrap
222
+ - plt: a module used for plotting (only PyPlot & Plots supported)
223
+ - plot_95: boolean flag, whether to plot 95% confidence interval
224
+ - label: a 3- or 4-tuple or vector of strings (no label by default)
225
+ """
226
+ function plot_fit (gprw:: Wrap , plt; plot_95 = false , label = nothing )
227
+ if ! gprw. __data_set
228
+ println (warn (" plot_fit" ), " data is not set, nothing to plot" )
229
+ return
230
+ end
231
+ is_pyplot = (Symbol (plt) == :PyPlot )
232
+ is_plots = (Symbol (plt) == :Plots )
233
+ if ! is_pyplot && ! is_plots
234
+ println (warn (" plot_fit" ), " only PyPlot & Plots are supported; not plotting" )
235
+ return
236
+ end
237
+
238
+ # set `cols` Dict with colors of the plots
239
+ alpha_95 = 0.3 # alpha channel for shaded region, i.e. 95% interval
240
+ cols = Dict {String, Any} ()
241
+ cols[" mean" ] = " black"
242
+ if is_pyplot
243
+ cols[" data" ] = " tab:gray"
244
+ cols[" sub" ] = " tab:red"
245
+ cols[" shade" ] = (0 , 0 , 0 , alpha_95)
246
+ elseif is_plots
247
+ cols[" data" ] = " #7f7f7f" # tab10 gray
248
+ cols[" sub" ] = " #d62728" # tab10 red
249
+ end
250
+
251
+ # set keyword argument dictionaries for plotting functions
252
+ kwargs_data = Dict {Symbol, Any} ()
253
+ kwargs_sub = Dict {Symbol, Any} ()
254
+ kwargs_mean = Dict {Symbol, Any} ()
255
+ kwargs_95 = Dict {Symbol, Any} ()
256
+ kwargs_aux = Dict {Symbol, Any} ()
257
+
258
+ kwargs_data[:color ] = cols[" data" ]
259
+ kwargs_sub[:color ] = cols[" sub" ]
260
+ kwargs_mean[:color ] = cols[" mean" ]
261
+ kwargs_mean[:lw ] = 2.5
262
+ if is_pyplot
263
+ kwargs_data[:ms ] = 4
264
+ kwargs_sub[:ms ] = 4
265
+ kwargs_95[:facecolor ] = cols[" shade" ]
266
+ kwargs_95[:edgecolor ] = cols[" mean" ]
267
+ kwargs_95[:lw ] = 0.5
268
+ kwargs_95[:zorder ] = 10
269
+ elseif is_plots
270
+ kwargs_data[:ms ] = 2
271
+ kwargs_sub[:ms ] = 2
272
+ kwargs_95[:color ] = cols[" mean" ]
273
+ kwargs_95[:fillalpha ] = alpha_95
274
+ kwargs_95[:lw ] = 2.5
275
+ kwargs_95[:z ] = 10
276
+ kwargs_aux[:color ] = cols[" mean" ]
277
+ kwargs_aux[:lw ] = 0.5
278
+ kwargs_aux[:label ] = " "
279
+ end
280
+
281
+ if label != nothing
282
+ kwargs_data[:label ] = label[1 ]
283
+ kwargs_sub[:label ] = label[2 ]
284
+ kwargs_mean[:label ] = label[3 ]
285
+ if is_pyplot
286
+ kwargs_95[:label ] = label[4 ]
287
+ elseif is_plots
288
+ kwargs_95[:label ] = label[3 ]
289
+ end
290
+ elseif is_plots
291
+ kwargs_data[:label ] = " "
292
+ kwargs_sub[:label ] = " "
293
+ kwargs_mean[:label ] = " "
294
+ kwargs_95[:label ] = " "
295
+ end
296
+
297
+
298
+ mesh, mean, std = mmstd (gprw)
299
+
300
+ # plot data, subsample and mean
301
+ if is_pyplot
302
+ plt. plot (gprw. data[:,1 ], gprw. data[:,2 ], " ." ; kwargs_data... )
303
+ plt. plot (gprw. subsample[:,1 ], gprw. subsample[:,2 ], " ." ; kwargs_sub... )
304
+ if plot_95
305
+ plt. fill_between (mesh,
306
+ mean - 1.96 * std,
307
+ mean + 1.96 * std;
308
+ kwargs_95... )
309
+ end
310
+ plt. plot (mesh, mean; kwargs_mean... )
311
+ elseif is_plots
312
+ plt. scatter! (gprw. data[:,1 ], gprw. data[:,2 ]; kwargs_data... )
313
+ plt. scatter! (gprw. subsample[:,1 ], gprw. subsample[:,2 ]; kwargs_sub... )
314
+ if plot_95
315
+ plt. plot! (mesh,
316
+ mean,
317
+ ribbon = (1.96 * std, 1.96 * std);
318
+ kwargs_95... )
319
+ plt. plot! (mesh, mean - 1.96 * std; kwargs_aux... )
320
+ plt. plot! (mesh, mean + 1.96 * std; kwargs_aux... )
321
+ else
322
+ plt. plot! (mesh, mean; kwargs_mean... )
323
+ end
324
+ end
325
+ end
326
+
187
327
# ###############################################################################
188
328
# convenience functions ########################################################
189
329
# ###############################################################################
0 commit comments