diff --git a/src/discovery/likelihood.py b/src/discovery/likelihood.py index faba0bc..9fd2e66 100644 --- a/src/discovery/likelihood.py +++ b/src/discovery/likelihood.py @@ -492,11 +492,12 @@ def cond(params): class ArrayLikelihood: - def __init__(self, psls, *, commongp=None, globalgp=None, transform=None): + def __init__(self, psls, *, commongp=None, globalgp=None, transform=None, decenter=False): self.psls = psls self.commongp = commongp self.globalgp = globalgp self.transform = transform + self.decenter = decenter # @functools.cached_property # def cloglast(self): @@ -531,16 +532,59 @@ def loglike(params): else: # merge common GPs and global GP cgp = self.commongp if isinstance(self.commongp, list) else [self.commongp] + gplist = cgp + [self.globalgp] commongp = matrix.VectorCompoundGP(cgp + [self.globalgp]) Ns, self.ys = zip(*[(psl.N, psl.y) for psl in self.psls]) + + # Both this line, and the decentering code below assumes that + # N and F are constants. self.vsm = matrix.VectorWoodburyKernel_varP(Ns, commongp.F, commongp.Phi) + + if self.decenter: + # create decentering transformation + + NmFs, ldNs = zip(*[N.solve_2d(F) for N, F in zip(self.vsm.Ns, self.vsm.Fs)]) + FtNmFs = [F.T @ NmF for F, NmF in zip(self.vsm.Fs, NmFs)] + NmFtys = [NmF.T @ y for NmF, y in zip(NmFs, self.ys)] + FtNmF, NmFty = matrix.jnparray(FtNmFs), matrix.jnparray(NmFtys) + def decenter_transform(params, c): + phis_invs_commongp = [gp.Phi.getN(params)**-1 for gp in (self.commongp if isinstance(self.commongp, list) else [self.commongp])] + # get diagonal piece of the globalGP (decenter using CURN) + if self.globalgp is not None: + phis_invs_globalgp = matrix.jnp.diag(self.globalgp.Phi.getN(params)**-1).reshape((len(self.psls), -1)) + phis_invs = matrix.jnp.concatenate([*phis_invs_commongp, phis_invs_globalgp], axis=1) + else: + phis_invs = matrix.jnp.concatenate([*phis_invs_commongp], axis=1) + i1, i2 = matrix.jnp.diag_indices(phis_invs.shape[1], ndim=2) + + # supposedly, the native batching here should be better + # than vmap. + + cf = matrix.matrix_factor(FtNmF.at[:,i1,i2].add(phis_invs), lower=True) + am = matrix.jsp.linalg.solve_triangular(cf[0], c, trans=1, lower=cf[1]) + mus = matrix.matrix_solve(cf, NmFty) + # jacobian of our transformation | d f^{-1} / d(xi) | = |L|. cf[0] is L^{-1}. + ldL = -matrix.jnp.logdet(cf[0][:,i1,i2]) + + c = am + mus + + return c, ldL + decenter_transform.params = [] + + if hasattr(commongp, 'prior'): self.vsm.prior = commongp.prior if hasattr(commongp, 'index'): self.vsm.index = commongp.index - loglike = self.vsm.make_kernelproduct_gpcomponent(self.ys, transform=self.transform) + if self.transform is not None and self.decenter is True: + raise ValueError("Can't decenter and add a transformation right now.") + + if self.decenter: + loglike = self.vsm.make_kernelproduct_gpcomponent(self.ys, transform=decenter_transform) + else: + loglike = self.vsm.make_kernelproduct_gpcomponent(self.ys, transform=self.transform) return loglike diff --git a/src/discovery/matrix.py b/src/discovery/matrix.py index b6b0e38..9f2e1fe 100644 --- a/src/discovery/matrix.py +++ b/src/discovery/matrix.py @@ -1312,6 +1312,33 @@ def kernelsolve(params): return kernelsolve + def make_kernelsolve_simple(self, y): + # for when there is only one + # GP, and it hasn't been marginalized over + P_var = self.P_var + Nvar = self.N + F = jnparray(self.F) + y = jnparray(y) + P_var_inv = P_var.make_inv() + NmF, ldN = self.N.solve_2d(self.F) + FtNm = NmF.T + FtNmy = FtNm @ y + FtNmF = F.T @ NmF + + Nvar_solve_2d = Nvar.make_solve_2d() + def kernelsolve(params): + Pinv, ldP = P_var_inv(params) + Sigma = Pinv + FtNmF + ch = matrix_factor(Sigma) + b_mean = matrix_solve(ch, FtNmy) + + return b_mean, Sigma + + kernelsolve.params = sorted(self.N.params + P_var.params) + return kernelsolve + + + def make_kernelproduct_vary(self, y): NmF, ldN = self.N.solve_2d(self.F) FtNmF = self.F.T @ NmF @@ -1869,6 +1896,8 @@ def make_kernelproduct_gpcomponent(self, ys, transform=None): FtNmF, NmFty = jnparray(FtNmFs), jnparray(NmFtys) ytNmy, ldN = float(sum(ytNmys)), float(sum(ldNs)) + n_psr = len(FtNmFs) + if isinstance(self.index, list): cvarsall = self.index else: @@ -1890,18 +1919,23 @@ def unfold(c): def kernelproduct(params): c = fold(params) - + ldL = 0.0 if transform is not None: - c, ldL = transform(params, c) + c, tmp_ldL = transform(params, c) + ldL += tmp_ldL params = {**params, **unfold(c)} else: - ldL = 0.0 + ldL += 0.0 + logpr = P_var_prior(params) ret = (-0.5 * ytNmy + jnp.sum(c * NmFty) - 0.5 * jnp.einsum('ij,ijk,ik', c, FtNmF, c) - -0.5 * ldN - logpr + ldL) - return (ret, c) if transform is not None else ret + -0.5 * ldN + logpr + ldL) + if transform: + return (ret, c) + else: + return ret kernelproduct.params = sorted(set(P_var_prior.params + sum([list(cvars) for cvars in cvarsall], []) + diff --git a/src/discovery/signals.py b/src/discovery/signals.py index dfd4933..54b5076 100644 --- a/src/discovery/signals.py +++ b/src/discovery/signals.py @@ -546,7 +546,14 @@ def priorfunc(params): # the jnp.dot handles the "pixel basis" case where the elements of orfmat are n-vectors # and phidiag is an (m x n)-matrix; here n is the number of pixels and m of Fourier components - return jnp.block([[jnp.make2d(jnp.dot(phi, val)) for val in row] for row in orfmat]) + # + if jnp is jax.numpy: + # this seems to speed things up, pre-compilation at least. + tmp = jnp.kron(orfmat, jnp.diag(phi)) + return tmp + else: + return jnp.block([[jnp.make2d(jnp.dot(phi, val)) for val in row] for row in orfmat]) + priorfunc.params = argmap priorfunc.type = jax.Array @@ -562,8 +569,10 @@ def invprior(params): # log |S_ij Gamma_ab| = log (prod_i S_i^npsr) + log prod_i |Gamma_ab| # = npsr * sum_i log S_i + nfreqs |Gamma_ab| - return (jnp.block([[jnp.make2d(val * invphi) for val in row] for row in invorf]), - phi.shape[0] * orflogdet + orfmat.shape[0] * logdetphi) + # return (jnp.block([[jnp.make2d(val * invphi) for val in row] for row in invorf]), + # phi.shape[0] * orflogdet + orfmat.shape[0] * logdetphi) + return (jnp.kron(invorf, jnp.make2d(invphi)), + phi.shape[0] * orflogdet + orfmat.shape[0] * logdetphi) # was -orfmat.shape[0] * jnp.sum(jnp.log(invphidiag))) invprior.params = argmap invprior.type = jax.Array