Skip to content

Commit 1a3c291

Browse files
authored
Fixing the missing sample_rate argument in mixin calling in Sortformer model file (#15228)
* Adding flexible input source for Diarization Mixin Signed-off-by: taejinp <tango4j@gmail.com> * Apply isort and black reformatting Signed-off-by: tango4j <tango4j@users.noreply.github.com> * Letting diarize() function to use lhotse dataloader Signed-off-by: taejinp <tango4j@gmail.com> * Apply isort and black reformatting Signed-off-by: tango4j <tango4j@users.noreply.github.com> * One for loop to handle everything Signed-off-by: taejinp <tango4j@gmail.com> * Fixing the missing sample_rate and fixed some outdated comments Signed-off-by: taejinp <tango4j@gmail.com> * Apply isort and black reformatting Signed-off-by: tango4j <tango4j@users.noreply.github.com> * Added sortformer_diar_models.py Signed-off-by: taejinp <tango4j@gmail.com> --------- Signed-off-by: taejinp <tango4j@gmail.com> Signed-off-by: tango4j <tango4j@users.noreply.github.com> Co-authored-by: tango4j <tango4j@users.noreply.github.com>
1 parent 6442018 commit 1a3c291

File tree

3 files changed

+33
-6
lines changed

3 files changed

+33
-6
lines changed

nemo/collections/asr/models/sortformer_diar_models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1142,6 +1142,7 @@ def diarize(
11421142
"""
11431143
return super().diarize(
11441144
audio=audio,
1145+
sample_rate=sample_rate,
11451146
batch_size=batch_size,
11461147
include_tensor_outputs=include_tensor_outputs,
11471148
postprocessing_yaml=postprocessing_yaml,

nemo/collections/asr/parts/mixins/diarization.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@ def resample_audio(samples: np.ndarray, orig_sr: int, target_sr: int) -> np.ndar
4949
return samples.astype(np.float32, copy=False)
5050

5151
resampled_samples = samples.astype(np.float32, copy=False)
52-
# User-requested API
5352
resampled_samples = librosa.core.resample(resampled_samples, orig_sr=orig_sr, target_sr=target_sr)
5453
return resampled_samples.astype(np.float32, copy=False)
5554

@@ -132,7 +131,7 @@ class SpkDiarizationMixin(ABC):
132131
"""
133132
An abstract class for diarize-able models.
134133
135-
Creates a template function `diarize()` that provides an interface to perform transcription of audio tensors or
134+
Creates a template function `diarize()` that provides an interface to perform diarization of audio tensors or
136135
filepaths.
137136
"""
138137

@@ -409,7 +408,7 @@ def _diarize_on_begin(self, audio: Union[str, List[str]], diarcfg: DiarizeConfig
409408
# Model's mode and device
410409
diarcfg._internal.training_mode = self.training
411410

412-
# Switch model to evaluation mode
411+
# Save preprocessor settings before switching to evaluation mode
413412
if hasattr(self, 'preprocessor'):
414413
if hasattr(self.preprocessor, 'featurizer') and hasattr(self.preprocessor.featurizer, 'dither'):
415414
diarcfg._internal.dither_value = self.preprocessor.featurizer.dither
@@ -541,7 +540,8 @@ def _diarize_input_processing(self, audio, diarcfg: DiarizeConfig):
541540

542541
else:
543542
raise ValueError(
544-
f"Input `audio` is of type {type(audio[0])}. " "Only `str` (path to audio file) is supported as input."
543+
f"Input `audio` is of type {type(audio[0])}. "
544+
"Only `str` (path to audio file) or `np.ndarray` are supported as input."
545545
)
546546

547547
def _diarize_input_manifest_processing(
@@ -632,7 +632,7 @@ def _diarize_output_processing(self, outputs, uniq_ids, diarcfg: DiarizeConfig)
632632

633633
def _diarize_on_end(self, diarcfg: DiarizeConfig):
634634
"""
635-
Internal function to teardown the model after transcription. Perform all teardown and post-checks here.
635+
Internal function to teardown the model after diarization. Perform all teardown and post-checks here.
636636
637637
Args:
638638
diarcfg: The diarization config dataclass. Subclasses can change this to a different dataclass if needed.

tests/collections/speaker_tasks/mixins/test_diarization.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,8 +200,34 @@ def test_diarize_manifest_jsonl_path(self, dummy_model, audio_files, tmp_path: P
200200
def test_diarize_numpy_single_requires_sample_rate(self, dummy_model, audio_files):
201201
dummy_model = dummy_model.eval()
202202
audio1, _, _, _, _ = audio_files
203+
204+
# Check if it raises an error without sample rate when using a single numpy variable input
203205
with pytest.raises(ValueError):
204-
_ = dummy_model.diarize(audio1, batch_size=1)
206+
_ = dummy_model.diarize(audio=audio1, batch_size=1)
207+
208+
# Set sample rate and check if it works
209+
sample_rate = 16000
210+
outputs = dummy_model.diarize(audio1, batch_size=1, sample_rate=sample_rate)
211+
assert isinstance(outputs, list)
212+
assert len(outputs) == 1
213+
assert outputs[0] > 0
214+
215+
@pytest.mark.unit
216+
def test_diarize_numpy_list_requires_sample_rate(self, dummy_model, audio_files):
217+
dummy_model = dummy_model.eval()
218+
audio1, audio2, _, _, _ = audio_files
219+
numpy_audio_list = [audio1, audio2]
220+
# Check if it raises an error without sample rate when using numpy list input
221+
with pytest.raises(ValueError):
222+
_ = dummy_model.diarize(audio=numpy_audio_list, batch_size=2)
223+
224+
# Set sample rate and check if it works
225+
sample_rate = 16000
226+
outputs = dummy_model.diarize(audio=numpy_audio_list, batch_size=2, sample_rate=sample_rate)
227+
assert isinstance(outputs, list)
228+
assert len(outputs) == 2
229+
assert outputs[0] > 0
230+
assert outputs[1] > 0
205231

206232
@pytest.mark.unit
207233
def test_diarize_numpy_single(self, dummy_model, audio_files):

0 commit comments

Comments
 (0)