Skip to content

Commit a30bac9

Browse files
committed
OS tweaks
1 parent 8f0bb5e commit a30bac9

File tree

2 files changed

+190
-18
lines changed

2 files changed

+190
-18
lines changed

src/discovery/optimal.py

Lines changed: 176 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,167 @@ def __init__(self, gbl):
4343
def params(self):
4444
return self.os_rhosigma.params
4545

46-
def sample_rhosigma_lowrank(self, params, orf=hd_orfa):
47-
# y_i^T K_i^{-1} T_i Phi_{ij} T_j^T K_j^{-1} y_j
48-
# with y_j =
46+
# TO DO: make opQ, and share init code between opQ, Q, and sample
47+
48+
@functools.cached_property
49+
def Q(self):
50+
Nmats, Fmats, Tmats = zip(*[(psl.N.N.N, psl.N.F, psl.gw.F) for psl in self.psls])
51+
52+
LNms = [1.0 / matrix.jnp.sqrt(Nmat) for Nmat in Nmats]
53+
Fts = [LNm[:,None] * Fmat for LNm, Fmat in zip(LNms, Fmats)]
54+
Tts = [LNm[:,None] * Tmat for LNm, Tmat in zip(LNms, Tmats)] # this is GW-only
55+
56+
FFts = [matrix.jnparray(Ft.T @ Ft) for Ft in Fts]
57+
TTts = [matrix.jnparray(Tt.T @ Tt) for Tt in Tts]
58+
FTts = [matrix.jnparray(Ft.T @ Tt) for Ft, Tt in zip(Fts, Tts)]
59+
60+
Phivar = self.psls[0].gw.Phi.getN
61+
Pvars = [psl.N.P_var.getN for psl in self.psls]
62+
63+
ngw = Tts[0].shape[1]
64+
cnt = len(self.psls) * ngw
65+
inds = [slice(i * ngw, (i + 1) * ngw) for i in range(len(self.psls))]
66+
67+
def get_Q(params, orf=hd_orfa):
68+
sPhi = matrix.jnp.sqrt(Phivar(params))
69+
70+
cs = [matrix.jsp.linalg.cho_factor(matrix.jnp.diag(1.0 / Pvar(params)) + FFt) for Pvar, FFt in zip(Pvars, FFts)]
71+
Ss = [TTt - FTt.T @ matrix.jsp.linalg.cho_solve(c, FTt) for c, TTt, FTt in zip(cs, TTts, FTts)]
72+
73+
Ss = [0.5 * (S + S.T) for S in Ss] # ensure symmetry
74+
As = [matrix.jnp.linalg.cholesky(S + (1e-10 * matrix.jnp.trace(S) / S.shape[0]) * matrix.jnp.eye(S.shape[0]))
75+
for S in Ss]
76+
77+
Ds = [sPhi[:,matrix.jnp.newaxis] * S * sPhi[matrix.jnp.newaxis,:] for S in Ss]
78+
bs = [matrix.jnp.trace(Ds[i] @ Ds[j]) for (i,j) in self.pairs]
79+
80+
orfs = orf(matrix.jnparray(self.angles))
81+
# note the 2 to get OS = x^T Q x
82+
denom = 2.0 * matrix.jnp.sqrt(matrix.jnp.sum(orfs**2 * matrix.jnparray(bs)))
83+
84+
Q = matrix.jnpzeros((cnt, cnt))
85+
86+
A_scaled = [sPhi[:, None] * A for A in As]
87+
88+
for w, (i, j) in zip(orfs, self.pairs):
89+
Bij = w * (A_scaled[i].T @ A_scaled[j])
90+
91+
Q = Q.at[inds[i], inds[j]].add(Bij)
92+
Q = Q.at[inds[j], inds[i]].add(Bij.T)
93+
94+
return Q / denom
95+
get_Q.params = self.os_rhosigma.params
96+
97+
return get_Q
98+
99+
@functools.cached_property
100+
def opQ(self):
101+
Nmats, Fmats, Tmats = zip(*[(psl.N.N.N, psl.N.F, psl.gw.F) for psl in self.psls])
102+
103+
LNms = [1.0 / matrix.jnp.sqrt(Nmat) for Nmat in Nmats]
104+
Fts = [LNm[:,None] * Fmat for LNm, Fmat in zip(LNms, Fmats)]
105+
Tts = [LNm[:,None] * Tmat for LNm, Tmat in zip(LNms, Tmats)] # this is GW-only
106+
107+
FFts = [matrix.jnparray(Ft.T @ Ft) for Ft in Fts]
108+
TTts = [matrix.jnparray(Tt.T @ Tt) for Tt in Tts]
109+
FTts = [matrix.jnparray(Ft.T @ Tt) for Ft, Tt in zip(Fts, Tts)]
110+
111+
Phivar = self.psls[0].gw.Phi.getN
112+
Pvars = [psl.N.P_var.getN for psl in self.psls]
113+
114+
ngw = Tts[0].shape[1]
115+
cnt = len(self.psls) * ngw
116+
inds = [slice(i * ngw, (i + 1) * ngw) for i in range(len(self.psls))]
117+
118+
def get_opQ(params, orf=hd_orfa):
119+
sPhi = matrix.jnp.sqrt(Phivar(params))
120+
121+
cs = [matrix.jsp.linalg.cho_factor(matrix.jnp.diag(1.0 / Pvar(params)) + FFt) for Pvar, FFt in zip(Pvars, FFts)]
122+
Ss = [TTt - FTt.T @ matrix.jsp.linalg.cho_solve(c, FTt) for c, TTt, FTt in zip(cs, TTts, FTts)]
123+
124+
Ss = [0.5 * (S + S.T) for S in Ss] # ensure symmetry
125+
As = [matrix.jnp.linalg.cholesky(S + (1e-10 * matrix.jnp.trace(S) / S.shape[0]) * matrix.jnp.eye(S.shape[0]))
126+
for S in Ss]
127+
128+
Ds = [sPhi[:,matrix.jnp.newaxis] * S * sPhi[matrix.jnp.newaxis,:] for S in Ss]
129+
bs = [matrix.jnp.trace(Ds[i] @ Ds[j]) for (i,j) in self.pairs]
130+
131+
orfs = orf(matrix.jnparray(self.angles))
132+
# note the 2 to get OS = x^T Q x
133+
denom = 2.0 * matrix.jnp.sqrt(matrix.jnp.sum(orfs**2 * matrix.jnparray(bs)))
134+
135+
Bs = [sPhi[:, None] * A for A in As] # B_i = diag(sPhi) @ A_i
136+
137+
# currently not traceable; too bad
138+
def op(x):
139+
zs = [B @ x[ii] for B, ii in zip(Bs, inds)]
140+
141+
y = matrix.jnp.zeros_like(x)
142+
for w, (i, j) in zip(orfs, self.pairs):
143+
y = y.at[inds[i]].add((w / denom) * (Bs[i].T @ zs[j]))
144+
y = y.at[inds[j]].add((w / denom) * (Bs[j].T @ zs[i]))
49145

