1
+ import base64
1
2
from tempfile import NamedTemporaryFile
2
3
from typing import Any , Dict
3
4
@@ -16,13 +17,37 @@ def load(self):
16
17
self ._model = WhisperModel (self ._config ["model_metadata" ]["model_id" ])
17
18
18
19
def preprocess (self , request : Dict ) -> Dict :
19
- resp = requests .get (request ["url" ])
20
- return {"response" : resp .content }
20
+ audio_base64 = request .get ("audio" )
21
+ url = request .get ("url" )
22
+
23
+ if audio_base64 and url :
24
+ return {
25
+ "error" : "Only a base64 audio file OR a URL can be passed to the API, not both of them." ,
26
+ }
27
+ if not audio_base64 and not url :
28
+ return {
29
+ "error" : "Please provide either an audio file in base64 string format or a URL to an audio file." ,
30
+ }
31
+
32
+ binary_data = None
33
+
34
+ if audio_base64 :
35
+ binary_data = base64 .b64decode (audio_base64 )
36
+ elif url :
37
+ resp = requests .get (url )
38
+ binary_data = resp .content
39
+
40
+ return {"data" : binary_data }
21
41
22
42
def predict (self , request : Dict ) -> Dict :
43
+ if request .get ("error" ):
44
+ return request
45
+
46
+ audio_data = request .get ("data" )
23
47
result_segments = []
48
+
24
49
with NamedTemporaryFile () as fp :
25
- fp .write (request [ "response" ] )
50
+ fp .write (audio_data )
26
51
segments , info = self ._model .transcribe (
27
52
fp .name ,
28
53
temperature = 0 ,
0 commit comments