Skip to content

Commit add4642

Browse files
committed
fix(train): parameter issue
1 parent 1410bd4 commit add4642

File tree

4 files changed

+20
-13
lines changed

4 files changed

+20
-13
lines changed

gui.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,8 +144,6 @@ def __init__(self) -> None:
144144
self.input_devices_indices = None
145145
self.output_devices_indices = None
146146
self.stream = None
147-
if not self.config.nocheck:
148-
self.check_assets()
149147
self.update_devices()
150148
self.launcher()
151149

infer/lib/train/utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import logging
55
import os
66
import sys
7+
from copy import deepcopy
78

89
import codecs
910
import numpy as np
@@ -444,6 +445,9 @@ def items(self):
444445

445446
def values(self):
446447
return self.__dict__.values()
448+
449+
def copy(self):
450+
return deepcopy(self)
447451

448452
def __len__(self):
449453
return len(self.__dict__)

infer/modules/train/train.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@
4646
from torch.utils.data import DataLoader
4747
from torch.utils.tensorboard import SummaryWriter
4848

49-
from rvc.layers import utils
5049
from infer.lib.train.data_utils import (
5150
DistributedBucketSampler,
5251
TextAudioCollate,
@@ -77,6 +76,11 @@
7776
from infer.lib.train.mel_processing import mel_spectrogram_torch, spec_to_mel_torch
7877
from infer.lib.train.process_ckpt import save_small_model
7978

79+
from rvc.layers.utils import (
80+
slice_on_last_dim,
81+
total_grad_norm,
82+
)
83+
8084
global_step = 0
8185

8286

@@ -118,7 +122,7 @@ def main():
118122
children[i].join()
119123

120124

121-
def run(rank, n_gpus, hps, logger: logging.Logger):
125+
def run(rank, n_gpus, hps: utils.HParams, logger: logging.Logger):
122126
global global_step
123127
if rank == 0:
124128
# logger = utils.get_logger(hps.model_dir)
@@ -163,20 +167,20 @@ def run(rank, n_gpus, hps, logger: logging.Logger):
163167
persistent_workers=True,
164168
prefetch_factor=8,
165169
)
170+
mdl = hps.copy().model
171+
del mdl.use_spectral_norm
166172
if hps.if_f0 == 1:
167173
net_g = RVC_Model_f0(
168174
hps.data.filter_length // 2 + 1,
169175
hps.train.segment_size // hps.data.hop_length,
170-
**hps.model,
171-
is_half=hps.train.fp16_run,
176+
**mdl,
172177
sr=hps.sample_rate,
173178
)
174179
else:
175180
net_g = RVC_Model_nof0(
176181
hps.data.filter_length // 2 + 1,
177182
hps.train.segment_size // hps.data.hop_length,
178-
**hps.model,
179-
is_half=hps.train.fp16_run,
183+
**mdl,
180184
)
181185
if torch.cuda.is_available():
182186
net_g = net_g.cuda(rank)
@@ -459,7 +463,7 @@ def train_and_evaluate(
459463
hps.data.mel_fmin,
460464
hps.data.mel_fmax,
461465
)
462-
y_mel = utils.slice_on_last_dim(
466+
y_mel = slice_on_last_dim(
463467
mel, ids_slice, hps.train.segment_size // hps.data.hop_length
464468
)
465469
with autocast(enabled=False):
@@ -475,7 +479,7 @@ def train_and_evaluate(
475479
)
476480
if hps.train.fp16_run == True:
477481
y_hat_mel = y_hat_mel.half()
478-
wave = utils.slice_on_last_dim(
482+
wave = slice_on_last_dim(
479483
wave, ids_slice * hps.data.hop_length, hps.train.segment_size
480484
) # slice
481485

@@ -488,7 +492,7 @@ def train_and_evaluate(
488492
optim_d.zero_grad()
489493
scaler.scale(loss_disc).backward()
490494
scaler.unscale_(optim_d)
491-
grad_norm_d = utils.total_grad_norm(net_d.parameters())
495+
grad_norm_d = total_grad_norm(net_d.parameters())
492496
scaler.step(optim_d)
493497

494498
with autocast(enabled=hps.train.fp16_run):
@@ -503,7 +507,7 @@ def train_and_evaluate(
503507
optim_g.zero_grad()
504508
scaler.scale(loss_gen_all).backward()
505509
scaler.unscale_(optim_g)
506-
grad_norm_g = utils.total_grad_norm(net_g.parameters())
510+
grad_norm_g = total_grad_norm(net_g.parameters())
507511
scaler.step(optim_g)
508512
scaler.update()
509513

web.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
import numpy as np
2525
import gradio as gr
2626
import faiss
27-
import fairseq
2827
import pathlib
2928
import json
3029
from time import sleep
@@ -72,7 +71,9 @@ def forward_dml(ctx, x, scale):
7271
res = x.clone().detach()
7372
return res
7473

74+
import fairseq
7575
fairseq.modules.grad_multiply.GradMultiply.forward = forward_dml
76+
7677
i18n = I18nAuto()
7778
logger.info(i18n)
7879
# 判断是否有能用来训练和加速推理的N卡

0 commit comments

Comments
 (0)