@@ -172,14 +172,14 @@ def run_small_cosmo_tt(
172172 ln_evidence_hm = - ev .ln_evidence_inv
173173 err_ln_inv_evidence_hm = ev .compute_ln_inv_evidence_errors ()
174174
175- hm . logs . debug_log ("---------------------------------" )
176- hm . logs . debug_log ("Technical Details" )
177- hm . logs . debug_log ("---------------------------------" )
178- hm . logs . debug_log (f"lnargmax = { ev .lnargmax } , lnargmin = { ev .lnargmin } " )
179- hm . logs . debug_log (f"lnprobmax = { ev .lnprobmax } , lnprobmin = { ev .lnprobmin } " )
180- hm . logs . debug_log (f"lnpredictmax = { ev .lnpredictmax } , lnpredictmin = { ev .lnpredictmin } " )
181- hm . logs . debug_log ("---------------------------------" )
182- hm . logs . debug_log (f"shift = { ev .shift_value } , shift setting = { ev .shift } " )
175+ print ("---------------------------------" )
176+ print ("Technical Details" )
177+ print ("---------------------------------" )
178+ print (f"lnargmax = { ev .lnargmax } , lnargmin = { ev .lnargmin } " )
179+ print (f"lnprobmax = { ev .lnprobmax } , lnprobmin = { ev .lnprobmin } " )
180+ print (f"lnpredictmax = { ev .lnpredictmax } , lnpredictmin = { ev .lnpredictmin } " )
181+ print ("---------------------------------" )
182+ print (f"shift = { ev .shift_value } , shift setting = { ev .shift } " )
183183
184184 print (f"ln_inv_evidence (harmonic)= { ev .ln_evidence_inv } +/- { err_ln_inv_evidence_hm } " )
185185 print (f"ln evidence = { - ev .ln_evidence_inv } +/- { - err_ln_inv_evidence_hm [1 ]} { - err_ln_inv_evidence_hm [0 ]} " )
@@ -207,7 +207,22 @@ def run_small_cosmo_tt(
207207
208208 tt_evidence = True
209209 if tt_evidence :
210- def neglog_posterior_torch (theta_torch , lower , upper ):
210+ theta_ref = 0.5 * (lower_tt + upper_tt )
211+ lnpost_ref = float (ln_posterior (theta_ref , lower_tt , upper_tt ))
212+
213+ def neglog_posterior_torch_exact (theta_torch , lower , upper ):
214+ theta_np = theta_torch .detach ().cpu ().numpy ()
215+ lnps = np .array (
216+ [float (ln_posterior (t , lower , upper )) for t in theta_np ],
217+ dtype = np .float64 ,
218+ )
219+
220+ # True target for correction/evidence: no clipping, no arbitrary shift.
221+ lnps = np .where (np .isfinite (lnps ), lnps , - 1e30 )
222+ neglogps = - lnps
223+ return torch .tensor (neglogps , dtype = theta_torch .dtype , device = theta_torch .device )
224+
225+ def neglog_posterior_torch_tt (theta_torch , lower , upper ):
211226 theta_np = theta_torch .detach ().cpu ().numpy ()
212227 lnps = np .array (
213228 [float (ln_posterior (t , lower , upper )) for t in theta_np ],
@@ -217,8 +232,9 @@ def neglog_posterior_torch(theta_torch, lower, upper):
217232 # Deterministic penalty for invalid values
218233 lnps = np .where (np .isfinite (lnps ), lnps , - 1e30 )
219234
220- # DIRT expects negative log target
221- neglogps = - lnps
235+ # DIRT expects a negative log target. Shift by a fixed reference point
236+ # and clamp deterministically to keep TT fitting numerically stable.
237+ neglogps = np .clip (lnpost_ref - lnps , a_min = 0.0 , a_max = 80.0 )
222238 return torch .tensor (neglogps , dtype = theta_torch .dtype , device = theta_torch .device )
223239
224240 # ===========================================================================
@@ -228,32 +244,30 @@ def neglog_posterior_torch(theta_torch, lower, upper):
228244 approximation_domain = torch .tensor (limits , dtype = torch .float64 )
229245
230246 # Create a partial function with lower and upper bounds pre-specified
231- neglog_posterior_torch_partial = partial (neglog_posterior_torch , lower = lower_tt , upper = upper_tt )
232- target_func = dt .TargetFunc (neglog_posterior_torch_partial )
247+ neglog_posterior_torch_tt_partial = partial (neglog_posterior_torch_tt , lower = lower_tt , upper = upper_tt )
248+ neglog_posterior_torch_exact_partial = partial (neglog_posterior_torch_exact , lower = lower_tt , upper = upper_tt )
249+ target_func = dt .TargetFunc (neglog_posterior_torch_tt_partial )
233250
234251 reference = dt .UniformReference ()
235252 preconditioner = dt .UniformMapping (approximation_domain , reference )
236253
237- # TT setup: since it's 5D, we can afford higher rank and more elements
238- tt_options = dt .TTOptions (max_als = 3 , init_rank = 1 , tt_method = "fixed_rank" )
239- basis = dt .Lagrange1 (num_elems = 29 )
254+ # More robust TT setup for the 5D cosmology posterior.
255+ tt_options = dt .TTOptions (max_als = 6 , init_rank = 6 , tt_method = "fixed_rank" )
256+ basis = dt .Lagrange1 (num_elems = 19 )
240257 bases = dt .ApproxBases (basis , ndim )
241258
242259 tt = dt .TT (tt_options )
243260 ftt = dt .FTT (bases , tt )
244261 bridge = dt .SingleLayer ()
245262 dirt = dt .DIRT (target_func , preconditioner , ftt , bridge )
246263
247- # ===========================================================================
248- # Generate Samples via Independence Sampler
249- # ===========================================================================
250264 hm .logs .info_log ("Generating independent samples from TT..." )
251265 num_sampl = nchains * (samples_per_chain - nburn )
252266 rs = reference .random (n = num_sampl , d = ndim )
253267
254268 startTime = time .time ()
255269 xs , neglogfxs_sirt = dirt .eval_irt (rs )
256- neglogfxs_exact = target_func (xs )
270+ neglogfxs_exact = neglog_posterior_torch_exact_partial (xs )
257271 res = dt .run_independence_sampler (xs , neglogfxs_sirt , neglogfxs_exact )
258272
259273 print (f'Time to generate { num_sampl } samples: { (time .time ()- startTime ):.2f} s' )
@@ -289,6 +303,26 @@ def neglog_posterior_torch(theta_torch, lower, upper):
289303
290304 print (f"Harmonic + tt posterior samples ln_evidence: { ln_evidence_hm_tt } +/- { - err_ln_inv_evidence_hm_tt [1 ]} { - err_ln_inv_evidence_hm_tt [0 ]} " )
291305
306+ print ("---------------------------------" )
307+ print ("Technical Details" )
308+ print ("---------------------------------" )
309+ print (f"lnargmax = { ev .lnargmax } , lnargmin = { ev .lnargmin } " )
310+ print (f"lnprobmax = { ev .lnprobmax } , lnprobmin = { ev .lnprobmin } " )
311+ print (f"lnpredictmax = { ev .lnpredictmax } , lnpredictmin = { ev .lnpredictmin } " )
312+ print ("---------------------------------" )
313+ print (f"shift = { ev .shift_value } , shift setting = { ev .shift } " )
314+
315+ print (f"ln_inv_evidence (harmonic)= { ev .ln_evidence_inv } +/- { err_ln_inv_evidence_hm_tt } " )
316+ print (f"ln evidence = { - ev .ln_evidence_inv } +/- { - err_ln_inv_evidence_hm_tt [1 ]} { - err_ln_inv_evidence_hm_tt [0 ]} " )
317+ print (f"kurtosis = { ev .kurtosis } (Aim for ~3)" )
318+
319+ check = np .exp (0.5 * ev .ln_evidence_inv_var_var - ev .ln_evidence_inv_var )
320+ n_eff_limit = np .sqrt (2.0 / (ev .n_eff - 1 ))
321+ print (f"Standardized Variance Check: { check } " )
322+ print (f"Aim for sqrt( 2/(n_eff-1) ) = { n_eff_limit } " )
323+ print (f"sqrt(evidence_inv_var_var) / evidence_inv_var = { check } " )
324+
325+
292326 if plot_corner :
293327 #Plot samples from tt
294328 hm .utils .plot_getdist (samples_np , labels = labels )
@@ -328,7 +362,7 @@ def estimate_evidence(
328362 num_samples_tt = nchains * (samples_per_chain - nburn )
329363
330364 # Estimate evidence
331- evidence = estimate_evidence (neglog_posterior_torch_partial , dirt , num_samples_tt )
365+ evidence = estimate_evidence (neglog_posterior_torch_exact_partial , dirt , num_samples_tt )
332366 print (f"TT importance sampling evidence: { evidence .item ():.4e} " )
333367 clock = time .process_time () - clock
334368 print (f"TT importance sampling evidence estimation completed in { clock :.2f} seconds" )
0 commit comments