Skip to content

Commit 84f70b6

Browse files
Generate training data in batches
1 parent 42c5d5c commit 84f70b6

File tree

1 file changed

+49
-22
lines changed

1 file changed

+49
-22
lines changed

examples/gaussian_mixture.py

Lines changed: 49 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ def run_example(
157157
save_name_start = "examples/plots/" + flow_type + "_" + str(ndim) + "D_" + str(n_components) + "gmm"
158158

159159
temperature = 0.9
160-
standardize = True
160+
standardize = False
161161
verbose = True
162162

163163
# Spline params
@@ -186,7 +186,7 @@ def run_example(
186186
)
187187

188188

189-
# ===== TRAINING PHASE: Generate training samples =====
189+
# ===== TRAINING PHASE: Generate training samples =====
190190
if use_emcee:
191191
# EMCEE sampling for training
192192
print("Using emcee for sampling training data...")
@@ -229,24 +229,48 @@ def log_prob_emcee(x):
229229
print("Using direct sampling for training data...")
230230
key = jax.random.PRNGKey(i_realisation)
231231
training_steps = int(samples_per_chain * training_proportion)
232-
total_train_samples = nchains * training_steps
233-
num_samples_per = (total_train_samples + n_components - 1) // n_components
234232

235-
samples_train, lnprob_train = sample_mixture(key, means, covs, num_samples_per, ndim)
236-
samples_train = samples_train[:total_train_samples]
237-
lnprob_train = lnprob_train[:total_train_samples]
233+
# Sample training data in chunks
234+
training_samples_per_batch = 20 # Samples per chain per batch
235+
n_train_batches = (training_steps + training_samples_per_batch - 1) // training_samples_per_batch
238236

239-
key, shuffle_key = jax.random.split(key)
240-
perm = jax.random.permutation(shuffle_key, total_train_samples)
241-
samples_train = samples_train[perm]
242-
lnprob_train = lnprob_train[perm]
237+
samples_train_list = []
238+
lnprob_train_list = []
243239

244-
samples_train = jnp.reshape(samples_train, (nchains, training_steps, ndim))
245-
lnprob_train = jnp.reshape(lnprob_train, (nchains, training_steps))
240+
for i_train_batch in range(n_train_batches):
241+
actual_train_batch_size = min(training_samples_per_batch, training_steps - i_train_batch * training_samples_per_batch)
242+
total_batch_train_samples = nchains * actual_train_batch_size
243+
num_samples_per = (total_batch_train_samples + n_components - 1) // n_components
244+
245+
hm.logs.info_log(f"Generating training batch {i_train_batch + 1}/{n_train_batches} ({actual_train_batch_size} samples per chain)...")
246+
key, subkey = jax.random.split(key)
247+
248+
samples_batch, lnprob_batch = sample_mixture(subkey, means, covs, num_samples_per, ndim)
249+
samples_batch = samples_batch[:total_batch_train_samples]
250+
lnprob_batch = lnprob_batch[:total_batch_train_samples]
251+
252+
key, shuffle_key = jax.random.split(key)
253+
perm = jax.random.permutation(shuffle_key, total_batch_train_samples)
254+
samples_batch = samples_batch[perm]
255+
lnprob_batch = lnprob_batch[perm]
256+
257+
samples_batch = jnp.reshape(samples_batch, (nchains, actual_train_batch_size, ndim))
258+
lnprob_batch = jnp.reshape(lnprob_batch, (nchains, actual_train_batch_size))
259+
260+
samples_train_list.append(samples_batch)
261+
lnprob_train_list.append(lnprob_batch)
262+
263+
del samples_batch, lnprob_batch
264+
265+
# Concatenate all training batches along the time dimension (axis=1)
266+
samples_train = jnp.concatenate(samples_train_list, axis=1)
267+
lnprob_train = jnp.concatenate(lnprob_train_list, axis=1)
268+
269+
del samples_train_list, lnprob_train_list
246270

247271
# Set up training chains
248-
chains_train = hm.Chains(ndim)
249-
chains_train.add_chains_3d(samples_train, lnprob_train)
272+
#chains_train = hm.Chains(ndim)
273+
#chains_train.add_chains_3d(samples_train, lnprob_train)
250274

251275
# =======================================================================
252276
# Fit model
@@ -276,7 +300,10 @@ def log_prob_emcee(x):
276300
standardize=standardize,
277301
temperature=temperature,
278302
)
279-
model.fit(chains_train.samples, epochs=epochs_num, verbose=verbose, batch_size=4096)
303+
304+
samples_train_flat = samples_train.reshape(-1, ndim)
305+
306+
model.fit(samples_train_flat, epochs=epochs_num, verbose=verbose, batch_size=4096)
280307

281308
losses = np.array(model.loss_values)
282309
ema = None
@@ -297,7 +324,7 @@ def log_prob_emcee(x):
297324
plt.show()
298325
plt.close()
299326

300-
# =======================================================================
327+
# =======================================================================
301328
# EVIDENCE COMPUTATION: Sample incrementally
302329
# =======================================================================
303330
hm.logs.info_log("Compute evidence with incremental sampling...")
@@ -424,10 +451,10 @@ def log_prob_emcee(x):
424451
save_name = save_name_start + "_getdist.png"
425452
plt.savefig(save_name, bbox_inches="tight")
426453

427-
num_samp = chains_train.samples.shape[0]
454+
num_samp = samples_train.shape[0]
428455
samps_compressed = model.sample(num_samp)
429456

430-
hm.utils.plot_getdist_compare(chains_train.samples, samps_compressed)
457+
hm.utils.plot_getdist_compare(samples_train, samps_compressed)
431458
plt.ticklabel_format(style="sci", axis="y", scilimits=(0, 0))
432459

433460
if savefigs:
@@ -488,9 +515,9 @@ def log_prob_emcee(x):
488515

489516
# Define parameters.
490517
n_components = 1
491-
ndim = 50
492-
nchains = 200
493-
samples_per_chain = 500
518+
ndim = 500
519+
nchains = 2000
520+
samples_per_chain = 1000
494521
burnin = 1000
495522
#flow_str = "RealNVP"
496523
#flow_str = "RQSpline"

0 commit comments

Comments
 (0)