@@ -142,17 +142,20 @@ def generate_and_save_intermediate_result(self, batch):
142142
143143 mels_before , mels_after , * _ = outputs
144144 mel_gts = batch ["mel_gts" ]
145+ utt_ids = batch ["utt_ids" ]
145146
146147 # convert to tensor.
147148 # here we just take a sample at first replica.
148149 try :
149150 mels_before = mels_before .values [0 ].numpy ()
150151 mels_after = mels_after .values [0 ].numpy ()
151152 mel_gts = mel_gts .values [0 ].numpy ()
153+ utt_ids = utt_ids .values [0 ].numpy ()
152154 except Exception :
153155 mels_before = mels_before .numpy ()
154156 mels_after = mels_after .numpy ()
155157 mel_gts = mel_gts .numpy ()
158+ utt_ids = utt_ids .numpy ()
156159
157160 # check directory
158161 if self .use_griffin :
@@ -167,22 +170,25 @@ def generate_and_save_intermediate_result(self, batch):
167170 for idx , (mel_gt , mel_before , mel_after ) in enumerate (
168171 zip (mel_gts , mels_before , mels_after ), 1
169172 ):
170-
173+
174+
171175 if self .use_griffin :
176+ utt_id = utt_ids [idx ]
172177 grif_before = self .griffin_lim_tf (tf .reshape (mel_before , [- 1 , 80 ])[tf .newaxis , :], n_iter = 32 )
173178 grif_after = self .griffin_lim_tf (tf .reshape (mel_after , [- 1 , 80 ])[tf .newaxis , :], n_iter = 32 )
174179 grif_gt = self .griffin_lim_tf (tf .reshape (mel_gt , [- 1 , 80 ])[tf .newaxis , :], n_iter = 32 )
175- self .griffin_lim_tf .save_wav (grif_before , griff_dir_name , f"{ idx } _before" )
176- self .griffin_lim_tf .save_wav (grif_after , griff_dir_name , f"{ idx } _after" )
177- self .griffin_lim_tf .save_wav (grif_gt , griff_dir_name , f"{ idx } _gt" )
178-
180+ self .griffin_lim_tf .save_wav (grif_before , griff_dir_name , f"{ utt_id } _before" )
181+ self .griffin_lim_tf .save_wav (grif_after , griff_dir_name , f"{ utt_id } _after" )
182+ self .griffin_lim_tf .save_wav (grif_gt , griff_dir_name , f"{ utt_id } _gt" )
183+
184+ utt_id = utt_ids [idx ]
179185 mel_gt = tf .reshape (mel_gt , (- 1 , 80 )).numpy () # [length, 80]
180186 mel_before = tf .reshape (mel_before , (- 1 , 80 )).numpy () # [length, 80]
181187 mel_after = tf .reshape (mel_after , (- 1 , 80 )).numpy () # [length, 80]
182188
183189
184190 # plit figure and save it
185- figname = os .path .join (dirname , f"{ idx } .png" )
191+ figname = os .path .join (dirname , f"{ utt_id } .png" )
186192 fig = plt .figure (figsize = (10 , 8 ))
187193 ax1 = fig .add_subplot (311 )
188194 ax2 = fig .add_subplot (312 )
0 commit comments