Skip to content

Commit b82bb62

Browse files
committed
large onnx file support
1 parent 202c4d8 commit b82bb62

File tree

14 files changed

+552
-49
lines changed

14 files changed

+552
-49
lines changed

src/onnx2fx/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@
4747
UnsupportedOpError,
4848
ConversionError,
4949
ValueNotFoundError,
50+
UnsupportedDTypeError,
51+
ExternalDataError,
52+
InferenceOnlyError,
5053
)
5154
from .op_registry import (
5255
register_op,
@@ -79,4 +82,7 @@
7982
"UnsupportedOpError",
8083
"ConversionError",
8184
"ValueNotFoundError",
85+
"UnsupportedDTypeError",
86+
"ExternalDataError",
87+
"InferenceOnlyError",
8288
]

src/onnx2fx/converter.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,23 +5,34 @@
55
ONNX models into equivalent PyTorch FX GraphModules.
66
"""
77

8-
from typing import Union
8+
import os
9+
from typing import Optional, Union
910

1011
import onnx
1112
import torch
1213

1314
from .graph_builder import GraphBuilder
15+
from .utils.external_data import validate_external_data_model
1416

1517

1618
def convert(
1719
model: Union[onnx.ModelProto, str],
20+
*,
21+
base_dir: Optional[str] = None,
22+
memmap_external_data: bool = False,
1823
) -> torch.fx.GraphModule:
1924
"""Convert an ONNX model into a ``torch.fx.GraphModule``.
2025
2126
Parameters
2227
----------
2328
model : Union[onnx.ModelProto, str]
2429
Either an in-memory ``onnx.ModelProto`` or a file path to an ONNX model.
30+
base_dir : Optional[str], optional
31+
Base directory for resolving external data tensors. Required when
32+
``memmap_external_data=True`` and a relative external data path is used.
33+
memmap_external_data : bool, optional
34+
If True, do not load external data into memory. Instead, keep external
35+
data references for memmap-based loading during conversion.
2536
2637
Returns
2738
-------
@@ -30,10 +41,22 @@ def convert(
3041
"""
3142

3243
if isinstance(model, str):
33-
model = onnx.load(model)
44+
if base_dir is None:
45+
base_dir = os.path.dirname(os.path.abspath(model))
46+
if memmap_external_data:
47+
model = onnx.load_model(model, load_external_data=False)
48+
else:
49+
model = onnx.load_model(model)
3450
elif isinstance(model, onnx.ModelProto):
3551
model = model
3652
else:
3753
raise TypeError("model must be a path or onnx.ModelProto")
3854

39-
return GraphBuilder(model).build()
55+
if memmap_external_data:
56+
validate_external_data_model(model, base_dir=base_dir, strict=True)
57+
58+
return GraphBuilder(
59+
model,
60+
base_dir=base_dir,
61+
memmap_external_data=memmap_external_data,
62+
).build()

src/onnx2fx/exceptions.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,3 +106,50 @@ def __init__(self, name: str, available: list[str] | None = None):
106106
if available:
107107
message += f". Available: {available}"
108108
super().__init__(message)
109+
110+
111+
class UnsupportedDTypeError(Onnx2FxError):
112+
"""Raised when an ONNX tensor dtype is not supported.
113+
114+
Parameters
115+
----------
116+
onnx_dtype : int
117+
ONNX TensorProto data type enum value.
118+
tensor_name : str
119+
Name of the tensor.
120+
details : str, optional
121+
Additional details about the failure.
122+
"""
123+
124+
def __init__(self, onnx_dtype: int, tensor_name: str, details: str | None = None):
125+
self.onnx_dtype = onnx_dtype
126+
self.tensor_name = tensor_name
127+
self.details = details
128+
129+
dtype_name = f"{onnx_dtype}"
130+
try:
131+
import onnx
132+
133+
dtype_name = onnx.TensorProto.DataType.Name(onnx_dtype)
134+
except Exception:
135+
pass
136+
137+
message = f"Unsupported dtype for tensor '{tensor_name}': {dtype_name}"
138+
if details:
139+
message += f" ({details})"
140+
super().__init__(message)
141+
142+
143+
class ExternalDataError(Onnx2FxError):
144+
"""Raised when external data metadata is invalid or inaccessible."""
145+
146+
def __init__(self, tensor_name: str, message: str):
147+
self.tensor_name = tensor_name
148+
super().__init__(f"External data error for '{tensor_name}': {message}")
149+
150+
151+
class InferenceOnlyError(Onnx2FxError):
152+
"""Raised when an inference-only model is used for training."""
153+
154+
def __init__(self, message: str):
155+
super().__init__(message)

src/onnx2fx/graph_builder.py

Lines changed: 59 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,17 @@
22
from collections import deque
33
from typing import Any, Callable, Dict, List, Optional, Tuple, Sequence, Union
44

5+
import numpy as np
6+
57
import torch
68
import torch.fx
79
import onnx
810
from onnx import numpy_helper
911

10-
from .exceptions import UnsupportedOpError, ValueNotFoundError
12+
from .exceptions import UnsupportedDTypeError, UnsupportedOpError, ValueNotFoundError
1113
from .op_registry import get_handler
1214
from .utils.dtype import DTYPE_MAP
15+
from .utils.external_data import resolve_external_data
1316
from .utils.names import sanitize_name
1417

1518
# Import ops module to register all operators
@@ -178,7 +181,13 @@ class GraphBuilder:
178181
The opset version for the default ONNX domain.
179182
"""
180183

181-
def __init__(self, model: onnx.ModelProto) -> None:
184+
def __init__(
185+
self,
186+
model: onnx.ModelProto,
187+
*,
188+
base_dir: Optional[str] = None,
189+
memmap_external_data: bool = False,
190+
) -> None:
182191
# Try shape inference but preserve original model if it fails
183192
# (shape_inference may drop graph contents for large models with external data)
184193
try:
@@ -191,6 +200,8 @@ def __init__(self, model: onnx.ModelProto) -> None:
191200
pass
192201
self.model: onnx.ModelProto = model
193202
self.graph: torch.fx.Graph = torch.fx.Graph()
203+
self._base_dir = base_dir
204+
self._memmap_external_data = memmap_external_data
194205
self.value_info_map = self._create_value_info_map()
195206
self.initializer_map = self._create_initializer_map()
196207
self.input_names: List[str] = []
@@ -299,6 +310,8 @@ def build(self) -> torch.fx.GraphModule:
299310
for name, submod in self._submodules.items():
300311
root_module.add_module(name, submod)
301312
module = torch.fx.GraphModule(root_module, self.graph)
313+
if self._memmap_external_data:
314+
module._onnx2fx_inference_only = True
302315
module.graph.lint()
303316
return module
304317

@@ -456,7 +469,19 @@ def extract_tensor_shape(
456469
def extract_tensor_dtype(value: onnx.ValueInfoProto) -> Optional[torch.dtype]:
457470
"""Extract the Torch dtype that corresponds to a value info."""
458471

459-
return DTYPE_MAP.get(value.type.tensor_type.elem_type)
472+
onnx_dtype = value.type.tensor_type.elem_type
473+
if onnx_dtype == 0:
474+
return None
475+
torch_dtype = DTYPE_MAP.get(onnx_dtype)
476+
if torch_dtype is None:
477+
if onnx_dtype == onnx.TensorProto.STRING:
478+
return None
479+
raise UnsupportedDTypeError(
480+
onnx_dtype=onnx_dtype,
481+
tensor_name=value.name,
482+
details="value_info dtype not supported",
483+
)
484+
return torch_dtype
460485

461486
info_map = {}
462487
for value_info in (
@@ -501,10 +526,39 @@ def _create_initializer_map(self) -> Dict[str, torch.Tensor]:
501526
"""Build a mapping from initializer names to PyTorch tensors."""
502527
init_map = {}
503528
for initializer in self.model.graph.initializer:
504-
np_array = numpy_helper.to_array(initializer)
505-
init_map[initializer.name] = torch.from_numpy(np_array.copy())
529+
init_map[initializer.name] = self.load_tensor(initializer)
506530
return init_map
507531

532+
def load_tensor(self, tensor: onnx.TensorProto) -> torch.Tensor:
533+
"""Load an ONNX TensorProto into a Torch tensor."""
534+
onnx_dtype = tensor.data_type
535+
if DTYPE_MAP.get(onnx_dtype) is None:
536+
raise UnsupportedDTypeError(
537+
onnx_dtype=onnx_dtype,
538+
tensor_name=tensor.name or "<unnamed>",
539+
details="initializer dtype not supported",
540+
)
541+
542+
if self._memmap_external_data and (
543+
tensor.data_location == onnx.TensorProto.EXTERNAL or tensor.external_data
544+
):
545+
info = resolve_external_data(
546+
tensor,
547+
base_dir=self._base_dir,
548+
strict=True,
549+
)
550+
memmap_array = np.memmap(
551+
info.path,
552+
dtype=info.numpy_dtype,
553+
mode="r",
554+
offset=info.offset,
555+
shape=info.shape,
556+
)
557+
return torch.from_numpy(memmap_array)
558+
559+
np_array = numpy_helper.to_array(tensor)
560+
return torch.from_numpy(np_array.copy())
561+
508562
def _load_initializers(self) -> None:
509563
"""Load ONNX initializers as constant nodes in the FX graph."""
510564
for name, tensor in self.initializer_map.items():

src/onnx2fx/ops/control_flow.py

Lines changed: 39 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
This module implements ONNX control flow operators like Loop and If.
55
"""
66

7-
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
7+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
88

99
import onnx
1010
import torch
@@ -70,6 +70,7 @@ def _build_subgraph_module(
7070
parent_env: Dict[str, torch.fx.Node],
7171
parent_opset_versions: Dict[str, int],
7272
parent_type_info: Optional[Dict[str, bool]] = None,
73+
tensor_loader: Optional[Callable[[onnx.TensorProto], torch.Tensor]] = None,
7374
) -> Tuple[torch.fx.GraphModule, List[str], List[str], List[str]]:
7475
"""Build an FX GraphModule from an ONNX subgraph.
7576
@@ -106,8 +107,11 @@ def _build_subgraph_module(
106107
# Load initializers from subgraph
107108
initializer_map: Dict[str, torch.Tensor] = {}
108109
for initializer in body_graph.initializer:
109-
np_array = numpy_helper.to_array(initializer)
110-
initializer_map[initializer.name] = torch.from_numpy(np_array.copy())
110+
if tensor_loader is not None:
111+
initializer_map[initializer.name] = tensor_loader(initializer)
112+
else:
113+
np_array = numpy_helper.to_array(initializer)
114+
initializer_map[initializer.name] = torch.from_numpy(np_array.copy())
111115

112116
# Register initializers as constants
113117
for name, tensor in initializer_map.items():
@@ -159,9 +163,16 @@ def __init__(self):
159163
self.initializer_map = initializer_map
160164
self._body_graph = body_graph
161165
self._parent_type_info = parent_type_info
166+
self._tensor_loader = tensor_loader
162167
# Build type info for this subgraph (to pass to nested subgraphs)
163168
self._type_info = self._build_type_info()
164169

170+
def load_tensor(self, tensor: onnx.TensorProto) -> torch.Tensor:
171+
if self._tensor_loader is not None:
172+
return self._tensor_loader(tensor)
173+
np_array = numpy_helper.to_array(tensor)
174+
return torch.from_numpy(np_array.copy())
175+
165176
def _build_type_info(self) -> Dict[str, bool]:
166177
"""Build a mapping of value names to whether they are optional types."""
167178
info: Dict[str, bool] = {}
@@ -437,7 +448,11 @@ def loop_op(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
437448
# Build subgraph module
438449
body_module, body_input_names, body_output_names, outer_refs = (
439450
_build_subgraph_module(
440-
body_graph, builder.env, builder._opset_versions, parent_type_info
451+
body_graph,
452+
builder.env,
453+
builder._opset_versions,
454+
parent_type_info,
455+
tensor_loader=builder.load_tensor,
441456
)
442457
)
443458

@@ -628,7 +643,11 @@ def scan_op(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
628643
# Build subgraph module
629644
body_module, body_input_names, body_output_names, outer_refs = (
630645
_build_subgraph_module(
631-
body_graph, builder.env, builder._opset_versions, parent_type_info
646+
body_graph,
647+
builder.env,
648+
builder._opset_versions,
649+
parent_type_info,
650+
tensor_loader=builder.load_tensor,
632651
)
633652
)
634653

@@ -712,7 +731,11 @@ def scan_op_v8(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
712731
# Build subgraph module
713732
body_module, body_input_names, body_output_names, outer_refs = (
714733
_build_subgraph_module(
715-
body_graph, builder.env, builder._opset_versions, parent_type_info
734+
body_graph,
735+
builder.env,
736+
builder._opset_versions,
737+
parent_type_info,
738+
tensor_loader=builder.load_tensor,
716739
)
717740
)
718741

@@ -884,12 +907,20 @@ def if_op(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
884907
# Build subgraph modules for both branches
885908
then_module, then_input_names, then_output_names, then_outer_refs = (
886909
_build_subgraph_module(
887-
then_graph, builder.env, builder._opset_versions, parent_type_info
910+
then_graph,
911+
builder.env,
912+
builder._opset_versions,
913+
parent_type_info,
914+
tensor_loader=builder.load_tensor,
888915
)
889916
)
890917
else_module, else_input_names, else_output_names, else_outer_refs = (
891918
_build_subgraph_module(
892-
else_graph, builder.env, builder._opset_versions, parent_type_info
919+
else_graph,
920+
builder.env,
921+
builder._opset_versions,
922+
parent_type_info,
923+
tensor_loader=builder.load_tensor,
893924
)
894925
)
895926

src/onnx2fx/ops/tensor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
@register("Constant")
2525
def constant(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
2626
"""Create a constant tensor."""
27-
value = get_attribute(node, "value")
27+
value = get_attribute(node, "value", tensor_loader=builder.load_tensor)
2828
if value is None:
2929
value_float = get_attribute(node, "value_float")
3030
if value_float is not None:
@@ -781,7 +781,7 @@ def size(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
781781
def constant_of_shape(builder: "GraphBuilder", node: onnx.NodeProto) -> torch.fx.Node:
782782
"""Create tensor filled with constant value."""
783783
shape = builder.get_value(node.input[0])
784-
value = get_attribute(node, "value")
784+
value = get_attribute(node, "value", tensor_loader=builder.load_tensor)
785785

786786
if value is not None:
787787
fill_value = (

0 commit comments

Comments
 (0)