Skip to content

Commit ac51c4f

Browse files
authored
Adding SD for superb (speech-classification). (#225)
* Adding SD for superb (speech-classification). * Style. * Style. * Test dependencies. * Addressing @omar 's comments. * Forgot pytest.. * `speech-classification` -> `speech-segmentation`. * Fixing test cache in common + adding simple tests.
1 parent 2409e60 commit ac51c4f

26 files changed

+507
-14
lines changed

.github/workflows/python-api-tests.yaml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ jobs:
2525
working-directory: api-inference-community
2626
run: |
2727
pip install --upgrade pip
28-
pip install pytest pillow httpx huggingface_hub
29-
pip install -e .
28+
pip install -e .[test]
3029
- run: make test
3130
working-directory: api-inference-community
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
"""
2+
Helper classes to modify pipeline outputs from tensors to expected pipeline output
3+
"""
4+
5+
from typing import TYPE_CHECKING, Dict, List, Union
6+
7+
8+
Classes = Dict[str, Union[str, float]]
9+
10+
if TYPE_CHECKING:
11+
try:
12+
import torch
13+
except Exception:
14+
pass
15+
16+
17+
def speaker_diarization_normalize(
18+
tensor: "torch.Tensor", sampling_rate: int, classnames: List[str]
19+
) -> List[Classes]:
20+
N = tensor.shape[1]
21+
if len(classnames) != N:
22+
raise ValueError(
23+
f"There is a mismatch between classnames ({len(classnames)}) and number of speakers ({N})"
24+
)
25+
classes = []
26+
for i in range(N):
27+
values, counts = tensor[:, i].unique_consecutive(return_counts=True)
28+
offset = 0
29+
for v, c in zip(values, counts):
30+
if v == 1:
31+
classes.append(
32+
{
33+
"class": classnames[i],
34+
"start": offset / sampling_rate,
35+
"end": (offset + c.item()) / sampling_rate,
36+
}
37+
)
38+
offset += c.item()
39+
40+
classes = sorted(classes, key=lambda x: x["start"])
41+
return classes

api-inference-community/api_inference_community/validation.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,12 +211,14 @@ def normalize_payload(
211211
if task in {
212212
"automatic-speech-recognition",
213213
"audio-to-audio",
214+
"speech-segmentation",
214215
}:
215216
if sampling_rate is None:
216217
raise EnvironmentError(
217218
"We cannot normalize audio file if we don't know the sampling rate"
218219
)
219-
return normalize_payload_audio(bpayload, sampling_rate)
220+
outputs = normalize_payload_audio(bpayload, sampling_rate)
221+
return outputs
220222
elif task in {
221223
"image-classification",
222224
"image-to-text",

api-inference-community/docker_images/common/app/pipelines/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from app.pipelines.image_classification import ImageClassificationPipeline
99
from app.pipelines.question_answering import QuestionAnsweringPipeline
1010
from app.pipelines.sentence_similarity import SentenceSimilarityPipeline
11+
from app.pipelines.speech_segmentation import SpeechSegmentationPipeline
1112
from app.pipelines.structured_data_classification import (
1213
StructuredDataClassificationPipeline,
1314
)

api-inference-community/docker_images/common/app/pipelines/automatic_speech_recognition.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,7 @@ def __call__(self, inputs: np.array) -> Dict[str, str]:
2121
"""
2222
Args:
2323
inputs (:obj:`np.array`):
24-
The raw waveform of audio received. By default at 16KHz.
25-
Check `app.validation` if a different sample rate is required
26-
or if it depends on the model
24+
The raw waveform of audio received. By default at self.sampling_rate, otherwise 16KHz.
2725
Return:
2826
A :obj:`dict`:. The object return should be liked {"text": "XXX"} containing
2927
the detected langage from the input audio
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
from typing import Dict
2+
3+
import numpy as np
4+
from app.pipelines import Pipeline
5+
6+
7+
class SpeechSegmentationPipeline(Pipeline):
8+
def __init__(self, model_id: str):
9+
# IMPLEMENT_THIS
10+
# Preload all the elements you are going to need at inference.
11+
# For instance your model, processors, tokenizer that might be needed.
12+
# This function is only called once, so do all the heavy processing I/O here
13+
# IMPLEMENT_THIS : Please define a `self.sampling_rate` for this pipeline
14+
# to automatically read the input correctly
15+
self.sampling_rate = 16000
16+
raise NotImplementedError(
17+
"Please implement SpeechSegmentationPipeline __init__ function"
18+
)
19+
20+
def __call__(self, inputs: np.array) -> Dict[str, str]:
21+
"""
22+
Args:
23+
inputs (:obj:`np.array`):
24+
The raw waveform of audio received. By default at self.sampling_rate, otherwise 16KHz.
25+
Return:
26+
A :obj:`list`:. Each item in the list is like {"class": "XXX", "start": float, "end": float}
27+
"class" is the associated class of the audio segment, "start" and "end" are markers expressed in seconds
28+
within the audio file.
29+
"""
30+
# IMPLEMENT_THIS
31+
# api_inference_community.normalizers.speaker_diarization_normalize could help.
32+
raise NotImplementedError(
33+
"Please implement SpeechSegmentationPipeline __call__ function"
34+
)

api-inference-community/docker_images/common/tests/test_api_audio_to_audio.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,12 @@ def setUp(self):
2424

2525
self.app = app
2626

27+
@classmethod
28+
def setUpClass(cls):
29+
from app.main import get_pipeline
30+
31+
get_pipeline.cache_clear()
32+
2733
def tearDown(self):
2834
if self.old_model_id is not None:
2935
os.environ["MODEL_ID"] = self.old_model_id

api-inference-community/docker_images/common/tests/test_api_automatic_speech_recognition.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,12 @@ def setUp(self):
2222

2323
self.app = app
2424

25+
@classmethod
26+
def setUpClass(cls):
27+
from app.main import get_pipeline
28+
29+
get_pipeline.cache_clear()
30+
2531
def tearDown(self):
2632
if self.old_model_id is not None:
2733
os.environ["MODEL_ID"] = self.old_model_id

api-inference-community/docker_images/common/tests/test_api_feature_extraction.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,12 @@ def setUp(self):
2222

2323
self.app = app
2424

25+
@classmethod
26+
def setUpClass(cls):
27+
from app.main import get_pipeline
28+
29+
get_pipeline.cache_clear()
30+
2531
def tearDown(self):
2632
if self.old_model_id is not None:
2733
os.environ["MODEL_ID"] = self.old_model_id

api-inference-community/docker_images/common/tests/test_api_image_classification.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,12 @@ def setUp(self):
2222

2323
self.app = app
2424

25+
@classmethod
26+
def setUpClass(cls):
27+
from app.main import get_pipeline
28+
29+
get_pipeline.cache_clear()
30+
2531
def tearDown(self):
2632
if self.old_model_id is not None:
2733
os.environ["MODEL_ID"] = self.old_model_id

0 commit comments

Comments
 (0)