@@ -55,6 +55,13 @@ def _select_config_file(path: pathlib.Path) -> Optional[pathlib.Path]:
5555
5656def _extract_gen_kwargs (cfg : Dict [str , Any ]) -> Dict [str , Any ]:
5757 model_args = cfg .get ("model_args" ) or cfg .get ("generator" ) or {}
58+ def pick (key : str , default : Any ):
59+ if key in model_args :
60+ return model_args [key ]
61+ if key in cfg :
62+ return cfg [key ]
63+ return default
64+
5865 return {
5966 "sampling_rate" : cfg .get ("audio_sample_rate" )
6067 or cfg .get ("sampling_rate" )
@@ -68,11 +75,11 @@ def _extract_gen_kwargs(cfg: Dict[str, Any]) -> Dict[str, Any]:
6875 or cfg .get ("hop_length" )
6976 or model_args .get ("hop_length" )
7077 or hparams ["hop_size" ],
71- "downsample_rates" : tuple (model_args . get ("downsample_rates" , (2 , 2 , 8 , 8 ))),
72- "upsample_rates" : tuple (model_args . get ("upsample_rates" , (8 , 8 , 2 , 2 ))),
73- "leaky_relu_slope" : float (model_args . get ("leaky_relu_slope" , 0.2 )),
74- "start_channels" : int (model_args . get ("start_channels" , 16 )),
75- "template_generator" : model_args . get ("template_generator" , "comb" ),
78+ "downsample_rates" : tuple (pick ("downsample_rates" , (2 , 2 , 8 , 8 ))),
79+ "upsample_rates" : tuple (pick ("upsample_rates" , (8 , 8 , 2 , 2 ))),
80+ "leaky_relu_slope" : float (pick ("leaky_relu_slope" , 0.2 )),
81+ "start_channels" : int (pick ("start_channels" , 16 )),
82+ "template_generator" : pick ("template_generator" , "comb" ),
7683 }
7784
7885
@@ -216,4 +223,3 @@ def spec2wav(self, mel, **kwargs):
216223 with torch .no_grad ():
217224 wav = self .spec2wav_torch (mel_np , f0 = f0_t )
218225 return wav .cpu ().numpy ()
219-
0 commit comments