Skip to content

Commit a1bee2c

Browse files
Shift
1 parent 6866a9e commit a1bee2c

File tree

1 file changed

+55
-21
lines changed

1 file changed

+55
-21
lines changed

examples/cosmo_small_tt.py

Lines changed: 55 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -172,14 +172,14 @@ def run_small_cosmo_tt(
172172
ln_evidence_hm = -ev.ln_evidence_inv
173173
err_ln_inv_evidence_hm = ev.compute_ln_inv_evidence_errors()
174174

175-
hm.logs.debug_log("---------------------------------")
176-
hm.logs.debug_log("Technical Details")
177-
hm.logs.debug_log("---------------------------------")
178-
hm.logs.debug_log(f"lnargmax = {ev.lnargmax}, lnargmin = {ev.lnargmin}")
179-
hm.logs.debug_log(f"lnprobmax = {ev.lnprobmax}, lnprobmin = {ev.lnprobmin}")
180-
hm.logs.debug_log(f"lnpredictmax = {ev.lnpredictmax}, lnpredictmin = {ev.lnpredictmin}")
181-
hm.logs.debug_log("---------------------------------")
182-
hm.logs.debug_log(f"shift = {ev.shift_value}, shift setting = {ev.shift}")
175+
print("---------------------------------")
176+
print("Technical Details")
177+
print("---------------------------------")
178+
print(f"lnargmax = {ev.lnargmax}, lnargmin = {ev.lnargmin}")
179+
print(f"lnprobmax = {ev.lnprobmax}, lnprobmin = {ev.lnprobmin}")
180+
print(f"lnpredictmax = {ev.lnpredictmax}, lnpredictmin = {ev.lnpredictmin}")
181+
print("---------------------------------")
182+
print(f"shift = {ev.shift_value}, shift setting = {ev.shift}")
183183

184184
print(f"ln_inv_evidence (harmonic)= {ev.ln_evidence_inv} +/- {err_ln_inv_evidence_hm}")
185185
print(f"ln evidence = {-ev.ln_evidence_inv} +/- {-err_ln_inv_evidence_hm[1]} {-err_ln_inv_evidence_hm[0]}")
@@ -207,7 +207,22 @@ def run_small_cosmo_tt(
207207

208208
tt_evidence = True
209209
if tt_evidence:
210-
def neglog_posterior_torch(theta_torch, lower, upper):
210+
theta_ref = 0.5 * (lower_tt + upper_tt)
211+
lnpost_ref = float(ln_posterior(theta_ref, lower_tt, upper_tt))
212+
213+
def neglog_posterior_torch_exact(theta_torch, lower, upper):
214+
theta_np = theta_torch.detach().cpu().numpy()
215+
lnps = np.array(
216+
[float(ln_posterior(t, lower, upper)) for t in theta_np],
217+
dtype=np.float64,
218+
)
219+
220+
# True target for correction/evidence: no clipping, no arbitrary shift.
221+
lnps = np.where(np.isfinite(lnps), lnps, -1e30)
222+
neglogps = -lnps
223+
return torch.tensor(neglogps, dtype=theta_torch.dtype, device=theta_torch.device)
224+
225+
def neglog_posterior_torch_tt(theta_torch, lower, upper):
211226
theta_np = theta_torch.detach().cpu().numpy()
212227
lnps = np.array(
213228
[float(ln_posterior(t, lower, upper)) for t in theta_np],
@@ -217,8 +232,9 @@ def neglog_posterior_torch(theta_torch, lower, upper):
217232
# Deterministic penalty for invalid values
218233
lnps = np.where(np.isfinite(lnps), lnps, -1e30)
219234

220-
# DIRT expects negative log target
221-
neglogps = -lnps
235+
# DIRT expects a negative log target. Shift by a fixed reference point
236+
# and clamp deterministically to keep TT fitting numerically stable.
237+
neglogps = np.clip(lnpost_ref - lnps, a_min=0.0, a_max=80.0)
222238
return torch.tensor(neglogps, dtype=theta_torch.dtype, device=theta_torch.device)
223239

224240
# ===========================================================================
@@ -228,32 +244,30 @@ def neglog_posterior_torch(theta_torch, lower, upper):
228244
approximation_domain = torch.tensor(limits, dtype=torch.float64)
229245

230246
# Create a partial function with lower and upper bounds pre-specified
231-
neglog_posterior_torch_partial = partial(neglog_posterior_torch, lower=lower_tt, upper=upper_tt)
232-
target_func = dt.TargetFunc(neglog_posterior_torch_partial)
247+
neglog_posterior_torch_tt_partial = partial(neglog_posterior_torch_tt, lower=lower_tt, upper=upper_tt)
248+
neglog_posterior_torch_exact_partial = partial(neglog_posterior_torch_exact, lower=lower_tt, upper=upper_tt)
249+
target_func = dt.TargetFunc(neglog_posterior_torch_tt_partial)
233250

234251
reference = dt.UniformReference()
235252
preconditioner = dt.UniformMapping(approximation_domain, reference)
236253

237-
# TT setup: since it's 5D, we can afford higher rank and more elements
238-
tt_options = dt.TTOptions(max_als=3, init_rank=1, tt_method="fixed_rank")
239-
basis = dt.Lagrange1(num_elems=29)
254+
# More robust TT setup for the 5D cosmology posterior.
255+
tt_options = dt.TTOptions(max_als=6, init_rank=6, tt_method="fixed_rank")
256+
basis = dt.Lagrange1(num_elems=19)
240257
bases = dt.ApproxBases(basis, ndim)
241258

242259
tt = dt.TT(tt_options)
243260
ftt = dt.FTT(bases, tt)
244261
bridge = dt.SingleLayer()
245262
dirt = dt.DIRT(target_func, preconditioner, ftt, bridge)
246263

247-
# ===========================================================================
248-
# Generate Samples via Independence Sampler
249-
# ===========================================================================
250264
hm.logs.info_log("Generating independent samples from TT...")
251265
num_sampl = nchains * (samples_per_chain-nburn)
252266
rs = reference.random(n=num_sampl, d=ndim)
253267

254268
startTime = time.time()
255269
xs, neglogfxs_sirt = dirt.eval_irt(rs)
256-
neglogfxs_exact = target_func(xs)
270+
neglogfxs_exact = neglog_posterior_torch_exact_partial(xs)
257271
res = dt.run_independence_sampler(xs, neglogfxs_sirt, neglogfxs_exact)
258272

259273
print(f'Time to generate {num_sampl} samples: {(time.time()-startTime):.2f}s')
@@ -289,6 +303,26 @@ def neglog_posterior_torch(theta_torch, lower, upper):
289303

290304
print(f"Harmonic + tt posterior samples ln_evidence: {ln_evidence_hm_tt} +/- {-err_ln_inv_evidence_hm_tt[1]} {-err_ln_inv_evidence_hm_tt[0]}")
291305

306+
print("---------------------------------")
307+
print("Technical Details")
308+
print("---------------------------------")
309+
print(f"lnargmax = {ev.lnargmax}, lnargmin = {ev.lnargmin}")
310+
print(f"lnprobmax = {ev.lnprobmax}, lnprobmin = {ev.lnprobmin}")
311+
print(f"lnpredictmax = {ev.lnpredictmax}, lnpredictmin = {ev.lnpredictmin}")
312+
print("---------------------------------")
313+
print(f"shift = {ev.shift_value}, shift setting = {ev.shift}")
314+
315+
print(f"ln_inv_evidence (harmonic)= {ev.ln_evidence_inv} +/- {err_ln_inv_evidence_hm_tt}")
316+
print(f"ln evidence = {-ev.ln_evidence_inv} +/- {-err_ln_inv_evidence_hm_tt[1]} {-err_ln_inv_evidence_hm_tt[0]}")
317+
print(f"kurtosis = {ev.kurtosis} (Aim for ~3)")
318+
319+
check = np.exp(0.5 * ev.ln_evidence_inv_var_var - ev.ln_evidence_inv_var)
320+
n_eff_limit = np.sqrt(2.0 / (ev.n_eff - 1))
321+
print(f"Standardized Variance Check: {check}")
322+
print(f"Aim for sqrt( 2/(n_eff-1) ) = {n_eff_limit}")
323+
print(f"sqrt(evidence_inv_var_var) / evidence_inv_var = {check}")
324+
325+
292326
if plot_corner:
293327
#Plot samples from tt
294328
hm.utils.plot_getdist(samples_np, labels=labels)
@@ -328,7 +362,7 @@ def estimate_evidence(
328362
num_samples_tt = nchains * (samples_per_chain-nburn)
329363

330364
# Estimate evidence
331-
evidence = estimate_evidence(neglog_posterior_torch_partial, dirt, num_samples_tt)
365+
evidence = estimate_evidence(neglog_posterior_torch_exact_partial, dirt, num_samples_tt)
332366
print(f"TT importance sampling evidence: {evidence.item():.4e}")
333367
clock = time.process_time() - clock
334368
print(f"TT importance sampling evidence estimation completed in {clock:.2f} seconds")

0 commit comments

Comments
 (0)