Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 46 additions & 2 deletions src/discovery/likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,11 +492,12 @@ def cond(params):


class ArrayLikelihood:
def __init__(self, psls, *, commongp=None, globalgp=None, transform=None):
def __init__(self, psls, *, commongp=None, globalgp=None, transform=None, decenter=False):
self.psls = psls
self.commongp = commongp
self.globalgp = globalgp
self.transform = transform
self.decenter = decenter

# @functools.cached_property
# def cloglast(self):
Expand Down Expand Up @@ -531,16 +532,59 @@ def loglike(params):
else:
# merge common GPs and global GP
cgp = self.commongp if isinstance(self.commongp, list) else [self.commongp]
gplist = cgp + [self.globalgp]
commongp = matrix.VectorCompoundGP(cgp + [self.globalgp])

Ns, self.ys = zip(*[(psl.N, psl.y) for psl in self.psls])

# Both this line, and the decentering code below assumes that
# N and F are constants.
self.vsm = matrix.VectorWoodburyKernel_varP(Ns, commongp.F, commongp.Phi)

if self.decenter:
# create decentering transformation

NmFs, ldNs = zip(*[N.solve_2d(F) for N, F in zip(self.vsm.Ns, self.vsm.Fs)])
FtNmFs = [F.T @ NmF for F, NmF in zip(self.vsm.Fs, NmFs)]
NmFtys = [NmF.T @ y for NmF, y in zip(NmFs, self.ys)]
FtNmF, NmFty = matrix.jnparray(FtNmFs), matrix.jnparray(NmFtys)
def decenter_transform(params, c):
phis_invs_commongp = [gp.Phi.getN(params)**-1 for gp in (self.commongp if isinstance(self.commongp, list) else [self.commongp])]
# get diagonal piece of the globalGP (decenter using CURN)
if self.globalgp is not None:
phis_invs_globalgp = matrix.jnp.diag(self.globalgp.Phi.getN(params)**-1).reshape((len(self.psls), -1))
phis_invs = matrix.jnp.concatenate([*phis_invs_commongp, phis_invs_globalgp], axis=1)
else:
phis_invs = matrix.jnp.concatenate([*phis_invs_commongp], axis=1)
i1, i2 = matrix.jnp.diag_indices(phis_invs.shape[1], ndim=2)

# supposedly, the native batching here should be better
# than vmap.

cf = matrix.matrix_factor(FtNmF.at[:,i1,i2].add(phis_invs), lower=True)
am = matrix.jsp.linalg.solve_triangular(cf[0], c, trans=1, lower=cf[1])
mus = matrix.matrix_solve(cf, NmFty)
# jacobian of our transformation | d f^{-1} / d(xi) | = |L|. cf[0] is L^{-1}.
ldL = -matrix.jnp.logdet(cf[0][:,i1,i2])

c = am + mus

return c, ldL
decenter_transform.params = []


if hasattr(commongp, 'prior'):
self.vsm.prior = commongp.prior
if hasattr(commongp, 'index'):
self.vsm.index = commongp.index

loglike = self.vsm.make_kernelproduct_gpcomponent(self.ys, transform=self.transform)
if self.transform is not None and self.decenter is True:
raise ValueError("Can't decenter and add a transformation right now.")

if self.decenter:
loglike = self.vsm.make_kernelproduct_gpcomponent(self.ys, transform=decenter_transform)
else:
loglike = self.vsm.make_kernelproduct_gpcomponent(self.ys, transform=self.transform)

return loglike

Expand Down
44 changes: 39 additions & 5 deletions src/discovery/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -1312,6 +1312,33 @@ def kernelsolve(params):

return kernelsolve

def make_kernelsolve_simple(self, y):
# for when there is only one
# GP, and it hasn't been marginalized over
P_var = self.P_var
Nvar = self.N
F = jnparray(self.F)
y = jnparray(y)
P_var_inv = P_var.make_inv()
NmF, ldN = self.N.solve_2d(self.F)
FtNm = NmF.T
FtNmy = FtNm @ y
FtNmF = F.T @ NmF

Nvar_solve_2d = Nvar.make_solve_2d()
def kernelsolve(params):
Pinv, ldP = P_var_inv(params)
Sigma = Pinv + FtNmF
ch = matrix_factor(Sigma)
b_mean = matrix_solve(ch, FtNmy)

return b_mean, Sigma

kernelsolve.params = sorted(self.N.params + P_var.params)
return kernelsolve



def make_kernelproduct_vary(self, y):
NmF, ldN = self.N.solve_2d(self.F)
FtNmF = self.F.T @ NmF
Expand Down Expand Up @@ -1869,6 +1896,8 @@ def make_kernelproduct_gpcomponent(self, ys, transform=None):
FtNmF, NmFty = jnparray(FtNmFs), jnparray(NmFtys)
ytNmy, ldN = float(sum(ytNmys)), float(sum(ldNs))

n_psr = len(FtNmFs)

if isinstance(self.index, list):
cvarsall = self.index
else:
Expand All @@ -1890,18 +1919,23 @@ def unfold(c):

def kernelproduct(params):
c = fold(params)

ldL = 0.0
if transform is not None:
c, ldL = transform(params, c)
c, tmp_ldL = transform(params, c)
ldL += tmp_ldL
params = {**params, **unfold(c)}
else:
ldL = 0.0
ldL += 0.0


logpr = P_var_prior(params)

ret = (-0.5 * ytNmy + jnp.sum(c * NmFty) - 0.5 * jnp.einsum('ij,ijk,ik', c, FtNmF, c)
-0.5 * ldN - logpr + ldL)
return (ret, c) if transform is not None else ret
-0.5 * ldN + logpr + ldL)
if transform:
return (ret, c)
else:
return ret

kernelproduct.params = sorted(set(P_var_prior.params +
sum([list(cvars) for cvars in cvarsall], []) +
Expand Down
15 changes: 12 additions & 3 deletions src/discovery/signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,7 +546,14 @@ def priorfunc(params):

# the jnp.dot handles the "pixel basis" case where the elements of orfmat are n-vectors
# and phidiag is an (m x n)-matrix; here n is the number of pixels and m of Fourier components
return jnp.block([[jnp.make2d(jnp.dot(phi, val)) for val in row] for row in orfmat])
#
if jnp is jax.numpy:
# this seems to speed things up, pre-compilation at least.
tmp = jnp.kron(orfmat, jnp.diag(phi))
return tmp
else:
return jnp.block([[jnp.make2d(jnp.dot(phi, val)) for val in row] for row in orfmat])

priorfunc.params = argmap
priorfunc.type = jax.Array

Expand All @@ -562,8 +569,10 @@ def invprior(params):
# log |S_ij Gamma_ab| = log (prod_i S_i^npsr) + log prod_i |Gamma_ab|
# = npsr * sum_i log S_i + nfreqs |Gamma_ab|

return (jnp.block([[jnp.make2d(val * invphi) for val in row] for row in invorf]),
phi.shape[0] * orflogdet + orfmat.shape[0] * logdetphi)
# return (jnp.block([[jnp.make2d(val * invphi) for val in row] for row in invorf]),
# phi.shape[0] * orflogdet + orfmat.shape[0] * logdetphi)
return (jnp.kron(invorf, jnp.make2d(invphi)),
phi.shape[0] * orflogdet + orfmat.shape[0] * logdetphi)
# was -orfmat.shape[0] * jnp.sum(jnp.log(invphidiag)))
invprior.params = argmap
invprior.type = jax.Array
Expand Down
Loading