|
19 | 19 |
|
20 | 20 |
|
21 | 21 | def train(args, pt_dir, chkpt_path, trainloader, testloader, writer, logger, hp, hp_str): |
22 | | - model = Tier(hp=hp, |
23 | | - freq=hp.audio.n_mels // f_div[hp.model.tier+1] * f_div[args.tier], |
24 | | - layers=hp.model.layers[args.tier-1], |
25 | | - tierN=args.tier).cuda() |
| 22 | + model = Tier( |
| 23 | + hp=hp, |
| 24 | + freq=hp.audio.n_mels // f_div[hp.model.tier+1] * f_div[args.tier], |
| 25 | + layers=hp.model.layers[args.tier-1], |
| 26 | + tierN=args.tier |
| 27 | + ).cuda() |
26 | 28 | melgen = MelGen(hp) |
27 | 29 | tierutil = TierUtil(hp) |
28 | 30 | criterion = GMMLoss() |
@@ -74,27 +76,30 @@ def train(args, pt_dir, chkpt_path, trainloader, testloader, writer, logger, hp, |
74 | 76 | torch.backends.cudnn.benchmark = True |
75 | 77 | try: |
76 | 78 | model.train() |
| 79 | + optimizer.zero_grad() |
| 80 | + loss_sum = 0 |
77 | 81 | for epoch in itertools.count(init_epoch+1): |
78 | | - trainloader.tier = args.tier |
79 | 82 | loader = tqdm(trainloader, desc='Train data loader') |
80 | 83 | for source, target in loader: |
81 | 84 | mu, std, pi = model(source.cuda()) |
82 | 85 | loss = criterion(target.cuda(), mu, std, pi) |
83 | | - |
84 | | - optimizer.zero_grad() |
85 | | - loss.backward() |
86 | | - optimizer.step() |
87 | 86 | step += 1 |
| 87 | + (loss / hp.train.update_interval).backward() |
| 88 | + loss_sum += loss.item() / hp.train.update_interval |
| 89 | + |
| 90 | + if step % hp.train.update_interval == 0: |
| 91 | + optimizer.step() |
| 92 | + optimizer.zero_grad() |
| 93 | + if step % hp.log.summary_interval == 0: |
| 94 | + writer.log_training(loss_sum, mu, std, pi, step) |
| 95 | + loader.set_description("Loss %.04f at step %d" % (loss_sum, step)) |
| 96 | + loss_sum = 0 |
88 | 97 |
|
89 | 98 | loss = loss.item() |
90 | 99 | if loss > 1e8 or math.isnan(loss): |
91 | 100 | logger.error("Loss exploded to %.04f at step %d!" % (loss, step)) |
92 | 101 | raise Exception("Loss exploded") |
93 | 102 |
|
94 | | - if step % hp.log.summary_interval == 0: |
95 | | - writer.log_training(loss, mu, std, pi, step) |
96 | | - loader.set_description("Loss %.04f at step %d" % (loss, step)) |
97 | | - |
98 | 103 | save_path = os.path.join(pt_dir, '%s_%s_tier%d_%03d.pt' |
99 | 104 | % (args.name, githash, args.tier, epoch)) |
100 | 105 | torch.save({ |
|
0 commit comments