Skip to content

Commit 9fd4602

Browse files
authored
Making faster whisper input more uniform (#218)
1 parent 194d19c commit 9fd4602

File tree

3 files changed

+84
-9
lines changed

3 files changed

+84
-9
lines changed

templates/faster-whisper-truss/model/model.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import base64
12
from tempfile import NamedTemporaryFile
23
from typing import Any, Dict
34

@@ -16,13 +17,37 @@ def load(self):
1617
self._model = WhisperModel(self._config["model_metadata"]["model_id"])
1718

1819
def preprocess(self, request: Dict) -> Dict:
19-
resp = requests.get(request["url"])
20-
return {"response": resp.content}
20+
audio_base64 = request.get("audio")
21+
url = request.get("url")
22+
23+
if audio_base64 and url:
24+
return {
25+
"error": "Only a base64 audio file OR a URL can be passed to the API, not both of them.",
26+
}
27+
if not audio_base64 and not url:
28+
return {
29+
"error": "Please provide either an audio file in base64 string format or a URL to an audio file.",
30+
}
31+
32+
binary_data = None
33+
34+
if audio_base64:
35+
binary_data = base64.b64decode(audio_base64)
36+
elif url:
37+
resp = requests.get(url)
38+
binary_data = resp.content
39+
40+
return {"data": binary_data}
2141

2242
def predict(self, request: Dict) -> Dict:
43+
if request.get("error"):
44+
return request
45+
46+
audio_data = request.get("data")
2347
result_segments = []
48+
2449
with NamedTemporaryFile() as fp:
25-
fp.write(request["response"])
50+
fp.write(audio_data)
2651
segments, info = self._model.transcribe(
2752
fp.name,
2853
temperature=0,

whisper/faster-whisper-v2/model/model.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import base64
12
from tempfile import NamedTemporaryFile
23
from typing import Any, Dict
34

@@ -16,13 +17,37 @@ def load(self):
1617
self._model = WhisperModel(self._config["model_metadata"]["model_id"])
1718

1819
def preprocess(self, request: Dict) -> Dict:
19-
resp = requests.get(request["url"])
20-
return {"response": resp.content}
20+
audio_base64 = request.get("audio")
21+
url = request.get("url")
22+
23+
if audio_base64 and url:
24+
return {
25+
"error": "Only a base64 audio file OR a URL can be passed to the API, not both of them.",
26+
}
27+
if not audio_base64 and not url:
28+
return {
29+
"error": "Please provide either an audio file in base64 string format or a URL to an audio file.",
30+
}
31+
32+
binary_data = None
33+
34+
if audio_base64:
35+
binary_data = base64.b64decode(audio_base64)
36+
elif url:
37+
resp = requests.get(url)
38+
binary_data = resp.content
39+
40+
return {"data": binary_data}
2141

2242
def predict(self, request: Dict) -> Dict:
43+
if request.get("error"):
44+
return request
45+
46+
audio_data = request.get("data")
2347
result_segments = []
48+
2449
with NamedTemporaryFile() as fp:
25-
fp.write(request["response"])
50+
fp.write(audio_data)
2651
segments, info = self._model.transcribe(
2752
fp.name,
2853
temperature=0,

whisper/faster-whisper-v3/model/model.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import base64
12
from tempfile import NamedTemporaryFile
23
from typing import Any, Dict
34

@@ -16,13 +17,37 @@ def load(self):
1617
self._model = WhisperModel(self._config["model_metadata"]["model_id"])
1718

1819
def preprocess(self, request: Dict) -> Dict:
19-
resp = requests.get(request["url"])
20-
return {"response": resp.content}
20+
audio_base64 = request.get("audio")
21+
url = request.get("url")
22+
23+
if audio_base64 and url:
24+
return {
25+
"error": "Only a base64 audio file OR a URL can be passed to the API, not both of them.",
26+
}
27+
if not audio_base64 and not url:
28+
return {
29+
"error": "Please provide either an audio file in base64 string format or a URL to an audio file.",
30+
}
31+
32+
binary_data = None
33+
34+
if audio_base64:
35+
binary_data = base64.b64decode(audio_base64)
36+
elif url:
37+
resp = requests.get(url)
38+
binary_data = resp.content
39+
40+
return {"data": binary_data}
2141

2242
def predict(self, request: Dict) -> Dict:
43+
if request.get("error"):
44+
return request
45+
46+
audio_data = request.get("data")
2347
result_segments = []
48+
2449
with NamedTemporaryFile() as fp:
25-
fp.write(request["response"])
50+
fp.write(audio_data)
2651
segments, info = self._model.transcribe(
2752
fp.name,
2853
temperature=0,

0 commit comments

Comments
 (0)