Skip to content

Commit ebc6e3f

Browse files
committed
fix: respect hop512 settings in RefineGAN downsample/upsample rates
1 parent c700be4 commit ebc6e3f

File tree

2 files changed

+39
-10
lines changed

2 files changed

+39
-10
lines changed

modules/refinegan/generator.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -369,11 +369,35 @@ def __init__(
369369

370370
self.sampling_rate = sampling_rate
371371
self.hop_length = hop_length
372-
self.downsample_rates = downsample_rates
373-
self.upsample_rates = upsample_rates
374372
self.leaky_relu_slope = leaky_relu_slope
375373

376-
assert np.prod(downsample_rates) == np.prod(upsample_rates) == hop_length
374+
def _scale_last(rates, factor):
375+
rates = list(rates)
376+
rates[-1] = rates[-1] * factor
377+
return tuple(rates)
378+
379+
total_down = np.prod(downsample_rates)
380+
total_up = np.prod(upsample_rates)
381+
if total_down != hop_length:
382+
if hop_length % total_down != 0:
383+
raise ValueError(
384+
f"RefineGAN: hop_length {hop_length} not divisible by prod(downsample_rates) {total_down}"
385+
)
386+
scale = hop_length // total_down
387+
downsample_rates = _scale_last(downsample_rates, scale)
388+
print(f"| adjust RefineGAN downsample_rates -> {downsample_rates} to match hop_length {hop_length}")
389+
if total_up != hop_length:
390+
if hop_length % total_up != 0:
391+
raise ValueError(
392+
f"RefineGAN: hop_length {hop_length} not divisible by prod(upsample_rates) {total_up}"
393+
)
394+
scale = hop_length // total_up
395+
upsample_rates = _scale_last(upsample_rates, scale)
396+
print(f"| adjust RefineGAN upsample_rates -> {upsample_rates} to match hop_length {hop_length}")
397+
398+
# Keep the possibly-adjusted rates
399+
self.downsample_rates = tuple(downsample_rates)
400+
self.upsample_rates = tuple(upsample_rates)
377401

378402
self.template_type = template_generator
379403
if template_generator == "comb":
@@ -524,4 +548,3 @@ def forward(self, mel: torch.Tensor, f0: torch.Tensor) -> torch.Tensor:
524548
x = torch.tanh(x)
525549

526550
return x
527-

modules/vocoders/refinegan.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,13 @@ def _select_config_file(path: pathlib.Path) -> Optional[pathlib.Path]:
5555

5656
def _extract_gen_kwargs(cfg: Dict[str, Any]) -> Dict[str, Any]:
5757
model_args = cfg.get("model_args") or cfg.get("generator") or {}
58+
def pick(key: str, default: Any):
59+
if key in model_args:
60+
return model_args[key]
61+
if key in cfg:
62+
return cfg[key]
63+
return default
64+
5865
return {
5966
"sampling_rate": cfg.get("audio_sample_rate")
6067
or cfg.get("sampling_rate")
@@ -68,11 +75,11 @@ def _extract_gen_kwargs(cfg: Dict[str, Any]) -> Dict[str, Any]:
6875
or cfg.get("hop_length")
6976
or model_args.get("hop_length")
7077
or hparams["hop_size"],
71-
"downsample_rates": tuple(model_args.get("downsample_rates", (2, 2, 8, 8))),
72-
"upsample_rates": tuple(model_args.get("upsample_rates", (8, 8, 2, 2))),
73-
"leaky_relu_slope": float(model_args.get("leaky_relu_slope", 0.2)),
74-
"start_channels": int(model_args.get("start_channels", 16)),
75-
"template_generator": model_args.get("template_generator", "comb"),
78+
"downsample_rates": tuple(pick("downsample_rates", (2, 2, 8, 8))),
79+
"upsample_rates": tuple(pick("upsample_rates", (8, 8, 2, 2))),
80+
"leaky_relu_slope": float(pick("leaky_relu_slope", 0.2)),
81+
"start_channels": int(pick("start_channels", 16)),
82+
"template_generator": pick("template_generator", "comb"),
7683
}
7784

7885

@@ -216,4 +223,3 @@ def spec2wav(self, mel, **kwargs):
216223
with torch.no_grad():
217224
wav = self.spec2wav_torch(mel_np, f0=f0_t)
218225
return wav.cpu().numpy()
219-

0 commit comments

Comments
 (0)