4141
4242# from supervision.detection.utils import mask_to_xyxy
4343from focoos .ports import (
44+ FocoosTask ,
4445 LatencyMetrics ,
4546 ModelMetadata ,
4647 OnnxRuntimeOpts ,
@@ -210,7 +211,12 @@ def _setup_providers(self, model_dir: str):
210211
211212 def _warmup (self ):
212213 self .logger .info ("⏱️ [onnxruntime] Warming up model .." )
213- np_image = np .random .rand (1 , 3 , 640 , 640 ).astype (self .dtype )
214+ size = (
215+ self .model_metadata .im_size
216+ if self .model_metadata .task == FocoosTask .DETECTION and self .model_metadata .im_size
217+ else 640
218+ )
219+ np_image = np .random .rand (1 , 3 , size , size ).astype (self .dtype )
214220 input_name = self .ort_sess .get_inputs ()[0 ].name
215221 out_name = [output .name for output in self .ort_sess .get_outputs ()]
216222
@@ -312,7 +318,7 @@ def __init__(
312318 self .logger = get_logger (name = "TorchscriptEngine" )
313319 self .logger .info (f"🔧 [torchscript] Device: { self .device } " )
314320 self .opts = opts
315-
321+ self . model_metadata = model_metadata
316322 map_location = None if torch .cuda .is_available () else "cpu"
317323
318324 self .model = torch .jit .load (model_path , map_location = map_location )
@@ -321,7 +327,12 @@ def __init__(
321327 if self .opts .warmup_iter > 0 :
322328 self .logger .info ("⏱️ [torchscript] Warming up model.." )
323329 with torch .no_grad ():
324- np_image = torch .rand (1 , 3 , 640 , 640 , device = self .device )
330+ size = (
331+ self .model_metadata .im_size
332+ if self .model_metadata .task == FocoosTask .DETECTION and self .model_metadata .im_size
333+ else 640
334+ )
335+ np_image = torch .rand (1 , 3 , size , size , device = self .device )
325336 for _ in range (self .opts .warmup_iter ):
326337 self .model (np_image )
327338 self .logger .info ("⏱️ [torchscript] WARMUP DONE" )
0 commit comments