Skip to content

Commit 61aa919

Browse files
authored
Adding flexible input sources for Diarization Mixin (#15184)
* Adding flexible input source for Diarization Mixin Signed-off-by: taejinp <[email protected]> * Apply isort and black reformatting Signed-off-by: tango4j <[email protected]> --------- Signed-off-by: taejinp <[email protected]> Signed-off-by: tango4j <[email protected]> Co-authored-by: tango4j <[email protected]>
1 parent 9a9b596 commit 61aa919

File tree

5 files changed

+361
-189
lines changed

5 files changed

+361
-189
lines changed

nemo/collections/asr/data/audio_to_diar_label_lhotse.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import logging
1615
from typing import Dict, Optional, Tuple
1716

1817
import torch.utils.data
@@ -24,6 +23,7 @@
2423
speaker_to_target,
2524
)
2625
from nemo.core.neural_types import AudioSignal, LabelsType, LengthsType, NeuralType
26+
from nemo.utils import logging
2727

2828

2929
class LhotseAudioToSpeechE2ESpkDiarDataset(torch.utils.data.Dataset):
@@ -58,11 +58,23 @@ def __init__(self, cfg):
5858
self.num_mel_frame_per_target_frame = int(self.cfg.get('subsampling_factor', 8))
5959

6060
def __getitem__(self, cuts) -> Tuple[torch.Tensor, ...]:
61-
audio, audio_lens, cuts = self.load_audio(cuts)
61+
# NOTE: This end-to-end diarization dataloader only loads the 1st ch of the audio file.
62+
# Process cuts in a single loop: convert to mono and compute speaker activities
63+
mono_cuts = []
6264
speaker_activities = []
6365
for cut in cuts:
66+
if cut.num_channels is not None and cut.num_channels > 1:
67+
logging.warning(
68+
"Multiple channels detected in cut '%s' (%d channels). "
69+
"Only the first channel will be used; remaining channels are ignored.",
70+
cut.id,
71+
cut.num_channels,
72+
)
73+
mono_cut = cut.with_channels(channels=[0])
74+
mono_cuts.append(mono_cut)
75+
6476
speaker_activity = speaker_to_target(
65-
a_cut=cut,
77+
a_cut=mono_cut,
6678
num_speakers=self.num_speakers,
6779
num_sample_per_mel_frame=self.num_sample_per_mel_frame,
6880
num_mel_frame_per_asr_frame=self.num_mel_frame_per_target_frame,
@@ -79,6 +91,9 @@ def __getitem__(self, cuts) -> Tuple[torch.Tensor, ...]:
7991
)
8092
speaker_activity = speaker_activity[:, : self.num_speakers]
8193
speaker_activities.append(speaker_activity)
94+
95+
cuts = type(cuts).from_cuts(mono_cuts)
96+
audio, audio_lens, cuts = self.load_audio(cuts)
8297
targets = collate_matrices(speaker_activities).to(audio.dtype) # (B, T, N)
8398

8499
if targets.shape[2] > self.num_speakers:

nemo/collections/asr/models/sortformer_diar_models.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,7 @@ def _setup_diarize_dataloader(self, config: Dict) -> 'torch.utils.data.DataLoade
425425
'session_len_sec': config['session_len_sec'],
426426
'num_workers': config.get('num_workers', min(batch_size, os.cpu_count() - 1)),
427427
'pin_memory': True,
428+
'use_lhotse': config.get('use_lhotse', False),
428429
}
429430
temporary_datalayer = self.__setup_dataloader_from_config(config=DictConfig(dl_config))
430431
return temporary_datalayer
@@ -1112,6 +1113,7 @@ def on_validation_epoch_end(self) -> Optional[dict[str, dict[str, torch.Tensor]]
11121113
def diarize(
11131114
self,
11141115
audio: Union[str, List[str], np.ndarray, DataLoader],
1116+
sample_rate: Optional[int] = None,
11151117
batch_size: int = 1,
11161118
include_tensor_outputs: bool = False,
11171119
postprocessing_yaml: Optional[str] = None,

0 commit comments

Comments
 (0)