diff --git a/clearvoice/config/inference/FRCRN_SE_16K.yaml b/clearvoice/config/inference/FRCRN_SE_16K.yaml index 0bec93c3..66904265 100644 --- a/clearvoice/config/inference/FRCRN_SE_16K.yaml +++ b/clearvoice/config/inference/FRCRN_SE_16K.yaml @@ -15,6 +15,6 @@ decode_window: 1 #one-pass decoding length # # FFT parameters win_type: 'hanning' -win_len: 640 -win_inc: 320 +win_len: 320 +win_inc: 160 fft_len: 640 diff --git a/clearvoice/demo.py b/clearvoice/demo.py index 70f9a916..55389d9f 100644 --- a/clearvoice/demo.py +++ b/clearvoice/demo.py @@ -65,3 +65,17 @@ #2nd calling method: process video files listed in .scp file, and write outputs to 'path_to_output_videos_tse_scp/' myClearVoice(input_path='samples/scp/video_samples.scp', online_write=True, output_path='samples/path_to_output_videos_tse_scp') + +##-----Demo Six: use FRCRN_SE_16K model for real-time processing ----------------- +if False: + myClearVoice = ClearVoice(task='speech_enhancement', model_names=['FRCRN_SE_16K']) + + ##1st calling method: process an input waveform in real-time and return output waveform, then write to output_FRCRN_SE_16K_realtime.wav + output_wav = myClearVoice(input_path='samples/input_realtime.wav', online_write=False) + myClearVoice.write(output_wav, output_path='samples/output_FRCRN_SE_16K_realtime.wav') + + ##2nd calling method: process all wav files in 'path_to_input_wavs_realtime/' in real-time and write outputs to 'path_to_output_wavs_realtime' + myClearVoice(input_path='samples/path_to_input_wavs_realtime', online_write=True, output_path='samples/path_to_output_wavs_realtime') + + ##3rd calling method: process wav files listed in .scp file in real-time, and write outputs to 'path_to_output_wavs_realtime_scp/' + myClearVoice(input_path='samples/scp/audio_samples_realtime.scp', online_write=True, output_path='samples/path_to_output_wavs_realtime_scp') diff --git a/clearvoice/demo_with_more_comments.py b/clearvoice/demo_with_more_comments.py index 9e845344..be382801 100644 --- a/clearvoice/demo_with_more_comments.py +++ b/clearvoice/demo_with_more_comments.py @@ -57,3 +57,31 @@ # - online_write (bool): Set to True to enable saving the enhanced output during processing # - output_path (str): Path to the directory to save the enhanced output files myClearVoice(input_path='samples/scp/audio_samples.scp', online_write=True, output_path='samples/path_to_output_wavs_scp') + + ## ---------------- Demo Three: Real-Time Processing ----------------------- + if False: # This block demonstrates how to use the FRCRN_SE_16K model for real-time speech enhancement + # Initialize ClearVoice for the task of speech enhancement using the FRCRN_SE_16K model + myClearVoice = ClearVoice(task='speech_enhancement', model_names=['FRCRN_SE_16K']) + + # 1st calling method: + # Process an input waveform in real-time and return the enhanced output waveform + # - input_path (str): Path to the input noisy audio file (input_realtime.wav) + # - output_wav (dict or ndarray) : The enhanced output waveform + output_wav = myClearVoice(input_path='samples/input_realtime.wav', online_write=False) + # Write the processed waveform to an output file + # - output_path (str): Path to save the enhanced audio file (output_FRCRN_SE_16K_realtime.wav) + myClearVoice.write(output_wav, output_path='samples/output_FRCRN_SE_16K_realtime.wav') + + # 2nd calling method: + # Process and write audio files directly in real-time + # - input_path (str): Path to the directory of input noisy audio files + # - online_write (bool): Set to True to enable saving the enhanced audio directly to files during processing + # - output_path (str): Path to the directory to save the enhanced output files + myClearVoice(input_path='samples/path_to_input_wavs_realtime', online_write=True, output_path='samples/path_to_output_wavs_realtime') + + # 3rd calling method: + # Use an .scp file to specify input audio paths for real-time processing + # - input_path (str): Path to a .scp file listing multiple audio file paths + # - online_write (bool): Set to True to enable saving the enhanced audio directly to files during processing + # - output_path (str): Path to the directory to save the enhanced output files + myClearVoice(input_path='samples/scp/audio_samples_realtime.scp', online_write=True, output_path='samples/path_to_output_wavs_realtime_scp') diff --git a/clearvoice/models/frcrn_se/frcrn.py b/clearvoice/models/frcrn_se/frcrn.py index bc195c8a..bedc3586 100644 --- a/clearvoice/models/frcrn_se/frcrn.py +++ b/clearvoice/models/frcrn_se/frcrn.py @@ -71,18 +71,35 @@ def __init__(self, args): win_type=args.win_type ) - def forward(self, x): + def forward(self, x, real_time=False): """ Forward pass of the model. Args: x (torch.Tensor): Input tensor representing audio signals. + real_time (bool): Flag to indicate real-time processing. Returns: torch.Tensor: Processed output tensor after applying the model. """ - output = self.model(x) - return output[1][0] # Return estimated waveform + if real_time: + return self.real_time_process(x) + else: + output = self.model(x) + return output[1][0] # Return estimated waveform + + def real_time_process(self, x): + """ + Real-time processing method for the FRCRN model. + + Args: + x (torch.Tensor): Input tensor representing audio signals. + + Returns: + torch.Tensor: Processed output tensor after applying the model in real-time. + """ + output = self.model.real_time_process(x) + return output class DCCRN(nn.Module): @@ -249,3 +266,41 @@ def get_params(self, weight_decay=0.0): }] return params + def real_time_process(self, inputs): + """ + Real-time processing method for the DCCRN model. + + Args: + inputs (torch.Tensor): Input tensor representing audio signals. + + Returns: + torch.Tensor: Processed output tensor after applying the model in real-time. + """ + out_list = [] + # Compute the complex spectrogram using STFT + cmp_spec = self.stft(inputs) # [B, D*2, T] + cmp_spec = torch.unsqueeze(cmp_spec, 1) # [B, 1, D*2, T] + + # Split into real and imaginary parts + cmp_spec = torch.cat([ + cmp_spec[:, :, :self.feat_dim, :], # Real part + cmp_spec[:, :, self.feat_dim:, :], # Imaginary part + ], 1) # [B, 2, D, T] + + cmp_spec = torch.unsqueeze(cmp_spec, 4) # [B, 2, D, T, 1] + cmp_spec = torch.transpose(cmp_spec, 1, 4) # [B, 1, D, T, 2] + + # Pass through the UNet to estimate masks + unet1_out = self.unet(cmp_spec) # First UNet output + cmp_mask1 = torch.tanh(unet1_out) # First mask + + unet2_out = self.unet2(unet1_out) # Second UNet output + cmp_mask2 = torch.tanh(unet2_out) # Second mask + cmp_mask2 = cmp_mask2 + cmp_mask1 # Combine masks + + # Apply the estimated mask to the complex spectrogram + est_spec, est_wav, est_mask = self.apply_mask(cmp_spec, cmp_mask2) + out_list.append(est_spec) + out_list.append(est_wav) + out_list.append(est_mask) + return out_list