Skip to content

Commit 6ad76d1

Browse files
Remove dependencies for mvdr tutorial (#4003)
* Remove dependencies for mvdr tutorial * Add back `evaluate` function to mvdr tutorial --------- Co-authored-by: Sam Anklesaria <[email protected]>
1 parent bf305f5 commit 6ad76d1

File tree

1 file changed

+8
-29
lines changed

1 file changed

+8
-29
lines changed

examples/tutorials/mvdr_tutorial.py

Lines changed: 8 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@
3838

3939

4040
import matplotlib.pyplot as plt
41-
import mir_eval
4241
from IPython.display import Audio
4342

4443
######################################################################
@@ -48,21 +47,8 @@
4847

4948
######################################################################
5049
# 2.1. Import the packages
51-
# ~~~~~~~~~~~~~~~~~~~~~~~~
5250
#
53-
# First, we install and import the necessary packages.
54-
#
55-
# ``mir_eval``, ``pesq``, and ``pystoi`` packages are required for
56-
# evaluating the speech enhancement performance.
57-
#
58-
59-
# When running this example in notebook, install the following packages.
60-
# !pip3 install mir_eval
61-
# !pip3 install pesq
62-
# !pip3 install pystoi
6351

64-
from pesq import pesq
65-
from pystoi import stoi
6652
from torchaudio.utils import download_asset
6753

6854
######################################################################
@@ -142,8 +128,14 @@ def generate_mixture(waveform_clean, waveform_noise, target_snr):
142128
waveform_noise *= 10 ** (-(target_snr - current_snr) / 20)
143129
return waveform_clean + waveform_noise
144130

145-
131+
# If you have mir_eval installed, you can use it to evaluate the separation quality of the estimated sources.
132+
# You can also evaluate the intelligibility of the speech with the Short-Time Objective Intelligibility (STOI) metric
133+
# available in the `pystoi` package, or the Perceptual Evaluation of Speech Quality (PESQ) metric available in the `pesq` package.
146134
def evaluate(estimate, reference):
135+
from pesq import pesq
136+
from pystoi import stoi
137+
import mir_eval
138+
147139
si_snr_score = si_snr(estimate, reference)
148140
(
149141
sdr,
@@ -158,7 +150,6 @@ def evaluate(estimate, reference):
158150
print(f"PESQ score: {pesq_mix}")
159151
print(f"STOI score: {stoi_mix}")
160152

161-
162153
######################################################################
163154
# 3. Generate Ideal Ratio Masks (IRMs)
164155
# ------------------------------------
@@ -211,18 +202,9 @@ def evaluate(estimate, reference):
211202
# 3.2.1. Visualize mixture speech
212203
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
213204
#
214-
# We evaluate the quality of the mixture speech or the enhanced speech
215-
# using the following three metrics:
216-
#
217-
# - signal-to-distortion ratio (SDR)
218-
# - scale-invariant signal-to-noise ratio (Si-SNR, or Si-SDR in some papers)
219-
# - Perceptual Evaluation of Speech Quality (PESQ)
220-
#
221-
# We also evaluate the intelligibility of the speech with the Short-Time Objective Intelligibility
222-
# (STOI) metric.
205+
223206

224207
plot_spectrogram(stft_mix[0], "Spectrogram of Mixture Speech (dB)")
225-
evaluate(waveform_mix[0:1], waveform_clean[0:1])
226208
Audio(waveform_mix[0], rate=SAMPLE_RATE)
227209

228210

@@ -335,7 +317,6 @@ def get_irms(stft_clean, stft_noise):
335317

336318
plot_spectrogram(stft_souden, "Enhanced Spectrogram by SoudenMVDR (dB)")
337319
waveform_souden = waveform_souden.reshape(1, -1)
338-
evaluate(waveform_souden, waveform_clean[0:1])
339320
Audio(waveform_souden, rate=SAMPLE_RATE)
340321

341322

@@ -393,7 +374,6 @@ def get_irms(stft_clean, stft_noise):
393374

394375
plot_spectrogram(stft_rtf_evd, "Enhanced Spectrogram by RTFMVDR and F.rtf_evd (dB)")
395376
waveform_rtf_evd = waveform_rtf_evd.reshape(1, -1)
396-
evaluate(waveform_rtf_evd, waveform_clean[0:1])
397377
Audio(waveform_rtf_evd, rate=SAMPLE_RATE)
398378

399379

@@ -404,5 +384,4 @@ def get_irms(stft_clean, stft_noise):
404384

405385
plot_spectrogram(stft_rtf_power, "Enhanced Spectrogram by RTFMVDR and F.rtf_power (dB)")
406386
waveform_rtf_power = waveform_rtf_power.reshape(1, -1)
407-
evaluate(waveform_rtf_power, waveform_clean[0:1])
408387
Audio(waveform_rtf_power, rate=SAMPLE_RATE)

0 commit comments

Comments
 (0)