Skip to content

Commit b95e256

Browse files
infer with different btch sizes
1 parent c290903 commit b95e256

File tree

1 file changed

+28
-17
lines changed

1 file changed

+28
-17
lines changed

tests/unit/test_exporters_dynamic.py

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from boxmot.utils.checks import RequirementsChecker
1212

1313

14-
def _load_existing_osnet_model_and_input():
14+
def _load_existing_osnet_model_and_input(batch_size=2):
1515
candidates = [
1616
ROOT / "osnet_x0_25_msmt17.pt",
1717
WEIGHTS / "osnet_x0_25_msmt17.pt",
@@ -22,7 +22,7 @@ def _load_existing_osnet_model_and_input():
2222

2323
backend = ReidAutoBackend(weights=weights, device="cpu", half=False)
2424
model = backend.model.model.eval()
25-
im = torch.randn(2, 3, 256, 128)
25+
im = torch.randn(batch_size, 3, 256, 128)
2626
return model, im
2727

2828

@@ -42,75 +42,86 @@ def _disable_dep_sync(monkeypatch):
4242
monkeypatch.setattr(RequirementsChecker, "sync_extra", lambda *args, **kwargs: None)
4343

4444

45-
def test_onnx_export_dynamic_uses_dynamic_shapes(monkeypatch, tmp_path):
45+
@pytest.mark.parametrize("batch_size", [1, 2, 4])
46+
def test_onnx_export_dynamic_uses_dynamic_shapes(monkeypatch, tmp_path, batch_size):
4647
_disable_dep_sync(monkeypatch)
4748
_install_fake_onnx(monkeypatch)
4849

4950
calls = []
5051

5152
def fake_export(model, args, f, **kwargs):
52-
calls.append(kwargs)
53+
calls.append((args, kwargs))
5354
Path(f).touch()
5455

5556
monkeypatch.setattr(torch.onnx, "export", fake_export)
5657

57-
model, im = _load_existing_osnet_model_and_input()
58+
model, im = _load_existing_osnet_model_and_input(batch_size=batch_size)
5859
out_file = tmp_path / "osnet_x0_25_msmt17.pt"
5960

6061
exporter = ONNXExporter(model, im, out_file, opset=17, dynamic=True, half=False, simplify=False)
6162
exported = exporter.export()
6263

6364
assert exported == out_file.with_suffix(".onnx")
6465
assert len(calls) == 1
65-
assert "dynamic_shapes" in calls[0]
66-
assert "dynamic_axes" not in calls[0]
66+
export_args, export_kwargs = calls[0]
67+
assert export_args[0].shape[0] == batch_size
68+
assert "dynamic_shapes" in export_kwargs
69+
assert "dynamic_axes" not in export_kwargs
6770

6871

69-
def test_onnx_export_dynamic_fallback_uses_dynamic_axes(monkeypatch, tmp_path):
72+
@pytest.mark.parametrize("batch_size", [1, 3])
73+
def test_onnx_export_dynamic_fallback_uses_dynamic_axes(monkeypatch, tmp_path, batch_size):
7074
_disable_dep_sync(monkeypatch)
7175
_install_fake_onnx(monkeypatch)
7276

7377
calls = []
7478

7579
def fake_export(model, args, f, **kwargs):
76-
calls.append(kwargs)
80+
calls.append((args, kwargs))
7781
if len(calls) == 1:
7882
raise RuntimeError("force dynamic fallback")
7983
Path(f).touch()
8084

8185
monkeypatch.setattr(torch.onnx, "export", fake_export)
8286

83-
model, im = _load_existing_osnet_model_and_input()
87+
model, im = _load_existing_osnet_model_and_input(batch_size=batch_size)
8488
out_file = tmp_path / "osnet_x0_25_msmt17.pt"
8589

8690
exporter = ONNXExporter(model, im, out_file, opset=17, dynamic=True, half=False, simplify=False)
8791
exported = exporter.export()
8892

8993
assert exported == out_file.with_suffix(".onnx")
9094
assert len(calls) == 2
91-
assert "dynamic_shapes" in calls[0]
92-
assert "dynamic_axes" in calls[1]
95+
first_args, first_kwargs = calls[0]
96+
second_args, second_kwargs = calls[1]
97+
assert first_args[0].shape[0] == batch_size
98+
assert second_args[0].shape[0] == batch_size
99+
assert "dynamic_shapes" in first_kwargs
100+
assert "dynamic_axes" in second_kwargs
93101

94102

95-
def test_onnx_export_static_has_no_dynamic_shapes(monkeypatch, tmp_path):
103+
@pytest.mark.parametrize("batch_size", [1, 2, 5])
104+
def test_onnx_export_static_has_no_dynamic_shapes(monkeypatch, tmp_path, batch_size):
96105
_disable_dep_sync(monkeypatch)
97106
_install_fake_onnx(monkeypatch)
98107

99108
calls = []
100109

101110
def fake_export(model, args, f, **kwargs):
102-
calls.append(kwargs)
111+
calls.append((args, kwargs))
103112
Path(f).touch()
104113

105114
monkeypatch.setattr(torch.onnx, "export", fake_export)
106115

107-
model, im = _load_existing_osnet_model_and_input()
116+
model, im = _load_existing_osnet_model_and_input(batch_size=batch_size)
108117
out_file = tmp_path / "osnet_x0_25_msmt17.pt"
109118

110119
exporter = ONNXExporter(model, im, out_file, opset=17, dynamic=False, half=False, simplify=False)
111120
exported = exporter.export()
112121

113122
assert exported == out_file.with_suffix(".onnx")
114123
assert len(calls) == 1
115-
assert "dynamic_shapes" not in calls[0]
116-
assert "dynamic_axes" not in calls[0]
124+
export_args, export_kwargs = calls[0]
125+
assert export_args[0].shape[0] == batch_size
126+
assert "dynamic_shapes" not in export_kwargs
127+
assert "dynamic_axes" not in export_kwargs

0 commit comments

Comments
 (0)