Skip to content

Commit 29ca18c

Browse files
authored
Add warm-up to onnx as some GPUs require kernel compilation before accepting inferences (#22685)
1 parent 148e11a commit 29ca18c

File tree

1 file changed

+28
-0
lines changed

1 file changed

+28
-0
lines changed

frigate/detectors/plugins/onnx.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from frigate.detectors.detection_runners import get_optimized_runner
99
from frigate.detectors.detector_config import (
1010
BaseDetectorConfig,
11+
InputDTypeEnum,
12+
InputTensorEnum,
1113
ModelTypeEnum,
1214
)
1315
from frigate.util.model import (
@@ -59,8 +61,34 @@ def __init__(self, detector_config: ONNXDetectorConfig):
5961
if self.onnx_model_type == ModelTypeEnum.yolox:
6062
self.calculate_grids_strides()
6163

64+
self._warmup(detector_config)
6265
logger.info(f"ONNX: {path} loaded")
6366

67+
def _warmup(self, detector_config: ONNXDetectorConfig) -> None:
68+
"""Run a warmup inference to front-load one-time compilation costs.
69+
70+
Some GPU backends have a slow first inference: CUDA may need PTX JIT
71+
compilation on newer architectures (e.g. NVIDIA 50-series / Blackwell),
72+
and MIGraphX compiles the model graph on first run. Running it here
73+
(during detector creation) keeps the watchdog start_time at 0.0 so the
74+
process won't be killed.
75+
"""
76+
if detector_config.model.input_tensor == InputTensorEnum.nchw:
77+
shape = (1, 3, detector_config.model.height, detector_config.model.width)
78+
else:
79+
shape = (1, detector_config.model.height, detector_config.model.width, 3)
80+
81+
if detector_config.model.input_dtype in (
82+
InputDTypeEnum.float,
83+
InputDTypeEnum.float_denorm,
84+
):
85+
dtype = np.float32
86+
else:
87+
dtype = np.uint8
88+
89+
logger.info("ONNX: warming up detector (may take a while on first run)...")
90+
self.detect_raw(np.zeros(shape, dtype=dtype))
91+
6492
def detect_raw(self, tensor_input: np.ndarray):
6593
if self.onnx_model_type == ModelTypeEnum.dfine:
6694
tensor_output = self.runner.run(

0 commit comments

Comments
 (0)