@@ -43,10 +43,167 @@ def __init__(self, gbl):
4343 def params (self ):
4444 return self .os_rhosigma .params
4545
46- def sample_rhosigma_lowrank (self , params , orf = hd_orfa ):
47- # y_i^T K_i^{-1} T_i Phi_{ij} T_j^T K_j^{-1} y_j
48- # with y_j =
46+ # TO DO: make opQ, and share init code between opQ, Q, and sample
47+
48+ @functools .cached_property
49+ def Q (self ):
50+ Nmats , Fmats , Tmats = zip (* [(psl .N .N .N , psl .N .F , psl .gw .F ) for psl in self .psls ])
51+
52+ LNms = [1.0 / matrix .jnp .sqrt (Nmat ) for Nmat in Nmats ]
53+ Fts = [LNm [:,None ] * Fmat for LNm , Fmat in zip (LNms , Fmats )]
54+ Tts = [LNm [:,None ] * Tmat for LNm , Tmat in zip (LNms , Tmats )] # this is GW-only
55+
56+ FFts = [matrix .jnparray (Ft .T @ Ft ) for Ft in Fts ]
57+ TTts = [matrix .jnparray (Tt .T @ Tt ) for Tt in Tts ]
58+ FTts = [matrix .jnparray (Ft .T @ Tt ) for Ft , Tt in zip (Fts , Tts )]
59+
60+ Phivar = self .psls [0 ].gw .Phi .getN
61+ Pvars = [psl .N .P_var .getN for psl in self .psls ]
62+
63+ ngw = Tts [0 ].shape [1 ]
64+ cnt = len (self .psls ) * ngw
65+ inds = [slice (i * ngw , (i + 1 ) * ngw ) for i in range (len (self .psls ))]
66+
67+ def get_Q (params , orf = hd_orfa ):
68+ sPhi = matrix .jnp .sqrt (Phivar (params ))
69+
70+ cs = [matrix .jsp .linalg .cho_factor (matrix .jnp .diag (1.0 / Pvar (params )) + FFt ) for Pvar , FFt in zip (Pvars , FFts )]
71+ Ss = [TTt - FTt .T @ matrix .jsp .linalg .cho_solve (c , FTt ) for c , TTt , FTt in zip (cs , TTts , FTts )]
72+
73+ Ss = [0.5 * (S + S .T ) for S in Ss ] # ensure symmetry
74+ As = [matrix .jnp .linalg .cholesky (S + (1e-10 * matrix .jnp .trace (S ) / S .shape [0 ]) * matrix .jnp .eye (S .shape [0 ]))
75+ for S in Ss ]
76+
77+ Ds = [sPhi [:,matrix .jnp .newaxis ] * S * sPhi [matrix .jnp .newaxis ,:] for S in Ss ]
78+ bs = [matrix .jnp .trace (Ds [i ] @ Ds [j ]) for (i ,j ) in self .pairs ]
79+
80+ orfs = orf (matrix .jnparray (self .angles ))
81+ # note the 2 to get OS = x^T Q x
82+ denom = 2.0 * matrix .jnp .sqrt (matrix .jnp .sum (orfs ** 2 * matrix .jnparray (bs )))
83+
84+ Q = matrix .jnpzeros ((cnt , cnt ))
85+
86+ A_scaled = [sPhi [:, None ] * A for A in As ]
87+
88+ for w , (i , j ) in zip (orfs , self .pairs ):
89+ Bij = w * (A_scaled [i ].T @ A_scaled [j ])
90+
91+ Q = Q .at [inds [i ], inds [j ]].add (Bij )
92+ Q = Q .at [inds [j ], inds [i ]].add (Bij .T )
93+
94+ return Q / denom
95+ get_Q .params = self .os_rhosigma .params
96+
97+ return get_Q
98+
99+ @functools .cached_property
100+ def opQ (self ):
101+ Nmats , Fmats , Tmats = zip (* [(psl .N .N .N , psl .N .F , psl .gw .F ) for psl in self .psls ])
102+
103+ LNms = [1.0 / matrix .jnp .sqrt (Nmat ) for Nmat in Nmats ]
104+ Fts = [LNm [:,None ] * Fmat for LNm , Fmat in zip (LNms , Fmats )]
105+ Tts = [LNm [:,None ] * Tmat for LNm , Tmat in zip (LNms , Tmats )] # this is GW-only
106+
107+ FFts = [matrix .jnparray (Ft .T @ Ft ) for Ft in Fts ]
108+ TTts = [matrix .jnparray (Tt .T @ Tt ) for Tt in Tts ]
109+ FTts = [matrix .jnparray (Ft .T @ Tt ) for Ft , Tt in zip (Fts , Tts )]
110+
111+ Phivar = self .psls [0 ].gw .Phi .getN
112+ Pvars = [psl .N .P_var .getN for psl in self .psls ]
113+
114+ ngw = Tts [0 ].shape [1 ]
115+ cnt = len (self .psls ) * ngw
116+ inds = [slice (i * ngw , (i + 1 ) * ngw ) for i in range (len (self .psls ))]
117+
118+ def get_opQ (params , orf = hd_orfa ):
119+ sPhi = matrix .jnp .sqrt (Phivar (params ))
120+
121+ cs = [matrix .jsp .linalg .cho_factor (matrix .jnp .diag (1.0 / Pvar (params )) + FFt ) for Pvar , FFt in zip (Pvars , FFts )]
122+ Ss = [TTt - FTt .T @ matrix .jsp .linalg .cho_solve (c , FTt ) for c , TTt , FTt in zip (cs , TTts , FTts )]
123+
124+ Ss = [0.5 * (S + S .T ) for S in Ss ] # ensure symmetry
125+ As = [matrix .jnp .linalg .cholesky (S + (1e-10 * matrix .jnp .trace (S ) / S .shape [0 ]) * matrix .jnp .eye (S .shape [0 ]))
126+ for S in Ss ]
127+
128+ Ds = [sPhi [:,matrix .jnp .newaxis ] * S * sPhi [matrix .jnp .newaxis ,:] for S in Ss ]
129+ bs = [matrix .jnp .trace (Ds [i ] @ Ds [j ]) for (i ,j ) in self .pairs ]
130+
131+ orfs = orf (matrix .jnparray (self .angles ))
132+ # note the 2 to get OS = x^T Q x
133+ denom = 2.0 * matrix .jnp .sqrt (matrix .jnp .sum (orfs ** 2 * matrix .jnparray (bs )))
134+
135+ Bs = [sPhi [:, None ] * A for A in As ] # B_i = diag(sPhi) @ A_i
136+
137+ # currently not traceable; too bad
138+ def op (x ):
139+ zs = [B @ x [ii ] for B , ii in zip (Bs , inds )]
140+
141+ y = matrix .jnp .zeros_like (x )
142+ for w , (i , j ) in zip (orfs , self .pairs ):
143+ y = y .at [inds [i ]].add ((w / denom ) * (Bs [i ].T @ zs [j ]))
144+ y = y .at [inds [j ]].add ((w / denom ) * (Bs [j ].T @ zs [i ]))
49145
146+ return y
147+
148+ return op
149+ get_opQ .params = self .os_rhosigma .params
150+
151+ return get_opQ
152+
153+ @functools .cached_property
154+ def sample (self ):
155+ Nmats , Fmats , Tmats = zip (* [(psl .N .N .N , psl .N .F , psl .gw .F ) for psl in self .psls ])
156+
157+ LNms = [1.0 / matrix .jnp .sqrt (Nmat ) for Nmat in Nmats ]
158+ Fts = [LNm [:,None ] * Fmat for LNm , Fmat in zip (LNms , Fmats )]
159+ Tts = [LNm [:,None ] * Tmat for LNm , Tmat in zip (LNms , Tmats )] # this is GW-only
160+
161+ FFts = [matrix .jnparray (Ft .T @ Ft ) for Ft in Fts ]
162+ TTts = [matrix .jnparray (Tt .T @ Tt ) for Tt in Tts ]
163+ FTts = [matrix .jnparray (Ft .T @ Tt ) for Ft , Tt in zip (Fts , Tts )]
164+
165+ Phivar = self .psls [0 ].gw .Phi .getN
166+ Pvars = [psl .N .P_var .getN for psl in self .psls ]
167+
168+ ngw = Tts [0 ].shape [1 ]
169+ cnt = len (self .psls ) * ngw
170+ inds = [slice (i * ngw , (i + 1 ) * ngw ) for i in range (len (self .psls ))]
171+
172+ def get_sample (key , params , orf = hd_orfa ):
173+ sPhi = matrix .jnp .sqrt (Phivar (params ))
174+
175+ # TO DO: should probably close on Ft.T @ Ft, Tt.T @ Tt, and Tt.T @ Ft (and Ft.T @ Tt) rather than on Fts and Tts
176+ cs = [matrix .jsp .linalg .cho_factor (matrix .jnp .diag (1.0 / Pvar (params )) + FFt ) for Pvar , FFt in zip (Pvars , FFts )]
177+ Ss = [TTt - FTt .T @ matrix .jsp .linalg .cho_solve (c , FTt ) for c , TTt , FTt in zip (cs , TTts , FTts )]
178+
179+ Ss = [0.5 * (S + S .T ) for S in Ss ] # ensure symmetry
180+ As = [matrix .jnp .linalg .cholesky (S + (1e-10 * matrix .jnp .trace (S ) / S .shape [0 ]) * matrix .jnp .eye (S .shape [0 ]))
181+ for S in Ss ]
182+
183+ Ds = [sPhi [:,matrix .jnp .newaxis ] * S * sPhi [matrix .jnp .newaxis ,:] for S in Ss ]
184+ bs = [matrix .jnp .trace (Ds [i ] @ Ds [j ]) for (i ,j ) in self .pairs ]
185+
186+ xs = matrix .jnpnormal (key , cnt )
187+ uks = [sPhi * (A @ xs [ind ]) for A , ind in zip (As , inds )]
188+
189+ ts = matrix .jnparray ([matrix .jnp .dot (uks [i ], uks [j ].T ) for (i ,j ) in self .pairs ])
190+
191+ gwnorm = 10 ** (2.0 * params [self .gwpar ])
192+ rhos = gwnorm * (matrix .jnparray (ts ) / matrix .jnparray (bs ))
193+ sigmas = gwnorm / matrix .jnp .sqrt (matrix .jnparray (bs ))
194+
195+ orfs = orf (matrix .jnparray (self .angles ))
196+
197+ os = matrix .jnp .sum (rhos * orfs / sigmas ** 2 ) / matrix .jnp .sum (orfs ** 2 / sigmas ** 2 )
198+ os_sigma = 1.0 / matrix .jnp .sqrt (matrix .jnp .sum (orfs ** 2 / sigmas ** 2 ))
199+ snr = os / os_sigma
200+
201+ return snr
202+ get_sample .params = self .os_rhosigma .params
203+
204+ return get_sample
205+
206+ def sample_rhosigma_lowrank (self , params , orf = hd_orfa ):
50207 Phi = self .psls [0 ].gw .Phi .getN (params )
51208 sPhi = matrix .jnp .sqrt (Phi )
52209
@@ -57,13 +214,25 @@ def sample_rhosigma_lowrank(self, params, orf=hd_orfa):
57214 Tts = [LNm [:,None ] * Tmat for LNm , Tmat in zip (LNms , Tmats )] # this is GW-only
58215
59216 cs = [matrix .jsp .linalg .cho_factor (matrix .jnp .diag (1 / Pmat ) + Ft .T @ Ft ) for Pmat , Ft in zip (Pmats , Fts )]
60- Ss = [Tt .T @ (Tt - Ft @ matrix .jsp .linalg .cho_solve (c , Ft .T @ Tt )) for c , Ft , Tt in zip (cs , Fts , Tts )]
217+ Xs = [Tt - Ft @ matrix .jsp .linalg .cho_solve (c , Ft .T @ Tt ) for c , Ft , Tt in zip (cs , Fts , Tts )]
218+
219+ Ss = [Tt .T @ X for Tt , X in zip (Tts , Xs )]
61220
221+ # alternative formulation (numerically unstable?):
222+ # R = chol(Pmat^-1 + Ft.T @ Ft)
223+ # Y = R^-1 @ Ft.T @ Tt
224+ # S = Tt.T @ Tt - Y.T @ Y
225+ #
62226 # Rs = [matrix.jnp.linalg.cholesky(matrix.jnp.diag(1/Pmat) + Ft.T @ Ft, upper=True) for Pmat, Ft in zip(Pmats, Fts)]
63227 # Ys = [matrix.jsp.linalg.solve_triangular(R, Ft.T @ Tt, lower=False) for R, Ft, Tt in zip(Rs, Fts, Tts)]
64228 # Ss = [Tt.T @ Tt - Y.T @ Y for Tt, Y in zip(Tts, Ys)]
65229
66- As = [matrix .jnp .linalg .cholesky (S , upper = False ) for S in Ss ]
230+ # with ridge regularization; the simple estimate based on the trace seems fine
231+ # a more precise possibility is eps = matrix.jnp.maximum(0.0, -matrix.jnp.linalg.eigvalsh(S).min())
232+ # + 1e-10 * matrix.jnp.trace(S) / S.shape[0]
233+ Ss = [0.5 * (S + S .T ) for S in Ss ] # ensure symmetry
234+ As = [matrix .jnp .linalg .cholesky (S + (1e-10 * matrix .jnp .trace (S ) / S .shape [0 ]) * matrix .jnp .eye (S .shape [0 ]))
235+ for S in Ss ]
67236
68237 Ds = [sPhi [:,matrix .jnp .newaxis ] * S * sPhi [matrix .jnp .newaxis ,:] for S in Ss ]
69238 bs = [matrix .jnp .trace (Ds [i ] @ Ds [j ]) for (i ,j ) in self .pairs ]
@@ -120,17 +289,6 @@ def sample_rhosigma(self, params, orf=hd_orfa):
120289 iPs .append (slice (cnt , cnt + Fmat .shape [1 ]))
121290 cnt += Fmat .shape [1 ]
122291
123- # @jax.jit
124- # @jax.vmap
125- # def makets(keys):
126- # uks = [PsTt @ K1(matrix.jnp.sqrt(Nmat) * matrix.jnpnormal(key, Nmat.shape[0]))[0] +
127- # PsTtKmFsP @ matrix.jnpnormal(key, Fmat.shape[1])
128- # for key, K1, Nmat, Fmat, PsTt, PsTtKmFsP in zip(keys, K1s, Nmats, Fmats, PsTts, PsTtKmFsPs)]
129-
130- # return matrix.jnparray([matrix.jnp.dot(uks[i], uks[j].T) for (i,j) in self.pairs])
131-
132- # ts = makets(jax.random.split(key, (n, len(self.psls))))
133-
134292 def xs2snrs (xs ):
135293 uks = [PsTt @ K1 (matrix .jnp .sqrt (Nmat ) * xs [iN ])[0 ] + PsTtKmFsP @ xs [iP ]
136294 for PsTt , K1 , Nmat , iN , PsTtKmFsP , iP in zip (PsTts , K1s , Nmats , iNs , PsTtKmFsPs , iPs )]
@@ -312,6 +470,8 @@ def get_shift(params, phases, orf=hd_orfa):
312470
313471 return get_shift
314472
473+ # TODO: work in progress...
474+
315475 @functools .cached_property
316476 def gx2mat (self ):
317477 kernelsolves = [psl .N .make_kernelsolve (psl .N .F , gw .F ) for (psl , gw ) in zip (self .psls , self .gws )]
0 commit comments