146+
return y
147+
148+
return op
149+
get_opQ.params = self.os_rhosigma.params
150+
151+
return get_opQ
152+
153+
@functools.cached_property
154+
def sample(self):
155+
Nmats, Fmats, Tmats = zip(*[(psl.N.N.N, psl.N.F, psl.gw.F) for psl in self.psls])
156+
157+
LNms = [1.0 / matrix.jnp.sqrt(Nmat) for Nmat in Nmats]
158+
Fts = [LNm[:,None] * Fmat for LNm, Fmat in zip(LNms, Fmats)]
159+
Tts = [LNm[:,None] * Tmat for LNm, Tmat in zip(LNms, Tmats)] # this is GW-only
160+
161+
FFts = [matrix.jnparray(Ft.T @ Ft) for Ft in Fts]
162+
TTts = [matrix.jnparray(Tt.T @ Tt) for Tt in Tts]
163+
FTts = [matrix.jnparray(Ft.T @ Tt) for Ft, Tt in zip(Fts, Tts)]
164+
165+
Phivar = self.psls[0].gw.Phi.getN
166+
Pvars = [psl.N.P_var.getN for psl in self.psls]
167+
168+
ngw = Tts[0].shape[1]
169+
cnt = len(self.psls) * ngw
170+
inds = [slice(i * ngw, (i + 1) * ngw) for i in range(len(self.psls))]
171+
172+
def get_sample(key, params, orf=hd_orfa):
173+
sPhi = matrix.jnp.sqrt(Phivar(params))
174+
175+
# TO DO: should probably close on Ft.T @ Ft, Tt.T @ Tt, and Tt.T @ Ft (and Ft.T @ Tt) rather than on Fts and Tts
176+
cs = [matrix.jsp.linalg.cho_factor(matrix.jnp.diag(1.0 / Pvar(params)) + FFt) for Pvar, FFt in zip(Pvars, FFts)]
177+
Ss = [TTt - FTt.T @ matrix.jsp.linalg.cho_solve(c, FTt) for c, TTt, FTt in zip(cs, TTts, FTts)]
178+
179+
Ss = [0.5 * (S + S.T) for S in Ss] # ensure symmetry
180+
As = [matrix.jnp.linalg.cholesky(S + (1e-10 * matrix.jnp.trace(S) / S.shape[0]) * matrix.jnp.eye(S.shape[0]))
181+
for S in Ss]
182+
183+
Ds = [sPhi[:,matrix.jnp.newaxis] * S * sPhi[matrix.jnp.newaxis,:] for S in Ss]
184+
bs = [matrix.jnp.trace(Ds[i] @ Ds[j]) for (i,j) in self.pairs]
185+
186+
xs = matrix.jnpnormal(key, cnt)
187+
uks = [sPhi * (A @ xs[ind]) for A, ind in zip(As, inds)]
188+
189+
ts = matrix.jnparray([matrix.jnp.dot(uks[i], uks[j].T) for (i,j) in self.pairs])
190+
191+
gwnorm = 10**(2.0 * params[self.gwpar])
192+
rhos = gwnorm * (matrix.jnparray(ts) / matrix.jnparray(bs))
193+
sigmas = gwnorm / matrix.jnp.sqrt(matrix.jnparray(bs))
194+
195+
orfs = orf(matrix.jnparray(self.angles))
196+
197+
os = matrix.jnp.sum(rhos * orfs / sigmas**2) / matrix.jnp.sum(orfs**2 / sigmas**2)
198+
os_sigma = 1.0 / matrix.jnp.sqrt(matrix.jnp.sum(orfs**2 / sigmas**2))
199+
snr = os / os_sigma
200+
201+
return snr
202+
get_sample.params = self.os_rhosigma.params
203+
204+
return get_sample
205+
206+
def sample_rhosigma_lowrank(self, params, orf=hd_orfa):
50207
Phi = self.psls[0].gw.Phi.getN(params)
51208
sPhi = matrix.jnp.sqrt(Phi)
52209

