Skip to content

Commit 5b0b196

Browse files
constrain dynamic batch range to satisfy torch.export shape guards on CUDA
1 parent bc97b65 commit 5b0b196

File tree

1 file changed

+47
-14
lines changed

1 file changed

+47
-14
lines changed

boxmot/reid/exporters/onnx_exporter.py

Lines changed: 47 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
1+
import inspect
22
import torch
33
from torch.export import Dim
44

@@ -27,21 +27,54 @@ def export(self):
2727

2828
# --- Export ---
2929
args = (self.im,)
30+
export_sig = inspect.signature(torch.onnx.export)
31+
has_dynamo_arg = "dynamo" in export_sig.parameters
32+
33+
export_kwargs = {
34+
"opset_version": opset,
35+
"input_names": ["images"],
36+
"output_names": output_names,
37+
}
3038

3139
if self.dynamic:
32-
dynamic_shapes = ({0: Dim("batch")},) # first (and only) input tensor: dim0 is dynamic
33-
else:
34-
dynamic_shapes = None
35-
36-
torch.onnx.export(
37-
self.model,
38-
args,
39-
str(f),
40-
opset_version=opset,
41-
input_names=["images"],
42-
output_names=output_names,
43-
dynamic_shapes=dynamic_shapes,
44-
)
40+
# Constrain dynamic batch range to satisfy torch.export shape guards on CUDA.
41+
export_kwargs["dynamic_shapes"] = ({0: Dim("batch", min=1, max=65535)},)
42+
43+
if has_dynamo_arg:
44+
export_kwargs["dynamo"] = True
45+
46+
try:
47+
torch.onnx.export(
48+
self.model,
49+
args,
50+
str(f),
51+
**export_kwargs,
52+
)
53+
except Exception as e:
54+
if not self.dynamic:
55+
raise
56+
57+
LOGGER.warning(
58+
f"Dynamic export via torch.export failed ({e}). "
59+
"Retrying with legacy dynamic_axes export..."
60+
)
61+
62+
# Fallback for torch.export/dynamo dynamic shape guard failures.
63+
fallback_kwargs = {
64+
"opset_version": opset,
65+
"input_names": ["images"],
66+
"output_names": output_names,
67+
"dynamic_axes": self._build_dynamic_axes(output_names),
68+
}
69+
if has_dynamo_arg:
70+
fallback_kwargs["dynamo"] = False
71+
72+
torch.onnx.export(
73+
self.model,
74+
args,
75+
str(f),
76+
**fallback_kwargs,
77+
)
4578

4679
# --- Load + validate ---
4780
model_onnx = onnx.load(str(f))

0 commit comments

Comments
 (0)