Skip to content

Commit 3a9efa4

Browse files
authored
Fix to latent sampling in the training loop. May improve learning. Thanks Yangkang Zhang!
1 parent d832541 commit 3a9efa4

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

pytorch_version/training/training_loop.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def fetch_data(dataset, dataset_iter, input_shape, drange_net, device, batches_n
5555
real_c = real_c.to(device).split(batch_gpu)
5656

5757
gen_zs = torch.randn([batches_num * batch_size, *input_shape[1:]], device = device)
58-
gen_zs = [gen_zs.split(batch_gpu) for gen_z in gen_zs.split(batch_size)]
58+
gen_zs = [gen_z.split(batch_gpu) for gen_z in gen_zs.split(batch_size)]
5959

6060
gen_cs = [dataset.get_label(np.random.randint(len(dataset))) for _ in range(batches_num * batch_size)]
6161
gen_cs = torch.from_numpy(np.stack(gen_cs)).pin_memory().to(device)

0 commit comments

Comments
 (0)