9
9
from huggingface_hub import snapshot_download
10
10
11
11
TORCHSERVE_ENDPOINT = "http://0.0.0.0:8888/predictions/whisper_base"
12
+ TORCHSERVE_HEALTH_ENDPOINT = "http://0.0.0.0:8888/ping"
12
13
13
14
14
15
class Model :
15
16
def __init__ (self , ** kwargs ):
16
17
self ._data_dir = kwargs ["data_dir" ]
17
18
self ._model = None
19
+ self .torchserver_ready = False
18
20
19
21
def start_tochserver (self ):
20
22
subprocess .run (
@@ -39,18 +41,29 @@ def load(self):
39
41
local_dir = os .path .join (self ._data_dir , "model_store" ),
40
42
max_workers = 4 ,
41
43
)
42
- print ( "Downloaded weights succesfully !" )
44
+ logging . info ( "⚡️ Weights Downloaded Successfully !" )
43
45
44
46
process = multiprocessing .Process (target = self .start_tochserver )
45
47
process .start ()
46
48
49
+ # Need to wait for the torchserve server to start up
50
+ while not self .torchserver_ready :
51
+ try :
52
+ res = requests .get (TORCHSERVE_HEALTH_ENDPOINT )
53
+ if res .status_code == 200 :
54
+ self .torchserver_ready = True
55
+ logging .info ("🔥Torchserve is ready!" )
56
+ except Exception as e :
57
+ logging .info ("⏳Torchserve is loading..." )
58
+ time .sleep (5 )
59
+
47
60
async def predict (self , request : Dict ):
48
61
audio_base64 = request .get ("audio" )
49
62
audio_bytes = base64 .b64decode (audio_base64 )
50
63
51
64
async with httpx .AsyncClient () as client :
52
65
res = await client .post (
53
- TORCHSERVE_ENDPOINT , files = {"data" : (None , audio_bytes )}
66
+ TORCHSERVE_ENDPOINT , files = {"data" : (None , audio_bytes )}, timeout = 120
54
67
)
55
68
transcription = res .text
56
69
return {"output" : transcription }
0 commit comments