|
| 1 | +module GPR |
| 2 | +""" |
| 3 | +For the time being, please use `include("src/GPR.jl")` and not `using Solus.GPR` |
| 4 | +since there are precompile issues with the backend (scikit-learn) |
| 5 | +""" |
| 6 | + |
| 7 | +using Parameters # lets you have defaults for fields |
| 8 | + |
| 9 | +using EllipsisNotation # adds '..' to refer to the rest of array |
| 10 | +import ScikitLearn |
| 11 | +import StatsBase |
| 12 | +const sklearn = ScikitLearn |
| 13 | + |
| 14 | +sklearn.@sk_import gaussian_process : GaussianProcessRegressor |
| 15 | +sklearn.@sk_import gaussian_process.kernels : (RBF, Matern, WhiteKernel) |
| 16 | + |
| 17 | + |
| 18 | +""" |
| 19 | +A simple struct to handle Gaussian Process Regression related stuff |
| 20 | +
|
| 21 | +Functions that operate on GPR.Wrap struct: |
| 22 | + - set_data! (1 method) |
| 23 | + - subsample! (3 methods) |
| 24 | + - learn! (1 method) |
| 25 | + - predict (1 method) |
| 26 | +
|
| 27 | +Do *not* set Wrap's variables except for `thrsh`; use setter functions! |
| 28 | +""" |
| 29 | +@with_kw mutable struct Wrap |
| 30 | + thrsh::Int = 500 |
| 31 | + data = nothing |
| 32 | + subsample = nothing |
| 33 | + GPR = nothing |
| 34 | + __data_set::Bool = false |
| 35 | + __subsample_set::Bool = false |
| 36 | +end |
| 37 | + |
| 38 | +################################################################################ |
| 39 | +# GRPWrap-related functions #################################################### |
| 40 | +################################################################################ |
| 41 | +""" |
| 42 | +Set `gprw.data` and reset `gprw.subsample` and `gprw.GPR` -- very important! |
| 43 | +
|
| 44 | +Parameters: |
| 45 | + - gprw: an instance of GPR.Wrap |
| 46 | + - data: input data to learn from (at least 2-dimensional) |
| 47 | +
|
| 48 | +`data` should be in the following format: |
| 49 | + last column: values/labels/y values |
| 50 | + first column(s): locations/x values |
| 51 | +""" |
| 52 | +function set_data!(gprw::Wrap, data::Array{<:Real}) |
| 53 | + if ndims(data) > 2 |
| 54 | + println(warn("set_data!"), "ndims(data) > 2; will use the first two dims") |
| 55 | + idx = fill(1, ndims(data) - 2) |
| 56 | + data = data[:,:,idx...] |
| 57 | + elseif ndims(data) < 2 |
| 58 | + throw(error("set_data!: ndims(data) < 2; cannot proceed")) |
| 59 | + end |
| 60 | + gprw.data = data |
| 61 | + gprw.subsample = nothing |
| 62 | + gprw.GPR = nothing |
| 63 | + gprw.__data_set = true |
| 64 | + gprw.__subsample_set = false |
| 65 | + println(name("set_data!"), size(gprw.data,1), " points") |
| 66 | + flush(stdout) |
| 67 | +end |
| 68 | + |
| 69 | +""" |
| 70 | +Subsample `gprw.data` using `indices` |
| 71 | +
|
| 72 | +Parameters: |
| 73 | + - gprw: an instance of GPR.Wrap |
| 74 | + - indices: indices that will be used to subsample `gprw.data` |
| 75 | +""" |
| 76 | +function subsample!(gprw::Wrap; indices::Union{Array{Int,1}, UnitRange{Int}}) |
| 77 | + gprw.subsample = @view(gprw.data[indices,..]) |
| 78 | + gprw.__subsample_set = true |
| 79 | + println(name("subsample!"), size(gprw.subsample,1), " subsampled") |
| 80 | + flush(stdout) |
| 81 | +end |
| 82 | + |
| 83 | +""" |
| 84 | +Draw `thrsh` subsamples from `gprw.data` |
| 85 | +
|
| 86 | +Parameters: |
| 87 | + - gprw: an instance of GPR.Wrap |
| 88 | + - thrsh: threshold for the maximum number of points used in subsampling |
| 89 | +
|
| 90 | +If `thrsh` > 0 and `thrsh` < number of `gprw.data` points: |
| 91 | + subsample `thrsh` points uniformly randomly from `gprw.data` |
| 92 | +If `thrsh` > 0 and `thrsh` >= number of `gprw.data` points: |
| 93 | + no subsampling, use whole `gprw.data` |
| 94 | +If `thrsh` < 0: |
| 95 | + no subsampling, use whole `gprw.data` |
| 96 | +
|
| 97 | +This function ignores `gprw.thrsh` |
| 98 | +""" |
| 99 | +function subsample!(gprw::Wrap, thrsh::Int) |
| 100 | + if !gprw.__data_set |
| 101 | + throw(error("subsample!: 'data' is not set, cannot sample")) |
| 102 | + end |
| 103 | + if thrsh == 0 |
| 104 | + throw(error("subsample!: 'thrsh' == 0, cannot sample")) |
| 105 | + end |
| 106 | + |
| 107 | + N = size(gprw.data,1) |
| 108 | + if thrsh < 0 |
| 109 | + thrsh = N |
| 110 | + end |
| 111 | + |
| 112 | + if N > thrsh |
| 113 | + inds = StatsBase.sample(1:N, thrsh, replace = false) |
| 114 | + else |
| 115 | + inds = 1:N |
| 116 | + end |
| 117 | + |
| 118 | + subsample!(gprw, indices = inds) |
| 119 | +end |
| 120 | + |
| 121 | +""" |
| 122 | +Wrapper for subsample!(gprw::Wrap, thrsh:Int) |
| 123 | +""" |
| 124 | +function subsample!(gprw::Wrap) |
| 125 | + subsample!(gprw, gprw.thrsh) |
| 126 | +end |
| 127 | + |
| 128 | +""" |
| 129 | +Fit a GP regressor to `gprw.data` that was previously set |
| 130 | +
|
| 131 | +Parameters: |
| 132 | + - gprw: an instance of GPR.Wrap |
| 133 | + - kernel: "rbf" or "matern"; "rbf" by default |
| 134 | + - noise: non-optimized noise level for the RBF kernel |
| 135 | + (in addition to the optimized one) |
| 136 | + - nu: Matern's nu parameter (smoothness of functions) |
| 137 | +""" |
| 138 | +function learn!(gprw::Wrap; kernel::String = "rbf", noise = 0.5, nu = 1.5) |
| 139 | + if !gprw.__subsample_set |
| 140 | + println(warn("learn!"), "'subsample' is not set; attempting to set...") |
| 141 | + subsample!(gprw) |
| 142 | + end |
| 143 | + |
| 144 | + WK = WhiteKernel(1, (1e-10, 10)) |
| 145 | + if kernel == "matern" |
| 146 | + GPR_kernel = 1.0 * Matern(length_scale = 1.0, nu = nu) + WK |
| 147 | + else # including "rbf", which is the default |
| 148 | + if kernel != "rbf" |
| 149 | + println(warn("learn!"), "Kernel '", kernel, "' is not supported; ", |
| 150 | + "falling back to RBF") |
| 151 | + end |
| 152 | + GPR_kernel = 1.0 * RBF(1.0, (1e-10, 1e+6)) + WK |
| 153 | + end |
| 154 | + |
| 155 | + gprw.GPR = GaussianProcessRegressor( |
| 156 | + kernel = GPR_kernel, |
| 157 | + n_restarts_optimizer = 7, |
| 158 | + alpha = noise |
| 159 | + ) |
| 160 | + sklearn.fit!(gprw.GPR, gprw.subsample[:,1:end-1], gprw.subsample[:,end]) |
| 161 | + |
| 162 | + println(name("learn!"), gprw.GPR.kernel_) |
| 163 | + flush(stdout) |
| 164 | +end |
| 165 | + |
| 166 | +""" |
| 167 | +Return mean (and st. deviation) values |
| 168 | +
|
| 169 | +Parameters: |
| 170 | + - gprw: an instance of GPR.Wrap |
| 171 | + - x: data for prediction |
| 172 | + - return_std: boolean flag, whether to return st. deviation |
| 173 | +
|
| 174 | +Returns: |
| 175 | + - mean: mean of the GP regressor at `x` locations |
| 176 | + - (mean, std): mean and st. deviation if `return_std` flag is true |
| 177 | +""" |
| 178 | +function predict(gprw::Wrap, x; return_std = false) |
| 179 | + if ndims(x) == 1 |
| 180 | + # add an extra dimension to `x` if it's a vector (scikit-learn's whim) |
| 181 | + return gprw.GPR.predict(reshape(x, (size(x)...,1)), return_std = return_std) |
| 182 | + else |
| 183 | + return gprw.GPR.predict(x, return_std = return_std) |
| 184 | + end |
| 185 | +end |
| 186 | + |
| 187 | +################################################################################ |
| 188 | +# convenience functions ######################################################## |
| 189 | +################################################################################ |
| 190 | +const RPAD = 25 |
| 191 | + |
| 192 | +function name(name::AbstractString) |
| 193 | + return rpad(name * ":", RPAD) |
| 194 | +end |
| 195 | + |
| 196 | +function warn(name::AbstractString) |
| 197 | + return rpad("WARNING (" * name * "):", RPAD) |
| 198 | +end |
| 199 | + |
| 200 | +end # module |
| 201 | + |
| 202 | + |
0 commit comments