Skip to content

Commit f50cd2c

Browse files
author
William J. Wilkinson
authored
Merge pull request #15 from AaltoML/develop
Develop
2 parents af3f625 + f583909 commit f50cd2c

File tree

4 files changed

+43
-6
lines changed

4 files changed

+43
-6
lines changed

README.md

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,50 @@
11
# Bayes-Newton
22

3-
Bayes-Newton is a library for approximate inference in Gaussian processes (GPs) in [JAX](https://github.com/google/jax) (with [objax](https://github.com/google/objax)), built and actively maintained by [Will Wilkinson](https://wil-j-wil.github.io/).
3+
Bayes-Newton is a library for approximate inference in Gaussian processes (GPs) in [JAX](https://github.com/google/jax) (with [objax](https://github.com/google/objax)), built and maintained by [Will Wilkinson](https://wil-j-wil.github.io/).
44

55
Bayes-Newton provides a unifying view of approximate Bayesian inference, and allows for the combination of many models (e.g. GPs, sparse GPs, Markov GPs, sparse Markov GPs) with the inference method of your choice (VI, EP, Laplace, Linearisation). For a full list of the methods implemented scroll down to the bottom of this page.
66

77
The methodology is outlined in the following article:
88
* W.J. Wilkinson, S. Särkkä, and A. Solin (2021): **Bayes-Newton Methods for Approximate Bayesian Inference with PSD Guarantees**. [*arXiv preprint arXiv:2111.01721*](https://arxiv.org/abs/2111.01721).
99

1010
## Installation
11+
12+
Latest (stable) release from PyPI
1113
```bash
1214
pip install bayesnewton
1315
```
1416

15-
## Example
17+
For *development*, you might want to use the latest source from GitHub: In a check-out of the develop branch of the BayesNewton GitHub repository, run
18+
```bash
19+
pip install -e .
20+
```
21+
22+
### Step-by-step: Getting started with the examples
23+
24+
For running the demos or experiments in this repository or building on top of it, you can follow these steps for creating a virtual environment and activating it:
25+
```bash
26+
python3 -m venv venv
27+
source venv/bin/activate
28+
```
29+
30+
Installing all required dependencies for the examples:
31+
```bash
32+
python -m pip install -r requirements.txt
33+
python -m pip install -e .
34+
```
35+
36+
Running the tests requires additionally a specific version of GPflow to test against:
37+
```bash
38+
python -m pip install pytest
39+
python -m pip install tensorflow==2.10 tensorflow-probability==0.18.0 gpflow==2.5.2
40+
```
41+
42+
Run tests
43+
```bash
44+
cd tests; pytest
45+
```
46+
47+
## Simple Example
1648
Given some inputs `x` and some data `y`, you can construct a Bayes-Newton model as follows,
1749
```python
1850
kern = bayesnewton.kernels.Matern52()

bayesnewton/basemodels.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -813,7 +813,7 @@ def predict(self, X=None, R=None, pseudo_lik_params=None):
813813

814814
# if np.squeeze(test_var).ndim > 2: # deal with spatio-temporal case (discard spatial covariance)
815815
if self.spatio_temporal: # deal with spatio-temporal case (discard spatial covariance)
816-
test_var = diag(np.squeeze(test_var))
816+
test_var = diag(test_var)
817817
return np.squeeze(test_mean), np.squeeze(test_var)
818818

819819
def filter_energy(self):

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
jax==0.2.9
22
jaxlib==0.1.60
33
objax==1.3.1
4+
numba
45
numpy
56
matplotlib
67
scipy

tests/test_vs_gpflow_shutters.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -168,9 +168,13 @@ def test_gradient_step(var_f, len_f, var_y):
168168
loss_fn = gpflow_model.training_loss_closure(data)
169169
adam_vars = gpflow_model.trainable_variables
170170
adam_opt.minimize(loss_fn, adam_vars)
171-
gpflow_hypers = np.array([gpflow_model.kernel.lengthscales.numpy()[0],
172-
gpflow_model.kernel.lengthscales.numpy()[1],
173-
gpflow_model.kernel.variance.numpy(),
171+
#gpflow_hypers = np.array([gpflow_model.kernel.lengthscales.numpy()[0],
172+
# gpflow_model.kernel.lengthscales.numpy()[1],
173+
# gpflow_model.kernel.variance.numpy(),
174+
# gpflow_model.likelihood.variance.numpy()])
175+
gpflow_hypers = np.array([gpflow_model.kernel.parameters[0].numpy(),
176+
gpflow_model.kernel.parameters[2].numpy(),
177+
gpflow_model.kernel.parameters[1].numpy(),
174178
gpflow_model.likelihood.variance.numpy()])
175179
print(gpflow_hypers)
176180
print(gpflow_grads)

0 commit comments

Comments
 (0)