|
1 | | - |
| 1 | +import inspect |
2 | 2 | import torch |
3 | 3 | from torch.export import Dim |
4 | 4 |
|
@@ -27,21 +27,54 @@ def export(self): |
27 | 27 |
|
28 | 28 | # --- Export --- |
29 | 29 | args = (self.im,) |
| 30 | + export_sig = inspect.signature(torch.onnx.export) |
| 31 | + has_dynamo_arg = "dynamo" in export_sig.parameters |
| 32 | + |
| 33 | + export_kwargs = { |
| 34 | + "opset_version": opset, |
| 35 | + "input_names": ["images"], |
| 36 | + "output_names": output_names, |
| 37 | + } |
30 | 38 |
|
31 | 39 | if self.dynamic: |
32 | | - dynamic_shapes = ({0: Dim("batch")},) # first (and only) input tensor: dim0 is dynamic |
33 | | - else: |
34 | | - dynamic_shapes = None |
35 | | - |
36 | | - torch.onnx.export( |
37 | | - self.model, |
38 | | - args, |
39 | | - str(f), |
40 | | - opset_version=opset, |
41 | | - input_names=["images"], |
42 | | - output_names=output_names, |
43 | | - dynamic_shapes=dynamic_shapes, |
44 | | - ) |
| 40 | + # Constrain dynamic batch range to satisfy torch.export shape guards on CUDA. |
| 41 | + export_kwargs["dynamic_shapes"] = ({0: Dim("batch", min=1, max=65535)},) |
| 42 | + |
| 43 | + if has_dynamo_arg: |
| 44 | + export_kwargs["dynamo"] = True |
| 45 | + |
| 46 | + try: |
| 47 | + torch.onnx.export( |
| 48 | + self.model, |
| 49 | + args, |
| 50 | + str(f), |
| 51 | + **export_kwargs, |
| 52 | + ) |
| 53 | + except Exception as e: |
| 54 | + if not self.dynamic: |
| 55 | + raise |
| 56 | + |
| 57 | + LOGGER.warning( |
| 58 | + f"Dynamic export via torch.export failed ({e}). " |
| 59 | + "Retrying with legacy dynamic_axes export..." |
| 60 | + ) |
| 61 | + |
| 62 | + # Fallback for torch.export/dynamo dynamic shape guard failures. |
| 63 | + fallback_kwargs = { |
| 64 | + "opset_version": opset, |
| 65 | + "input_names": ["images"], |
| 66 | + "output_names": output_names, |
| 67 | + "dynamic_axes": self._build_dynamic_axes(output_names), |
| 68 | + } |
| 69 | + if has_dynamo_arg: |
| 70 | + fallback_kwargs["dynamo"] = False |
| 71 | + |
| 72 | + torch.onnx.export( |
| 73 | + self.model, |
| 74 | + args, |
| 75 | + str(f), |
| 76 | + **fallback_kwargs, |
| 77 | + ) |
45 | 78 |
|
46 | 79 | # --- Load + validate --- |
47 | 80 | model_onnx = onnx.load(str(f)) |
|
0 commit comments