@@ -57,13 +214,25 @@ def sample_rhosigma_lowrank(self, params, orf=hd_orfa):
57214
Tts = [LNm[:,None] * Tmat for LNm, Tmat in zip(LNms, Tmats)] # this is GW-only
58215

59216
cs = [matrix.jsp.linalg.cho_factor(matrix.jnp.diag(1/Pmat) + Ft.T @ Ft) for Pmat, Ft in zip(Pmats, Fts)]
60-
Ss = [Tt.T @ (Tt - Ft @ matrix.jsp.linalg.cho_solve(c, Ft.T @ Tt)) for c, Ft, Tt in zip(cs, Fts, Tts)]
217+
Xs = [Tt - Ft @ matrix.jsp.linalg.cho_solve(c, Ft.T @ Tt) for c, Ft, Tt in zip(cs, Fts, Tts)]
218+
219+
Ss = [Tt.T @ X for Tt, X in zip(Tts, Xs)]
61220

221+
# alternative formulation (numerically unstable?):
222+
# R = chol(Pmat^-1 + Ft.T @ Ft)
223+
# Y = R^-1 @ Ft.T @ Tt
224+
# S = Tt.T @ Tt - Y.T @ Y
225+
#
62226
# Rs = [matrix.jnp.linalg.cholesky(matrix.jnp.diag(1/Pmat) + Ft.T @ Ft, upper=True) for Pmat, Ft in zip(Pmats, Fts)]
63227
# Ys = [matrix.jsp.linalg.solve_triangular(R, Ft.T @ Tt, lower=False) for R, Ft, Tt in zip(Rs, Fts, Tts)]
64228
# Ss = [Tt.T @ Tt - Y.T @ Y for Tt, Y in zip(Tts, Ys)]
65229

