@@ -157,7 +157,7 @@ def run_example(
157157 save_name_start = "examples/plots/" + flow_type + "_" + str (ndim ) + "D_" + str (n_components ) + "gmm"
158158
159159 temperature = 0.9
160- standardize = True
160+ standardize = False
161161 verbose = True
162162
163163 # Spline params
@@ -186,7 +186,7 @@ def run_example(
186186 )
187187
188188
189- # ===== TRAINING PHASE: Generate training samples =====
189+ # ===== TRAINING PHASE: Generate training samples =====
190190 if use_emcee :
191191 # EMCEE sampling for training
192192 print ("Using emcee for sampling training data..." )
@@ -229,24 +229,48 @@ def log_prob_emcee(x):
229229 print ("Using direct sampling for training data..." )
230230 key = jax .random .PRNGKey (i_realisation )
231231 training_steps = int (samples_per_chain * training_proportion )
232- total_train_samples = nchains * training_steps
233- num_samples_per = (total_train_samples + n_components - 1 ) // n_components
234232
235- samples_train , lnprob_train = sample_mixture ( key , means , covs , num_samples_per , ndim )
236- samples_train = samples_train [: total_train_samples ]
237- lnprob_train = lnprob_train [: total_train_samples ]
233+ # Sample training data in chunks
234+ training_samples_per_batch = 20 # Samples per chain per batch
235+ n_train_batches = ( training_steps + training_samples_per_batch - 1 ) // training_samples_per_batch
238236
239- key , shuffle_key = jax .random .split (key )
240- perm = jax .random .permutation (shuffle_key , total_train_samples )
241- samples_train = samples_train [perm ]
242- lnprob_train = lnprob_train [perm ]
237+ samples_train_list = []
238+ lnprob_train_list = []
243239
244- samples_train = jnp .reshape (samples_train , (nchains , training_steps , ndim ))
245- lnprob_train = jnp .reshape (lnprob_train , (nchains , training_steps ))
240+ for i_train_batch in range (n_train_batches ):
241+ actual_train_batch_size = min (training_samples_per_batch , training_steps - i_train_batch * training_samples_per_batch )
242+ total_batch_train_samples = nchains * actual_train_batch_size
243+ num_samples_per = (total_batch_train_samples + n_components - 1 ) // n_components
244+
245+ hm .logs .info_log (f"Generating training batch { i_train_batch + 1 } /{ n_train_batches } ({ actual_train_batch_size } samples per chain)..." )
246+ key , subkey = jax .random .split (key )
247+
248+ samples_batch , lnprob_batch = sample_mixture (subkey , means , covs , num_samples_per , ndim )
249+ samples_batch = samples_batch [:total_batch_train_samples ]
250+ lnprob_batch = lnprob_batch [:total_batch_train_samples ]
251+
252+ key , shuffle_key = jax .random .split (key )
253+ perm = jax .random .permutation (shuffle_key , total_batch_train_samples )
254+ samples_batch = samples_batch [perm ]
255+ lnprob_batch = lnprob_batch [perm ]
256+
257+ samples_batch = jnp .reshape (samples_batch , (nchains , actual_train_batch_size , ndim ))
258+ lnprob_batch = jnp .reshape (lnprob_batch , (nchains , actual_train_batch_size ))
259+
260+ samples_train_list .append (samples_batch )
261+ lnprob_train_list .append (lnprob_batch )
262+
263+ del samples_batch , lnprob_batch
264+
265+ # Concatenate all training batches along the time dimension (axis=1)
266+ samples_train = jnp .concatenate (samples_train_list , axis = 1 )
267+ lnprob_train = jnp .concatenate (lnprob_train_list , axis = 1 )
268+
269+ del samples_train_list , lnprob_train_list
246270
247271 # Set up training chains
248- chains_train = hm .Chains (ndim )
249- chains_train .add_chains_3d (samples_train , lnprob_train )
272+ # chains_train = hm.Chains(ndim)
273+ # chains_train.add_chains_3d(samples_train, lnprob_train)
250274
251275 # =======================================================================
252276 # Fit model
@@ -276,7 +300,10 @@ def log_prob_emcee(x):
276300 standardize = standardize ,
277301 temperature = temperature ,
278302 )
279- model .fit (chains_train .samples , epochs = epochs_num , verbose = verbose , batch_size = 4096 )
303+
304+ samples_train_flat = samples_train .reshape (- 1 , ndim )
305+
306+ model .fit (samples_train_flat , epochs = epochs_num , verbose = verbose , batch_size = 4096 )
280307
281308 losses = np .array (model .loss_values )
282309 ema = None
@@ -297,7 +324,7 @@ def log_prob_emcee(x):
297324 plt .show ()
298325 plt .close ()
299326
300- # =======================================================================
327+ # =======================================================================
301328 # EVIDENCE COMPUTATION: Sample incrementally
302329 # =======================================================================
303330 hm .logs .info_log ("Compute evidence with incremental sampling..." )
@@ -424,10 +451,10 @@ def log_prob_emcee(x):
424451 save_name = save_name_start + "_getdist.png"
425452 plt .savefig (save_name , bbox_inches = "tight" )
426453
427- num_samp = chains_train . samples .shape [0 ]
454+ num_samp = samples_train .shape [0 ]
428455 samps_compressed = model .sample (num_samp )
429456
430- hm .utils .plot_getdist_compare (chains_train . samples , samps_compressed )
457+ hm .utils .plot_getdist_compare (samples_train , samps_compressed )
431458 plt .ticklabel_format (style = "sci" , axis = "y" , scilimits = (0 , 0 ))
432459
433460 if savefigs :
@@ -488,9 +515,9 @@ def log_prob_emcee(x):
488515
489516 # Define parameters.
490517 n_components = 1
491- ndim = 50
492- nchains = 200
493- samples_per_chain = 500
518+ ndim = 500
519+ nchains = 2000
520+ samples_per_chain = 1000
494521 burnin = 1000
495522 #flow_str = "RealNVP"
496523 #flow_str = "RQSpline"
0 commit comments