77from starlette .responses import PlainTextResponse , Response
88from starlette .routing import Route
99
10- from huggingface_inference_toolkit .async_utils import async_handler_call
10+ from huggingface_inference_toolkit .async_utils import MAX_CONCURRENT_THREADS , MAX_THREADS_GUARD , async_handler_call
1111from huggingface_inference_toolkit .const import (
1212 HF_FRAMEWORK ,
1313 HF_HUB_TOKEN ,
@@ -69,6 +69,18 @@ async def health(request):
6969 return PlainTextResponse ("Ok" )
7070
7171
72+ # Report Prometheus metrics
73+ # inf_batch_current_size: Current number of requests being processed
74+ # inf_queue_size: Number of requests waiting in the queue
75+ async def metrics (request ):
76+ batch_current_size = MAX_CONCURRENT_THREADS - MAX_THREADS_GUARD .value
77+ queue_size = MAX_THREADS_GUARD .statistics ().tasks_waiting
78+ return PlainTextResponse (
79+ f"inf_batch_current_size { batch_current_size } \n " +
80+ f"inf_queue_size { queue_size } \n "
81+ )
82+
83+
7284async def predict (request ):
7385 try :
7486 # extracts content from request
@@ -132,6 +144,7 @@ async def predict(request):
132144 routes = [
133145 Route (_health_route , health , methods = ["GET" ]),
134146 Route (_predict_route , predict , methods = ["POST" ]),
147+ Route ("/metrics" , metrics , methods = ["GET" ]),
135148 ],
136149 on_startup = [prepare_model_artifacts ],
137150 )
@@ -143,6 +156,7 @@ async def predict(request):
143156 Route ("/health" , health , methods = ["GET" ]),
144157 Route ("/" , predict , methods = ["POST" ]),
145158 Route ("/predict" , predict , methods = ["POST" ]),
159+ Route ("/metrics" , metrics , methods = ["GET" ]),
146160 ],
147161 on_startup = [prepare_model_artifacts ],
148162 )
0 commit comments