Skip to content

Commit 116a8c1

Browse files
Merge pull request #94 from FocoosAI/fix/warmup-runtime
fix(runtime): adjust warmup image size based on model metadata
2 parents 2e4183e + e813633 commit 116a8c1

File tree

1 file changed

+14
-3
lines changed

1 file changed

+14
-3
lines changed

focoos/runtime.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141

4242
# from supervision.detection.utils import mask_to_xyxy
4343
from focoos.ports import (
44+
FocoosTask,
4445
LatencyMetrics,
4546
ModelMetadata,
4647
OnnxRuntimeOpts,
@@ -210,7 +211,12 @@ def _setup_providers(self, model_dir: str):
210211

211212
def _warmup(self):
212213
self.logger.info("⏱️ [onnxruntime] Warming up model ..")
213-
np_image = np.random.rand(1, 3, 640, 640).astype(self.dtype)
214+
size = (
215+
self.model_metadata.im_size
216+
if self.model_metadata.task == FocoosTask.DETECTION and self.model_metadata.im_size
217+
else 640
218+
)
219+
np_image = np.random.rand(1, 3, size, size).astype(self.dtype)
214220
input_name = self.ort_sess.get_inputs()[0].name
215221
out_name = [output.name for output in self.ort_sess.get_outputs()]
216222

@@ -312,7 +318,7 @@ def __init__(
312318
self.logger = get_logger(name="TorchscriptEngine")
313319
self.logger.info(f"🔧 [torchscript] Device: {self.device}")
314320
self.opts = opts
315-
321+
self.model_metadata = model_metadata
316322
map_location = None if torch.cuda.is_available() else "cpu"
317323

318324
self.model = torch.jit.load(model_path, map_location=map_location)
@@ -321,7 +327,12 @@ def __init__(
321327
if self.opts.warmup_iter > 0:
322328
self.logger.info("⏱️ [torchscript] Warming up model..")
323329
with torch.no_grad():
324-
np_image = torch.rand(1, 3, 640, 640, device=self.device)
330+
size = (
331+
self.model_metadata.im_size
332+
if self.model_metadata.task == FocoosTask.DETECTION and self.model_metadata.im_size
333+
else 640
334+
)
335+
np_image = torch.rand(1, 3, size, size, device=self.device)
325336
for _ in range(self.opts.warmup_iter):
326337
self.model(np_image)
327338
self.logger.info("⏱️ [torchscript] WARMUP DONE")

0 commit comments

Comments
 (0)