Skip to content

Commit c290903

Browse files
add dynamic batch tests
1 parent 5b0b196 commit c290903

File tree

1 file changed

+116
-0
lines changed

1 file changed

+116
-0
lines changed
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
import types
2+
import sys
3+
from pathlib import Path
4+
5+
import pytest
6+
import torch
7+
8+
from boxmot.reid.core.auto_backend import ReidAutoBackend
9+
from boxmot.reid.exporters.onnx_exporter import ONNXExporter
10+
from boxmot.utils import ROOT, WEIGHTS
11+
from boxmot.utils.checks import RequirementsChecker
12+
13+
14+
def _load_existing_osnet_model_and_input():
15+
candidates = [
16+
ROOT / "osnet_x0_25_msmt17.pt",
17+
WEIGHTS / "osnet_x0_25_msmt17.pt",
18+
]
19+
weights = next((p for p in candidates if p.exists()), None)
20+
if weights is None:
21+
pytest.skip("Missing osnet_x0_25_msmt17.pt in repository root or engine/weights.")
22+
23+
backend = ReidAutoBackend(weights=weights, device="cpu", half=False)
24+
model = backend.model.model.eval()
25+
im = torch.randn(2, 3, 256, 128)
26+
return model, im
27+
28+
29+
def _install_fake_onnx(monkeypatch):
30+
fake_onnx_model = types.SimpleNamespace(ir_version=9)
31+
fake_onnx = types.SimpleNamespace(
32+
__version__="0.0-test",
33+
defs=types.SimpleNamespace(onnx_opset_version=lambda: 18),
34+
checker=types.SimpleNamespace(check_model=lambda _m: None),
35+
load=lambda _p: fake_onnx_model,
36+
save=lambda _m, _p: None,
37+
)
38+
monkeypatch.setitem(sys.modules, "onnx", fake_onnx)
39+
40+
41+
def _disable_dep_sync(monkeypatch):
42+
monkeypatch.setattr(RequirementsChecker, "sync_extra", lambda *args, **kwargs: None)
43+
44+
45+
def test_onnx_export_dynamic_uses_dynamic_shapes(monkeypatch, tmp_path):
46+
_disable_dep_sync(monkeypatch)
47+
_install_fake_onnx(monkeypatch)
48+
49+
calls = []
50+
51+
def fake_export(model, args, f, **kwargs):
52+
calls.append(kwargs)
53+
Path(f).touch()
54+
55+
monkeypatch.setattr(torch.onnx, "export", fake_export)
56+
57+
model, im = _load_existing_osnet_model_and_input()
58+
out_file = tmp_path / "osnet_x0_25_msmt17.pt"
59+
60+
exporter = ONNXExporter(model, im, out_file, opset=17, dynamic=True, half=False, simplify=False)
61+
exported = exporter.export()
62+
63+
assert exported == out_file.with_suffix(".onnx")
64+
assert len(calls) == 1
65+
assert "dynamic_shapes" in calls[0]
66+
assert "dynamic_axes" not in calls[0]
67+
68+
69+
def test_onnx_export_dynamic_fallback_uses_dynamic_axes(monkeypatch, tmp_path):
70+
_disable_dep_sync(monkeypatch)
71+
_install_fake_onnx(monkeypatch)
72+
73+
calls = []
74+
75+
def fake_export(model, args, f, **kwargs):
76+
calls.append(kwargs)
77+
if len(calls) == 1:
78+
raise RuntimeError("force dynamic fallback")
79+
Path(f).touch()
80+
81+
monkeypatch.setattr(torch.onnx, "export", fake_export)
82+
83+
model, im = _load_existing_osnet_model_and_input()
84+
out_file = tmp_path / "osnet_x0_25_msmt17.pt"
85+
86+
exporter = ONNXExporter(model, im, out_file, opset=17, dynamic=True, half=False, simplify=False)
87+
exported = exporter.export()
88+
89+
assert exported == out_file.with_suffix(".onnx")
90+
assert len(calls) == 2
91+
assert "dynamic_shapes" in calls[0]
92+
assert "dynamic_axes" in calls[1]
93+
94+
95+
def test_onnx_export_static_has_no_dynamic_shapes(monkeypatch, tmp_path):
96+
_disable_dep_sync(monkeypatch)
97+
_install_fake_onnx(monkeypatch)
98+
99+
calls = []
100+
101+
def fake_export(model, args, f, **kwargs):
102+
calls.append(kwargs)
103+
Path(f).touch()
104+
105+
monkeypatch.setattr(torch.onnx, "export", fake_export)
106+
107+
model, im = _load_existing_osnet_model_and_input()
108+
out_file = tmp_path / "osnet_x0_25_msmt17.pt"
109+
110+
exporter = ONNXExporter(model, im, out_file, opset=17, dynamic=False, half=False, simplify=False)
111+
exported = exporter.export()
112+
113+
assert exported == out_file.with_suffix(".onnx")
114+
assert len(calls) == 1
115+
assert "dynamic_shapes" not in calls[0]
116+
assert "dynamic_axes" not in calls[0]

0 commit comments

Comments
 (0)