1111from boxmot .utils .checks import RequirementsChecker
1212
1313
14- def _load_existing_osnet_model_and_input ():
14+ def _load_existing_osnet_model_and_input (batch_size = 2 ):
1515 candidates = [
1616 ROOT / "osnet_x0_25_msmt17.pt" ,
1717 WEIGHTS / "osnet_x0_25_msmt17.pt" ,
@@ -22,7 +22,7 @@ def _load_existing_osnet_model_and_input():
2222
2323 backend = ReidAutoBackend (weights = weights , device = "cpu" , half = False )
2424 model = backend .model .model .eval ()
25- im = torch .randn (2 , 3 , 256 , 128 )
25+ im = torch .randn (batch_size , 3 , 256 , 128 )
2626 return model , im
2727
2828
@@ -42,75 +42,86 @@ def _disable_dep_sync(monkeypatch):
4242 monkeypatch .setattr (RequirementsChecker , "sync_extra" , lambda * args , ** kwargs : None )
4343
4444
45- def test_onnx_export_dynamic_uses_dynamic_shapes (monkeypatch , tmp_path ):
45+ @pytest .mark .parametrize ("batch_size" , [1 , 2 , 4 ])
46+ def test_onnx_export_dynamic_uses_dynamic_shapes (monkeypatch , tmp_path , batch_size ):
4647 _disable_dep_sync (monkeypatch )
4748 _install_fake_onnx (monkeypatch )
4849
4950 calls = []
5051
5152 def fake_export (model , args , f , ** kwargs ):
52- calls .append (kwargs )
53+ calls .append (( args , kwargs ) )
5354 Path (f ).touch ()
5455
5556 monkeypatch .setattr (torch .onnx , "export" , fake_export )
5657
57- model , im = _load_existing_osnet_model_and_input ()
58+ model , im = _load_existing_osnet_model_and_input (batch_size = batch_size )
5859 out_file = tmp_path / "osnet_x0_25_msmt17.pt"
5960
6061 exporter = ONNXExporter (model , im , out_file , opset = 17 , dynamic = True , half = False , simplify = False )
6162 exported = exporter .export ()
6263
6364 assert exported == out_file .with_suffix (".onnx" )
6465 assert len (calls ) == 1
65- assert "dynamic_shapes" in calls [0 ]
66- assert "dynamic_axes" not in calls [0 ]
66+ export_args , export_kwargs = calls [0 ]
67+ assert export_args [0 ].shape [0 ] == batch_size
68+ assert "dynamic_shapes" in export_kwargs
69+ assert "dynamic_axes" not in export_kwargs
6770
6871
69- def test_onnx_export_dynamic_fallback_uses_dynamic_axes (monkeypatch , tmp_path ):
72+ @pytest .mark .parametrize ("batch_size" , [1 , 3 ])
73+ def test_onnx_export_dynamic_fallback_uses_dynamic_axes (monkeypatch , tmp_path , batch_size ):
7074 _disable_dep_sync (monkeypatch )
7175 _install_fake_onnx (monkeypatch )
7276
7377 calls = []
7478
7579 def fake_export (model , args , f , ** kwargs ):
76- calls .append (kwargs )
80+ calls .append (( args , kwargs ) )
7781 if len (calls ) == 1 :
7882 raise RuntimeError ("force dynamic fallback" )
7983 Path (f ).touch ()
8084
8185 monkeypatch .setattr (torch .onnx , "export" , fake_export )
8286
83- model , im = _load_existing_osnet_model_and_input ()
87+ model , im = _load_existing_osnet_model_and_input (batch_size = batch_size )
8488 out_file = tmp_path / "osnet_x0_25_msmt17.pt"
8589
8690 exporter = ONNXExporter (model , im , out_file , opset = 17 , dynamic = True , half = False , simplify = False )
8791 exported = exporter .export ()
8892
8993 assert exported == out_file .with_suffix (".onnx" )
9094 assert len (calls ) == 2
91- assert "dynamic_shapes" in calls [0 ]
92- assert "dynamic_axes" in calls [1 ]
95+ first_args , first_kwargs = calls [0 ]
96+ second_args , second_kwargs = calls [1 ]
97+ assert first_args [0 ].shape [0 ] == batch_size
98+ assert second_args [0 ].shape [0 ] == batch_size
99+ assert "dynamic_shapes" in first_kwargs
100+ assert "dynamic_axes" in second_kwargs
93101
94102
95- def test_onnx_export_static_has_no_dynamic_shapes (monkeypatch , tmp_path ):
103+ @pytest .mark .parametrize ("batch_size" , [1 , 2 , 5 ])
104+ def test_onnx_export_static_has_no_dynamic_shapes (monkeypatch , tmp_path , batch_size ):
96105 _disable_dep_sync (monkeypatch )
97106 _install_fake_onnx (monkeypatch )
98107
99108 calls = []
100109
101110 def fake_export (model , args , f , ** kwargs ):
102- calls .append (kwargs )
111+ calls .append (( args , kwargs ) )
103112 Path (f ).touch ()
104113
105114 monkeypatch .setattr (torch .onnx , "export" , fake_export )
106115
107- model , im = _load_existing_osnet_model_and_input ()
116+ model , im = _load_existing_osnet_model_and_input (batch_size = batch_size )
108117 out_file = tmp_path / "osnet_x0_25_msmt17.pt"
109118
110119 exporter = ONNXExporter (model , im , out_file , opset = 17 , dynamic = False , half = False , simplify = False )
111120 exported = exporter .export ()
112121
113122 assert exported == out_file .with_suffix (".onnx" )
114123 assert len (calls ) == 1
115- assert "dynamic_shapes" not in calls [0 ]
116- assert "dynamic_axes" not in calls [0 ]
124+ export_args , export_kwargs = calls [0 ]
125+ assert export_args [0 ].shape [0 ] == batch_size
126+ assert "dynamic_shapes" not in export_kwargs
127+ assert "dynamic_axes" not in export_kwargs
0 commit comments