|
| 1 | +import theano.tensor as tt |
| 2 | +import pymc3 as pm |
| 3 | +from pymc3.gp.gp import Base |
| 4 | +from pymc3.gp.cov import Covariance |
| 5 | +from pymc3.gp.mean import Constant |
| 6 | +from pymc3.gp.util import (conditioned_vars, |
| 7 | + infer_shape, stabilize, cholesky, solve, solve_lower, solve_upper) |
| 8 | + |
| 9 | + |
| 10 | +__all__ = ["Grid2DLatent"] |
| 11 | + |
| 12 | +def vec(A): |
| 13 | + return tt.flatten(tt.transpose(A))[:, None] |
| 14 | + |
| 15 | + |
| 16 | +def devec(x, r, c): |
| 17 | + return x.reshape((c,r)).T |
| 18 | + #return tt.transpose(tt.reshape(x, (r, c))) |
| 19 | + |
| 20 | + |
| 21 | +def grid_to_full(X1, X2, n1, n2): |
| 22 | + n1 = infer_shape(X1) |
| 23 | + n2 = infer_shape(X2) |
| 24 | + X1 = tt.as_tensor_variable(X1) |
| 25 | + X2 = tt.as_tensor_variable(X2) |
| 26 | + xx1 = tt.repeat(X1, n2, 0) |
| 27 | + xx2 = tt.tile(X2, (n1, 1)) |
| 28 | + return tt.concatenate((xx1, xx2), 1) |
| 29 | + |
| 30 | + |
| 31 | +def kronprod(A1, A2, v, veced=True): |
| 32 | + if veced: |
| 33 | + v = devec(v, A2.shape[0], A1.shape[0]) |
| 34 | + tmp = tt.dot(v, A1) |
| 35 | + tmp = tt.dot(A2, tmp) |
| 36 | + return tt.dot(A2, tt.dot(v, A1)) |
| 37 | + |
| 38 | + |
| 39 | +def kronsolve(A1, A2, v, veced=True, chol=False): |
| 40 | + """ if chol is true, assume A1, A2 chol factors """ |
| 41 | + if veced: |
| 42 | + v = devec(v, A2.shape[0], A1.shape[0]) |
| 43 | + if chol: |
| 44 | + return vec(solve(A2, tt.transpose(solve_upper(tt.transpose(A1), tt.transpose(v))))) |
| 45 | + else: |
| 46 | + return vec(solve(A2, tt.transpose(solve(tt.transpose(A1), tt.transpose(v))))) |
| 47 | + |
| 48 | + |
| 49 | +def make_gridcov_func(self, cov_func1, cov_func2): |
| 50 | + ndim1 = cov_func1.input_dim |
| 51 | + def gridcov_func(X, Xnew=None): |
| 52 | + if Xnew is not None: |
| 53 | + return cov_func1(X[:,:ndim1], Xnew[:,ndim1:]) * cov_func2(X[:,:ndim1], Xnew[:,ndim2:]) |
| 54 | + else: |
| 55 | + return cov_func1(X[:,:ndim1]) * cov_func2(X[:,:ndim1]) |
| 56 | + return gridcov_func |
| 57 | + |
| 58 | + |
| 59 | +class Grid(Base): |
| 60 | + def __init__(self, cov_funcs, mean_func): |
| 61 | + self.cov_func1, self.cov_func2 = cov_funcs |
| 62 | + self.mean_func = mean_func |
| 63 | + # K1 is 5x5 |
| 64 | + # K2 is 3x3 |
| 65 | + # Y is 3x5 |
| 66 | + |
| 67 | +@conditioned_vars(["X", "n", "f"]) |
| 68 | +class Grid2DLatent(Grid): |
| 69 | + def _build_prior(self, name, X, n, reparameterize=True): |
| 70 | + mu = 0.0 |
| 71 | + L1 = cholesky(stabilize(self.cov_func1(X[0]))) |
| 72 | + L2 = cholesky(stabilize(self.cov_func2(X[1]))) |
| 73 | + if reparameterize: |
| 74 | + v = pm.Normal(name + "_rotated_", mu=0.0, sd=1.0, shape=n[0] * n[1]) |
| 75 | + f = pm.Deterministic(name, devec(kronprod(L1, L2, v[:,None]), n[0], n[1])) |
| 76 | + else: |
| 77 | + raise NotImplementedError |
| 78 | + return f |
| 79 | + |
| 80 | + def prior(self, name, X, n=None, reparameterize=True): |
| 81 | + if n is None: |
| 82 | + n = (None, None) |
| 83 | + if len(X) != 2 or len(n) != 2: |
| 84 | + raise ValueError("2d implemented") |
| 85 | + n = (infer_shape(X[0], n[0]), infer_shape(X[1], n[1])) |
| 86 | + f = self._build_prior(name, X, n, reparameterize) |
| 87 | + self.X = (tt.as_tensor_variable(X[0]), tt.as_tensor_variable(X[1])) |
| 88 | + self.n = n |
| 89 | + self.f = f |
| 90 | + return f |
| 91 | + |
| 92 | + def _get_given_vals(self, **given): |
| 93 | + if 'gp' in given: |
| 94 | + cov_total = given['gp'].cov_func |
| 95 | + mean_total = given['gp'].mean_func |
| 96 | + else: |
| 97 | + cov_total = self.cov_func |
| 98 | + mean_total = self.mean_func |
| 99 | + if all(val in given for val in ['X', 'n', 'f']): |
| 100 | + X, n, f = given['X'], given['n'], given['f'] |
| 101 | + if not isinstance(X, tuple) and not isinstance(n, tuple): |
| 102 | + raise ValueError("must provide tuple with each element for dimension") |
| 103 | + else: |
| 104 | + X, n, f = self.X, self.n, self.f |
| 105 | + return X, n, f, cov_total, mean_total |
| 106 | + |
| 107 | + def conditional(self, name, Xnew, n_points=None, given=None): |
| 108 | + givens = self._get_given_vals(**given) |
| 109 | + mu, cov, n_points = self._build_conditional(Xnew, n_points, *givens) |
| 110 | + chol = cholesky(stabilize(cov)) |
| 111 | + return pm.MvNormal(name, mu=mu, chol=chol, shape=n_points) |
| 112 | + |
| 113 | + def _build_conditional(self, Xnew, n_points, X, n, f, cov_total, mean_total): |
| 114 | + Kxx1 = cov_total1(X1) |
| 115 | + Kxx2 = cov_total2(X2) |
| 116 | + L1 = cholesky(stabilize(Kxx1)) |
| 117 | + L2 = cholesky(stabilize(Kxx2)) |
| 118 | + if isinstance(Xnew, tuple) and isinstance(n_points, tuple): |
| 119 | + # if X1new and X2new given sep, Xnew comes in as a tuple |
| 120 | + Kxs = tt.slinalg.kron(self.cov_func1(self.X[0], Xnew[0]), |
| 121 | + self.cov_func2(self.X[1], Xnew[1])) |
| 122 | + Kss = tt.slinalg.kron(self.cov_func1(Xnew[0]), |
| 123 | + self.cov_func2(Xnew[1])) |
| 124 | + Xnew = grid_to_full(Xnew[0], Xnew[1], n_points[0], n_points[1]) |
| 125 | + n_points = np.prod(n_points) |
| 126 | + else: |
| 127 | + # predict given full |
| 128 | + gridcov_func = make_gridcov_func(self.cov_func1, self.cov_func2) |
| 129 | + X = grid_to_full(self.X[0], self.X[1], self.n[0], self.n[1]) |
| 130 | + Kxs = gridcov_func(X, Xnew) |
| 131 | + Kss = gridcov_func(Xnew) |
| 132 | + n_points = infer_shape(Xnew, n_points) |
| 133 | + A = kronsolve(L1, L2, Kxs, veced=False, chol=True) |
| 134 | + v = kronsolve(L1, L2, f, veced=False, chol=True) # f - mean_total(X) |
| 135 | + #mu = self.mean_func(Xnew) + tt.dot(tt.transpose(A), v) |
| 136 | + mu = tt.dot(tt.transpose(A), v) |
| 137 | + cov = Kss - tt.dot(tt.transpose(A), A) |
| 138 | + return mu, cov, n_points |
| 139 | + |
| 140 | + |
| 141 | +@conditioned_vars(["X", "n", "f"]) |
| 142 | +class Grid2DMarginal(Grid): |
| 143 | + |
| 144 | + def _build_marginal_likelihood_logp(self, X, sigma, n): |
| 145 | + mu = 0.0 |
| 146 | + K1 = stabilize(self.cov_func1(X[0])) |
| 147 | + K2 = stabilize(self.cov_func2(X[1])) |
| 148 | + if sigma is None: # use cholesky |
| 149 | + L1 = cholesky(stabilize(self.cov_func1(X[0]))) |
| 150 | + L2 = cholesky(stabilize(self.cov_func2(X[1]))) |
| 151 | + const = n[0] * n[1] * tt.log(2 * np.pi) |
| 152 | + logdet = (n[1] * 2.0 * tt.sum(tt.log(tt.diag(L1))) + |
| 153 | + n[0] * 2.0 * tt.sum(tt.log(tt.diag(L2)))) |
| 154 | + tmp = kronsolve(L1, L2, Y, veced=False, chol=True) |
| 155 | + quad = tt.sum(tt.square(tmp)) |
| 156 | + return -0.5 * (logdet + quad + const) |
| 157 | + else: # use eigh |
| 158 | + S1, Q1 = tt.nlinalg.eigh(K1) |
| 159 | + S2, Q2 = tt.nlinalg.eigh(K2) |
| 160 | + W = np.kron(S2, S1) + tt.square(sigma) |
| 161 | + Qinvy = kronsolve(Q1, Q2, Y, veced=False, chol=False) |
| 162 | + const = n[0] * n[1] * tt.log(2.0 * np.pi) |
| 163 | + logdet = tt.sum(tt.log(W)) |
| 164 | + quad = tt.sum(tt.square(tt.sqrt(1.0 / W) * Qinvy)) |
| 165 | + return -0.5 * (logdet + quad + const) |
| 166 | + |
| 167 | + |
| 168 | + def marginal_likelihood(self, name, X, sigma=None, n=None): |
| 169 | + if n is None: |
| 170 | + n = (None, None) |
| 171 | + if len(X) != 2 or len(n) != 2: |
| 172 | + raise ValueError("2d implemented") |
| 173 | + n = (infer_shape(X[0], n[0]), infer_shape(X[1], n[1])) |
| 174 | + logp = lambda Y: self._build_marginal_likelihood(X, Y, sigma, n) |
| 175 | + self.X = (tt.as_tensor_variable(X[0]), tt.as_tensor_variable(X[1])) |
| 176 | + self.n = n |
| 177 | + return f |
| 178 | + |
| 179 | + |
| 180 | + |
| 181 | + |
| 182 | + |
| 183 | + |
| 184 | + |
| 185 | + |
| 186 | + |
| 187 | + |
| 188 | + |
| 189 | + |
0 commit comments