-
Notifications
You must be signed in to change notification settings - Fork 3
Example with two nodes and two datasets
Now we are starting to use a far more complex gprn. However, if you understood the previous examples, this one is only an extension of them. Again let us start by importing the necessary packages
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import AutoMinorLocator
from gpyrn import meanfield, covfunc, meanfunc
We will again use two sine waves in our example.
time = np.linspace(0, 100, 25)
y1 = 20*np.sin(2*np.pi*time / 31)
y1err = np.random.rand(25)
y2 = 25*np.sin(2*np.pi*time / 31 + 0.5*np.pi)
y2err = np.random.rand(25)
As previously, we can plot it
plt.figure()
plt.errorbar(time, y1, y1err, fmt='ob', markersize=7, label='y1')
plt.errorbar(time, y2, y2err, fmt='or', markersize=7, label='y2')
plt.xlabel('Time (days)')
plt.ylabel('Measurements')
plt.legend(loc='upper right', facecolor='white', framealpha=1, edgecolor='black')
plt.grid(which='major', alpha=0.5)
And get something like this

Now we just create our gprn and everything needed
gprn = meanfield.inference(2, time, y1, y1err, y2, y2err)
nodes = [covfunc.Periodic(1, 31, 0.5), covfunc.Matern52(1, 100)]
weight = [covfunc.SquaredExponential(20, 5), covfunc.SquaredExponential(0.1, 10),
covfunc.SquaredExponential(10, 5), covfunc.SquaredExponential(1, 10)]
means = [meanfunc.Constant(0), meanfunc.Constant(0)]
jitter = [0.5, 0.5]
With the increasing number of nodes and datasets things become trickier. Still is easy to use once you understand how to do it.
First on the meanfield.inference function we now have the number 2 because we want to use two nodes. If it was three nodes we would put 3, and so on...
Obviously wanting to use two nodes we need to defined them. It this case the first is a gp with a periodic kernel, the second node is a gp using a Matérn 5/2 kernel.
The weight might be a bit confusing to someone not used to use this package. Since we have two nodes and two dataset we need four weights to connect everything. For everything be correctly connected the order is node-dataset.
What do I mean by that? On the list weight, the first one (covfunc.SquaredExponential(20, 5)) connects the first node to the first dataset. The second weight (covfunc.SquaredExponential(0.1, 10)) connects the first node to the second dataset. The third weight (covfunc.SquaredExponential(10, 5)) connects the second node to the first dataset. The fourth and last weight (covfunc.SquaredExponential(1, 10)) connects the second node to the second dataset. Think on it like you first connect a node to all the datasets, then you connect the following node to all datasets, etc...
If it helps, visually it would be something like this

Of course we also need to define two mean functions and two jitter terms. One for each dataset.
Now let us plot every thing and see what we got. First we calculate the predictive mean, separating each gprn term (separate=True) to see all the gprn components in one big plot.
tstar = np.linspace(time.min(), time.max(), 1000)
a, _, _, b = gprn.Prediction(nodes, weight, means, jitter, tstar, m, v,separate=True)
Then it is just a matter of plotting everything together
fig = plt.figure(constrained_layout=True, figsize=(7, 10))
axs = fig.subplot_mosaic( [['predictive 1', 'node 1'],
['predictive 1', 'node 2'],
['predictive 1', 'weight 1'],
['predictive 2', 'weight 2'],
['predictive 2', 'weight 3'],
['predictive 2', 'weight 4'],],)
axs['predictive 1'].set(xlabel='', ylabel='y1')
axs['predictive 1'].errorbar(time, y1, y1err, fmt= '.k')
axs['predictive 1'].plot(tstar, a[:,0].T, '-r')
axs['predictive 1'].xaxis.set_minor_locator(AutoMinorLocator(5))
axs['predictive 1'].yaxis.set_minor_locator(AutoMinorLocator(5))
axs['predictive 1'].grid(which='major', alpha=0.5)
axs['predictive 1'].grid(which='minor', alpha=0.2)
axs['predictive 2'].set(xlabel='', ylabel='y2')
axs['predictive 2'].errorbar(time, y2, y2err, fmt= '.k')
axs['predictive 2'].plot(tstar, a[:,1].T, '-r')
axs['predictive 2'].xaxis.set_minor_locator(AutoMinorLocator(5))
axs['predictive 2'].yaxis.set_minor_locator(AutoMinorLocator(5))
axs['predictive 2'].grid(which='major', alpha=0.5)
axs['predictive 2'].grid(which='minor', alpha=0.2)
axs['node 1'].set(xlabel='', ylabel='1st Node')
axs['node 1'].plot(tstar, b[0][0].T, '-b')
axs['node 2'].set(xlabel='', ylabel='2nd Node')
axs['node 2'].plot(tstar, b[0][1].T, '-b')
axs['weight 1'].set(xlabel='', ylabel='1st weight')
axs['weight 1'].plot(tstar, b[1][0].T, '-b')
axs['weight 2'].set(xlabel='', ylabel='2nd weight')
axs['weight 2'].plot(tstar, b[1][1].T, '-b')
axs['weight 3'].set(xlabel='', ylabel='3rd weight')
axs['weight 3'].plot(tstar, b[1][2].T, '-b')
axs['weight 4'].set(xlabel='', ylabel='4th weight')
axs['weight 4'].plot(tstar, b[1][3].T, '-b')
In the end we end up with something like this

And now congratulations, you probably know how to use gpyrn!
2018-2021 João Camacho