Skip to content

Commit 0f0585e

Browse files
authored
Add fairseq S2T/S2S pipelines (#570)
* add s2s and s2t * updates * update fairseq commit * update fairseq commit * resolve PR comment * isort * isort latest version * fix sample rate * fix * update * fix * update test * update test * update * update
1 parent 2a3f2c3 commit 0f0585e

File tree

7 files changed

+163
-2
lines changed

7 files changed

+163
-2
lines changed

api-inference-community/docker_images/fairseq/app/main.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from typing import Dict, Type
55

66
from api_inference_community.routes import pipeline_route, status_ok
7-
from app.pipelines import Pipeline, TextToSpeechPipeline
7+
from app.pipelines import Pipeline, SpeechToSpeechPipeline, TextToSpeechPipeline
88
from starlette.applications import Starlette
99
from starlette.middleware import Middleware
1010
from starlette.middleware.gzip import GZipMiddleware
@@ -34,6 +34,7 @@
3434
# directories. Implement directly within the directories.
3535
ALLOWED_TASKS: Dict[str, Type[Pipeline]] = {
3636
"text-to-speech": TextToSpeechPipeline,
37+
"audio-to-audio": SpeechToSpeechPipeline,
3738
}
3839

3940

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from app.pipelines.base import Pipeline, PipelineException # isort:skip
22

3+
from app.pipelines.audio_to_audio import SpeechToSpeechPipeline
34
from app.pipelines.text_to_speech import TextToSpeechPipeline
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
import os
2+
from typing import List, Tuple
3+
4+
import numpy as np
5+
import torch
6+
from app.pipelines import Pipeline
7+
from fairseq.checkpoint_utils import load_model_ensemble_and_task_from_hf_hub
8+
from fairseq.models.speech_to_text.hub_interface import S2THubInterface
9+
from fairseq.models.text_to_speech.hub_interface import TTSHubInterface
10+
11+
12+
class SpeechToSpeechPipeline(Pipeline):
13+
def __init__(self, model_id: str):
14+
models, cfg, task = load_model_ensemble_and_task_from_hf_hub(
15+
model_id,
16+
arg_overrides={"config_yaml": "config.yaml"},
17+
cache_dir=os.getenv("HUGGINGFACE_HUB_CACHE"),
18+
)
19+
self.model = models[0].cpu()
20+
self.model.eval()
21+
cfg["task"].cpu = True
22+
self.task = task
23+
self.generator = task.build_generator([self.model], cfg)
24+
25+
self.sampling_rate = getattr(self.task, "sr", None) or 16_000
26+
27+
tgt_lang = self.task.data_cfg.hub.get("tgt_lang", None)
28+
pfx = f"{tgt_lang}_" if self.task.data_cfg.prepend_tgt_lang_tag else ""
29+
tts_model_id = self.task.data_cfg.hub.get(f"{pfx}tts_model_id", None)
30+
self.tts_model, self.tts_task, self.tts_generator = None, None, None
31+
if tts_model_id is not None:
32+
_repo, _id = tts_model_id.split(":")
33+
(
34+
tts_models,
35+
tts_cfg,
36+
self.tts_task,
37+
) = load_model_ensemble_and_task_from_hf_hub(
38+
f"facebook/{_id}",
39+
arg_overrides={"vocoder": "griffin_lim", "fp16": False},
40+
cache_dir=os.getenv("HUGGINGFACE_HUB_CACHE"),
41+
)
42+
self.tts_model = tts_models[0].cpu()
43+
self.tts_model.eval()
44+
tts_cfg["task"].cpu = True
45+
TTSHubInterface.update_cfg_with_data_cfg(tts_cfg, self.tts_task.data_cfg)
46+
self.tts_generator = self.tts_task.build_generator(
47+
[self.tts_model], tts_cfg
48+
)
49+
50+
def __call__(self, inputs: np.array) -> Tuple[np.array, int, List[str]]:
51+
"""
52+
Args:
53+
inputs (:obj:`np.array`):
54+
The raw waveform of audio received. By default sampled at `self.sampling_rate`.
55+
The shape of this array is `T`, where `T` is the time axis
56+
Return:
57+
A :obj:`tuple` containing:
58+
- :obj:`np.array`:
59+
The return shape of the array must be `C'`x`T'`
60+
- a :obj:`int`: the sampling rate as an int in Hz.
61+
- a :obj:`List[str]`: the annotation for each out channel.
62+
This can be the name of the instruments for audio source separation
63+
or some annotation for speech enhancement. The length must be `C'`.
64+
"""
65+
_inputs = torch.from_numpy(inputs).unsqueeze(0)
66+
sample = S2THubInterface.get_model_input(self.task, _inputs)
67+
text = S2THubInterface.get_prediction(
68+
self.task, self.model, self.generator, sample
69+
)
70+
71+
if self.tts_model is None:
72+
return np.zeros((0,)), self.sampling_rate, [text]
73+
else:
74+
tts_sample = TTSHubInterface.get_model_input(self.tts_task, text)
75+
wav, sr = TTSHubInterface.get_prediction(
76+
self.tts_task, self.tts_model, self.tts_generator, tts_sample
77+
)
78+
return wav.numpy(), sr, [text]

api-inference-community/docker_images/fairseq/requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,5 @@ g2pc==0.9.9.3
44
phonemizer==2.2.1
55
librosa==0.8.1
66
hanziconv==0.3.2
7-
git+git://github.com/pytorch/fairseq.git@43defa1bcb9cc3d5c532d12cba5e01f37dad0350
7+
sentencepiece==0.1.91
8+
git+git://github.com/pytorch/fairseq.git@1d5da6d5b954ba01fc3df12d25d63df27437e20e

api-inference-community/docker_images/fairseq/tests/test_api.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
ALL_TASKS = {
1717
"text-to-speech",
18+
"audio-to-audio",
1819
}
1920

2021

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import base64
2+
import json
3+
import os
4+
from unittest import TestCase, skipIf
5+
6+
from api_inference_community.validation import ffmpeg_read
7+
from app.main import ALLOWED_TASKS
8+
from starlette.testclient import TestClient
9+
from tests.test_api import TESTABLE_MODELS
10+
11+
12+
@skipIf(
13+
"audio-to-audio" not in ALLOWED_TASKS,
14+
"audio-to-audio not implemented",
15+
)
16+
class AudioToAudioTestCase(TestCase):
17+
def setUp(self):
18+
model_id = TESTABLE_MODELS["audio-to-audio"]
19+
self.old_model_id = os.getenv("MODEL_ID")
20+
self.old_task = os.getenv("TASK")
21+
os.environ["MODEL_ID"] = model_id
22+
os.environ["TASK"] = "audio-to-audio"
23+
from app.main import app
24+
25+
self.app = app
26+
27+
@classmethod
28+
def setUpClass(cls):
29+
from app.main import get_pipeline
30+
31+
get_pipeline.cache_clear()
32+
33+
def tearDown(self):
34+
if self.old_model_id is not None:
35+
os.environ["MODEL_ID"] = self.old_model_id
36+
else:
37+
del os.environ["MODEL_ID"]
38+
if self.old_task is not None:
39+
os.environ["TASK"] = self.old_task
40+
else:
41+
del os.environ["TASK"]
42+
43+
def test_simple(self):
44+
bpayload = self.read("sample1.flac")
45+
46+
with TestClient(self.app) as client:
47+
response = client.post("/", data=bpayload)
48+
self.assertEqual(
49+
response.status_code,
50+
200,
51+
)
52+
self.assertEqual(response.headers["content-type"], "application/json")
53+
audio = json.loads(response.content)
54+
55+
self.assertTrue(isinstance(audio, list))
56+
self.assertEqual(set(audio[0].keys()), {"blob", "content-type", "label"})
57+
58+
data = base64.b64decode(audio[0]["blob"])
59+
wavform = ffmpeg_read(data, 16000)
60+
self.assertGreater(wavform.shape[0], 1000)
61+
self.assertTrue(isinstance(audio[0]["content-type"], str))
62+
self.assertTrue(isinstance(audio[0]["label"], str))
63+
64+
def test_malformed_audio(self):
65+
bpayload = self.read("malformed.flac")
66+
67+
with TestClient(self.app) as client:
68+
response = client.post("/", data=bpayload)
69+
70+
self.assertEqual(
71+
response.status_code,
72+
400,
73+
)
74+
self.assertEqual(response.content, b'{"error":"Malformed soundfile"}')

api-inference-community/tests/test_dockers.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,11 @@ def test_fairseq(self):
8585
"text-to-speech",
8686
"facebook/fastspeech2-en-ljspeech",
8787
)
88+
self.framework_docker_test(
89+
"fairseq",
90+
"audio-to-audio",
91+
"facebook/xm_transformer_600m-es_en-multi_domain",
92+
)
8893
self.framework_invalid_test("fairseq")
8994

9095
def test_fasttext(self):

0 commit comments

Comments
 (0)