66-
As = [matrix.jnp.linalg.cholesky(S, upper=False) for S in Ss]
230+
# with ridge regularization; the simple estimate based on the trace seems fine
231+
# a more precise possibility is eps = matrix.jnp.maximum(0.0, -matrix.jnp.linalg.eigvalsh(S).min())
232+
# + 1e-10 * matrix.jnp.trace(S) / S.shape[0]
233+
Ss = [0.5 * (S + S.T) for S in Ss] # ensure symmetry
234+
As = [matrix.jnp.linalg.cholesky(S + (1e-10 * matrix.jnp.trace(S) / S.shape[0]) * matrix.jnp.eye(S.shape[0]))
235+
for S in Ss]
67236

68237
Ds = [sPhi[:,matrix.jnp.newaxis] * S * sPhi[matrix.jnp.newaxis,:] for S in Ss]
69238
bs = [matrix.jnp.trace(Ds[i] @ Ds[j]) for (i,j) in self.pairs]
@@ -120,17 +289,6 @@ def sample_rhosigma(self, params, orf=hd_orfa):
120289
iPs.append(slice(cnt, cnt + Fmat.shape[1]))
121290
cnt += Fmat.shape[1]
122291

123-
# @jax.jit
124-
# @jax.vmap
125-
# def makets(keys):
126-
# uks = [PsTt @ K1(matrix.jnp.sqrt(Nmat) * matrix.jnpnormal(key, Nmat.shape[0]))[0] +
127-
# PsTtKmFsP @ matrix.jnpnormal(key, Fmat.shape[1])
128-
# for key, K1, Nmat, Fmat, PsTt, PsTtKmFsP in zip(keys, K1s, Nmats, Fmats, PsTts, PsTtKmFsPs)]
129-
130-
# return matrix.jnparray([matrix.jnp.dot(uks[i], uks[j].T) for (i,j) in self.pairs])
131-
132-
# ts = makets(jax.random.split(key, (n, len(self.psls))))
133-
134292
def xs2snrs(xs):
135293
uks = [PsTt @ K1(matrix.jnp.sqrt(Nmat) * xs[iN])[0] + PsTtKmFsP @ xs[iP]
136294
for PsTt, K1, Nmat, iN, PsTtKmFsP, iP in zip(PsTts, K1s, Nmats, iNs, PsTtKmFsPs, iPs)]
@@ -312,6 +470,8 @@ def get_shift(params, phases, orf=hd_orfa):
312470

313471
return get_shift
314472

473+
# TODO: work in progress...
474+
315475
@functools.cached_property
316476
def gx2mat(self):
317477
kernelsolves = [psl.N.make_kernelsolve(psl.N.F, gw.F) for (psl, gw) in zip(self.psls, self.gws)]

src/discovery/signals.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ def getphi(params):
156156
return matrix.VariableGP(matrix.NoiseMatrix1D_var(getphi), Umat)
157157

158158
# nanograv backends
159-
def makegp_ecorr(psr, noisedict={}, enterprise=False, scale=1.0, selection=selection_backend_flags, name='ecorrGP'):
159+
def makegp_ecorr(psr, noisedict={}, enterprise=False, scale=1.0, selection=selection_backend_flags, variable=False, name='ecorrGP'):
160160
log10_ecorrs, Umats = [], []
161161

162162
backend_flags = selection(psr)
@@ -189,7 +189,19 @@ def makegp_ecorr(psr, noisedict={}, enterprise=False, scale=1.0, selection=selec
189189
if all(par in noisedict for par in params):
190190
phi = sum(10.0**(2 * (logscale + noisedict[log10_ecorr])) * pmask for (log10_ecorr, pmask) in zip(log10_ecorrs, pmasks))
191191

192-
return matrix.ConstantGP(matrix.NoiseMatrix1D_novar(phi), Umatall)
192+
if variable:
193+
def getphi(params):
194+
return phi
195+
getphi.params = []
196+
197+
gp = matrix.VariableGP(matrix.NoiseMatrix1D_var(getphi), Umatall)
198+
gp.index = {f'{psr.name}_{name}_coefficients({Umatall.shape[1]})': slice(0,Umatall.shape[1])} # better for cosine
199+
gp.name, gp.pos = psr.name, psr.pos
200+
gp.gpname, gp.gpcommon = name, []
201+
202+
return gp
203+
else:
204+
return matrix.ConstantGP(matrix.NoiseMatrix1D_novar(phi), Umatall)
193205
else:
194206
pmasks = [matrix.jnparray(pmask) for pmask in pmasks]
195207
def getphi(params):

0 commit comments

Comments
 (0)