Skip to content

Commit 66ffb38

Browse files
authored
[Fix] Fix the notebook errors on multispeaker data simulation and end to end diarization training (#15149)
* Fixed the notebook errors Signed-off-by: taejinp <[email protected]> * Apply isort and black reformatting Signed-off-by: tango4j <[email protected]> --------- Signed-off-by: taejinp <[email protected]> Signed-off-by: tango4j <[email protected]> Co-authored-by: tango4j <[email protected]>
1 parent 7fda26c commit 66ffb38

File tree

3 files changed

+33
-29
lines changed

3 files changed

+33
-29
lines changed

nemo/collections/asr/data/data_simulation.py

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1148,7 +1148,7 @@ def _generate_session(
11481148
if self._params.data_simulator.background_noise.add_bg:
11491149
if len(self._noise_samples) > 0:
11501150
avg_power_array = torch.mean(array[is_speech == 1] ** 2)
1151-
bg, snr = get_background_noise(
1151+
bg, snr, _ = get_background_noise(
11521152
len_array=len(array),
11531153
power_array=avg_power_array,
11541154
noise_samples=self._noise_samples,
@@ -1466,6 +1466,8 @@ def _generate_rir_pyroomacoustics(self) -> Tuple[torch.Tensor, int]:
14661466
if self._params.data_simulator.rir_generation.mic_config.mic_pattern == 'omni':
14671467
mic_pattern = DirectivityPattern.OMNI
14681468
dir_vec = DirectionVector(azimuth=0, colatitude=90, degrees=True)
1469+
else:
1470+
raise Exception("Currently, microphone pattern must be omni. Aborting RIR generation.")
14691471
dir_obj = CardioidFamily(
14701472
orientation=dir_vec,
14711473
pattern_enum=mic_pattern,
@@ -1509,6 +1511,8 @@ def _convolve_rir(self, input, speaker_turn: int, RIR: torch.Tensor) -> Tuple[li
15091511
out_channel = convolve(input, RIR[speaker_turn, channel, : len(input)]).tolist()
15101512
elif self._params.data_simulator.rir_generation.toolkit == 'pyroomacoustics':
15111513
out_channel = convolve(input, RIR[channel][speaker_turn][: len(input)]).tolist()
1514+
else:
1515+
raise Exception("Toolkit must be pyroomacoustics or gpuRIR. Aborting RIR convolution.")
15121516
if len(out_channel) > length:
15131517
length = len(out_channel)
15141518
output_sound.append(torch.tensor(out_channel))
@@ -1644,7 +1648,11 @@ def _generate_session(
16441648
self.annotator.annote_lists['json'].append(new_json_entry)
16451649

16461650
new_ctm_entries, _ = self.annotator.create_new_ctm_entry(
1647-
filename, speaker_ids[speaker_turn], start / self._params.data_simulator.sr
1651+
words=self._text,
1652+
alignments=self._alignments,
1653+
session_name=filename,
1654+
speaker_id=speaker_ids[speaker_turn],
1655+
start=start / self._params.data_simulator.sr,
16481656
)
16491657
self.annotator.annote_lists['ctm'].extend(new_ctm_entries)
16501658

@@ -1659,23 +1667,21 @@ def _generate_session(
16591667
array = perturb_audio(array, self._params.data_simulator.sr, self.session_augmentor)
16601668

16611669
# Step 7-2: Additive background noise from noise manifest files
1662-
if self._params.data_simulator.background_noise.add_bg:
1663-
if len(self._noise_samples) > 0:
1664-
avg_power_array = torch.mean(array[is_speech == 1] ** 2)
1665-
bg, snr = get_background_noise(
1666-
len_array=len(array),
1667-
power_array=avg_power_array,
1668-
noise_samples=self._noise_samples,
1669-
audio_read_buffer_dict=self._audio_read_buffer_dict,
1670-
snr_min=self._params.data_simulator.background_noise.snr_min,
1671-
snr_max=self._params.data_simulator.background_noise.snr_max,
1672-
background_noise_snr=self._params.data_simulator.background_noise.snr,
1673-
seed=(random_seed + idx),
1674-
device=self._device,
1675-
)
1676-
array += bg
1670+
if self._params.data_simulator.background_noise.add_bg and len(self._noise_samples) > 0:
1671+
avg_power_array = torch.mean(array[is_speech == 1] ** 2)
1672+
bg, snr, _ = get_background_noise(
1673+
len_array=len(array),
1674+
power_array=avg_power_array,
1675+
noise_samples=self._noise_samples,
1676+
audio_read_buffer_dict=self._audio_read_buffer_dict,
1677+
snr_min=self._params.data_simulator.background_noise.snr_min,
1678+
snr_max=self._params.data_simulator.background_noise.snr_max,
1679+
background_noise_snr=self._params.data_simulator.background_noise.snr,
1680+
seed=(random_seed + idx),
1681+
device=self._device,
1682+
)
1683+
array += bg
16771684
length = array.shape[0]
1678-
bg, snr = self._get_background(length, avg_power_array)
16791685
augmented_bg, _ = self._convolve_rir(bg, -1, RIR)
16801686
for channel in range(self._params.data_simulator.rir_generation.mic_config.num_channels):
16811687
array[:, channel] += augmented_bg[channel][:length]

tutorials/speaker_tasks/End_to_End_Diarization_Training.ipynb

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@
153153
"cell_type": "markdown",
154154
"metadata": {},
155155
"source": [
156-
"In mathmatical terms, Sort-Loss can be expressed as follows:\n",
156+
"In mathematical terms, Sort-Loss can be expressed as follows:\n",
157157
"\n",
158158
"* **Arrival Time Sorting Function with $\\Psi$ function** \n",
159159
"\n",
@@ -200,7 +200,7 @@
200200
"cell_type": "markdown",
201201
"metadata": {},
202202
"source": [
203-
"Now that we learn the concept of Sort Loss and Sortformer, we can now calculate Sort Loss based target matrix and PIL-based target matrix to compare the difference in target-value setting atrix and loss calculation.\n",
203+
"Now that we learn the concept of Sort Loss and Sortformer, we can now calculate Sort Loss based target matrix and PIL-based target matrix to compare the difference in target-value setting matrix and loss calculation.\n",
204204
"\n",
205205
"- raw target matrix $\\mathbf{Y}$: `raw_targets`\n",
206206
"- prediction matrix $\\mathbf{P}$: `preds`\n",
@@ -297,7 +297,6 @@
297297
"from nemo.collections.asr.losses.bce_loss import BCELoss \n",
298298
"\n",
299299
"bce_loss = BCELoss()\n",
300-
"# reduction='mean', class_normalization=False)\n",
301300
"\n",
302301
"def plot_diarout(preds, title_text, cmap_str):\n",
303302
"\n",
@@ -825,7 +824,6 @@
825824
"source": [
826825
"curr_dir = os.getcwd() + \"/\"\n",
827826
"config.model.train_ds.manifest_filepath = f'{curr_dir}simulated_train/sortformer_train.json'\n",
828-
"# config.model.test_ds.manifest_filepath = f'{curr_dir}simulated_valid/sortformer_valid.json'\n",
829827
"config.model.validation_ds.manifest_filepath = f'{curr_dir}simulated_valid/sortformer_valid.json'\n",
830828
"config.trainer.strategy = \"ddp_notebook\"\n",
831829
"config.batch_size = 3\n",

tutorials/tools/Multispeaker_Simulator.ipynb

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@
122122
"metadata": {},
123123
"outputs": [],
124124
"source": [
125-
"!python NeMo/scripts/speaker_tasks/create_alignment_manifest.py \\\n",
125+
"!python {NEMO_DIR_PATH}/scripts/speaker_tasks/create_alignment_manifest.py \\\n",
126126
" --input_manifest_filepath LibriSpeech/dev_clean.json \\\n",
127127
" --base_alignment_path LibriSpeech_Alignments \\\n",
128128
" --output_manifest_filepath ./dev-clean-align.json \\\n",
@@ -218,7 +218,7 @@
218218
"source": [
219219
"# Step 5: Generate Simulated Audio Session\n",
220220
"\n",
221-
"A single 4-speaker session of 60 seconds is generated as an example. "
221+
"A single 4-speaker session of 30 seconds is generated as an example. "
222222
]
223223
},
224224
{
@@ -250,7 +250,7 @@
250250
"cell_type": "markdown",
251251
"metadata": {},
252252
"source": [
253-
"# Step 5: Listen to and Visualize Session\n",
253+
"# Step 6: Listen to and Visualize Session\n",
254254
"\n",
255255
"Listen to the audio and visualize the corresponding speaker timestamps (recorded in a RTTM file for each session)"
256256
]
@@ -264,7 +264,6 @@
264264
"outputs": [],
265265
"source": [
266266
"import os\n",
267-
"import wget\n",
268267
"import IPython\n",
269268
"import matplotlib.pyplot as plt\n",
270269
"import numpy as np\n",
@@ -316,7 +315,7 @@
316315
"cell_type": "markdown",
317316
"metadata": {},
318317
"source": [
319-
"# Step 6: Get Simulated Data Statistics "
318+
"# Step 7: Get Simulated Data Statistics "
320319
]
321320
},
322321
{
@@ -325,6 +324,7 @@
325324
"metadata": {},
326325
"outputs": [],
327326
"source": [
327+
"import wget\n",
328328
"if not os.path.exists(\"multispeaker_data_analysis.py\"):\n",
329329
" !wget https://raw.githubusercontent.com/NVIDIA/NeMo/main/scripts/speaker_tasks/multispeaker_data_analysis.py\n",
330330
"\n",
@@ -365,7 +365,7 @@
365365
],
366366
"metadata": {
367367
"kernelspec": {
368-
"display_name": "Python 3 (ipykernel)",
368+
"display_name": "nemo093025",
369369
"language": "python",
370370
"name": "python3"
371371
},
@@ -379,7 +379,7 @@
379379
"name": "python",
380380
"nbconvert_exporter": "python",
381381
"pygments_lexer": "ipython3",
382-
"version": "3.9.7"
382+
"version": "3.10.12"
383383
},
384384
"pycharm": {
385385
"stem_cell": {

0 commit comments

Comments
 (0)