Skip to content

Commit d497571

Browse files
authored
[onnx][importer] Add support for externalized params (#18880)
This patch adds support to externalize params, and store them to the given path as an IRPA file. The IR imported with externalization should now prevent possible OOM errors happening due to large inlined parameters.
1 parent 34d9d5f commit d497571

File tree

5 files changed

+399
-19
lines changed

5 files changed

+399
-19
lines changed

compiler/bindings/python/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ declare_mlir_python_sources(IREECompilerAPIPythonTools
145145
tools/tf.py
146146
tools/tflite.py
147147
tools/import_onnx/__main__.py
148+
tools/import_onnx/importer_externalization_overrides.py
148149
tools/ir_tool/__main__.py
149150
tools/scripts/iree_compile/__main__.py
150151
tools/scripts/iree_opt/__main__.py

compiler/bindings/python/iree/compiler/tools/import_onnx/__main__.py

Lines changed: 57 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -14,42 +14,43 @@
1414
1515
python -m iree.compiler.tools.import_onnx ...
1616
"""
17+
1718
import argparse
1819
import os
1920
from pathlib import Path
2021
import sys
2122
import tempfile
2223

23-
try:
24-
import onnx
25-
except ModuleNotFoundError as e:
26-
raise ModuleNotFoundError(
27-
f"iree-import-onnx requires that the `onnx` Python package is installed "
28-
f"(typically `{sys.executable} -m pip install onnx`)"
29-
) from e
30-
31-
try:
32-
from ...extras import onnx_importer
33-
except ModuleNotFoundError as e:
34-
raise ModuleNotFoundError(
35-
"iree-import-onnx is only available if IREE was built with Torch support"
36-
) from e
37-
38-
from ...ir import (
39-
Context,
40-
)
24+
from .importer_externalization_overrides import *
4125

4226

4327
def main(args: argparse.Namespace):
4428
model_proto = load_onnx_model(args)
4529
context = Context()
4630
model_info = onnx_importer.ModelInfo(model_proto)
4731
m = model_info.create_module(context=context).operation
48-
imp = onnx_importer.NodeImporter.define_function(model_info.main_graph, m)
32+
33+
imp: Any = None
34+
if args.externalize_params:
35+
imp = IREENodeImporter.define_function(
36+
model_info.main_graph, m, args.num_elements_threshold, args.params_scope
37+
)
38+
else:
39+
imp = onnx_importer.NodeImporter.define_function(model_info.main_graph, m)
4940
imp.import_all()
41+
5042
if not args.no_verify:
5143
m.verify()
5244

45+
if args.externalize_params:
46+
default_param_path = Path(args.output_file).parent / Path(args.output_file).stem
47+
param_path = (
48+
(str(default_param_path) + "_params.irpa")
49+
if args.save_params_to is None
50+
else str(args.save_params_to)
51+
)
52+
imp.param_archive.create_archive_file(param_path)
53+
5354
# TODO: This isn't very efficient output. If these files ever
5455
# get large, enable bytecode and direct binary emission to save
5556
# some copies.
@@ -71,6 +72,12 @@ def load_onnx_model(args: argparse.Namespace) -> onnx.ModelProto:
7172
raw_model = onnx.load(args.input_file, load_external_data=False)
7273
onnx.load_external_data_for_model(raw_model, str(args.data_dir))
7374

75+
# Only change the opset version if it is greater than the current one.
76+
if args.opset_version and args.opset_version > raw_model.opset_import[0].version:
77+
raw_model = onnx.version_converter.convert_version(
78+
raw_model, args.opset_version
79+
)
80+
7481
# Do shape inference two ways. First, attempt in-memory to avoid redundant
7582
# loading and the need for writing a temporary file somewhere. If that
7683
# fails, typically because of the 2 GB protobuf size limit, try again via
@@ -132,6 +139,37 @@ def parse_arguments(argv=None) -> argparse.Namespace:
132139
" Defaults to the directory of the input file.",
133140
type=Path,
134141
)
142+
parser.add_argument(
143+
"--opset-version",
144+
help="Allows specification of a newer opset_version to update the model"
145+
" to before importing to MLIR. This can sometime assist with shape inference.",
146+
type=int,
147+
)
148+
parser.add_argument(
149+
"--num-elements-threshold",
150+
help="Minimum number of elements for an initializer to be externalized. Only has an effect if 'externalize-params' is true.",
151+
type=int,
152+
default=100,
153+
)
154+
parser.add_argument(
155+
"--externalize-params",
156+
help="Externalize large parameters and store them on the disk, to load at runtime.",
157+
action=argparse.BooleanOptionalAction,
158+
default=False,
159+
)
160+
parser.add_argument(
161+
"--save-params-to",
162+
help="Location to save the externalized parameters. When not set, the parameters will be written to '<output_file_name>_params.irpa'"
163+
" under the namespace 'model', which can be configured by passing the namespace string to 'params-scope'.",
164+
default=None,
165+
type=Path,
166+
)
167+
parser.add_argument(
168+
"--params-scope",
169+
help="The namespace or the scope in which the externalized parameters are placed. Default is 'model'.",
170+
type=str,
171+
default="model",
172+
)
135173
args = parser.parse_args(argv)
136174
return args
137175

Lines changed: 263 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,263 @@
1+
# Copyright 2023 The IREE Authors
2+
#
3+
# Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
# See https://llvm.org/LICENSE.txt for license information.
5+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
import copy
8+
import random
9+
import string
10+
import iree.runtime as rt
11+
12+
from ...dialects import util
13+
from typing import Optional, Tuple, Any
14+
15+
try:
16+
import onnx
17+
except ModuleNotFoundError as e:
18+
raise ModuleNotFoundError(
19+
f"iree-import-onnx requires that the `onnx` Python package is installed "
20+
f"(typically `{sys.executable} -m pip install onnx`)"
21+
) from e
22+
23+
try:
24+
from ...extras import onnx_importer
25+
except ModuleNotFoundError as e:
26+
raise ModuleNotFoundError(
27+
"iree-import-onnx is only available if IREE was built with Torch support"
28+
) from e
29+
30+
from onnx import numpy_helper
31+
32+
from ...ir import (
33+
Context,
34+
Type as IrType,
35+
TypeAttr,
36+
RankedTensorType,
37+
StringAttr,
38+
Attribute,
39+
Operation,
40+
Location,
41+
InsertionPoint,
42+
Value,
43+
SymbolTable,
44+
IntegerType,
45+
)
46+
47+
48+
class IREENodeImporter(onnx_importer.NodeImporter):
49+
def __init__(
50+
self,
51+
graph_info: onnx_importer.GraphInfo,
52+
*,
53+
parent_op: Operation,
54+
block: onnx_importer.Block,
55+
context_cache: "onnx_importer.ContextCache",
56+
module_op: Operation,
57+
module_cache: "onnx_importer.ModuleCache",
58+
num_elements_threshold: int,
59+
params_scope: str,
60+
):
61+
super().__init__(
62+
graph_info,
63+
parent_op=parent_op,
64+
block=block,
65+
context_cache=context_cache,
66+
module_op=module_op,
67+
module_cache=module_cache,
68+
)
69+
self.last_global_op = None
70+
self.symbol_table = SymbolTable(module_op)
71+
self.symbol_table.insert(parent_op)
72+
self.num_elements_threshold = num_elements_threshold
73+
self.param_archive = rt.ParameterIndex()
74+
self.params_scope = params_scope
75+
76+
def sanitize_name(self, name: str) -> str:
77+
# There are often some initializers in the models that have no name
78+
# labels, or contain substrings like '::', which can cause conflicts,
79+
# and invalid symbol names for symbolic references. This function will
80+
# remove substrings like '::' when the name is not empty, and generate
81+
# a random string when it is, as a placeholder.
82+
new_name: str = ""
83+
for c in range(len(name)):
84+
if name[c] == ":":
85+
new_name += "_"
86+
else:
87+
new_name += name[c]
88+
89+
if len(new_name) == 0:
90+
alpha = string.ascii_lowercase
91+
ch = random.choice(alpha)
92+
new_name = str(random.randrange(1, 1000)) + "__" + ch
93+
return new_name
94+
95+
def create_tensor_global(
96+
self,
97+
t: onnx.TensorProto,
98+
) -> Tuple[str, IrType]:
99+
# Always create globals at the top. Then after created, if there was
100+
# a prior one, move the new one to after it to maintain declaration
101+
# order.
102+
name = self.sanitize_name(t.name)
103+
with InsertionPoint.at_block_begin(
104+
self._m.regions[0].blocks[0]
105+
), Location.unknown():
106+
# After lowering to linalg-on-tensors, the data type needs to be signless.
107+
# So, we construct the globals to have signless types, and use
108+
# torch_c.from_builtin_tensor to convert to the correct frontend type.
109+
vtensor_type = RankedTensorType.get(
110+
tuple(t.dims), ELEM_TYPE_TO_SIGNLESS_IR_TYPE[t.data_type]()
111+
)
112+
ir_attrs = {
113+
"sym_name": StringAttr.get(name),
114+
"sym_visibility": StringAttr.get("private"),
115+
"type": TypeAttr.get(vtensor_type),
116+
}
117+
118+
external_scope_attr = StringAttr.get(self.params_scope)
119+
external_name_attr = StringAttr.get(name)
120+
ir_attrs["initial_value"] = Attribute.parse(
121+
f"#stream.parameter.named<{external_scope_attr}::{external_name_attr}> : {vtensor_type}"
122+
)
123+
global_op = util.GlobalOp(
124+
ir_attrs["sym_name"],
125+
ir_attrs["type"],
126+
sym_visibility=ir_attrs["sym_visibility"],
127+
initial_value=ir_attrs["initial_value"],
128+
)
129+
self.symbol_table.insert(global_op)
130+
if self.last_global_op is not None:
131+
global_op.move_after(self.last_global_op)
132+
self.last_global_op = global_op
133+
actual_symbol_name = StringAttr(global_op.attributes["sym_name"]).value
134+
return actual_symbol_name, vtensor_type
135+
136+
@classmethod
137+
def define_function(
138+
cls,
139+
graph_info: onnx_importer.GraphInfo,
140+
module_op: Operation,
141+
num_elements_threshold: int,
142+
params_scope: str,
143+
context_cache: Optional["onnx_importer.ContextCache"] = None,
144+
module_cache: Optional["onnx_importer.ModuleCache"] = None,
145+
private: bool = False,
146+
) -> "IREENodeImporter":
147+
# Recover per-context caches of various attributes.
148+
# Allows modifications in the same context without
149+
# loss of current state.
150+
cc = (
151+
context_cache
152+
if context_cache is not None
153+
else onnx_importer.ContextCache(module_op.context)
154+
)
155+
# Recover per-module caches of various attributes.
156+
# Allows modification in the same module_op without
157+
# loss of current state.
158+
mc = (
159+
module_cache
160+
if module_cache is not None
161+
else onnx_importer.ModuleCache(module_op, cc)
162+
)
163+
with module_op.context, Location.name(f"graph:{graph_info.graph_proto.name}"):
164+
body = module_op.regions[0].blocks[0]
165+
func_name = graph_info.graph_proto.name
166+
input_types = [
167+
cc.type_proto_to_type(inp.type) for inp in graph_info.input_map.values()
168+
]
169+
output_types = [
170+
cc.type_proto_to_type(out.type)
171+
for out in graph_info.output_map.values()
172+
]
173+
ftype = onnx_importer.FunctionType.get(input_types, output_types)
174+
func_op = onnx_importer.func_dialect.FuncOp(
175+
func_name,
176+
ftype,
177+
ip=InsertionPoint(body),
178+
visibility="private" if private else None,
179+
)
180+
block = func_op.add_entry_block(
181+
[Location.name(k) for k in graph_info.input_map.keys()]
182+
)
183+
imp = IREENodeImporter(
184+
graph_info,
185+
parent_op=func_op,
186+
block=block,
187+
context_cache=cc,
188+
module_op=module_op,
189+
module_cache=mc,
190+
num_elements_threshold=num_elements_threshold,
191+
params_scope=params_scope,
192+
)
193+
for node_name, input_value in zip(graph_info.input_map.keys(), block.arguments):
194+
imp._nv_map[node_name] = input_value
195+
imp._populate_graph_attrs(func_op)
196+
return imp
197+
198+
def import_initializer(
199+
self, initializer: onnx.TensorProto, extern_name: Optional[str] = None
200+
) -> Value:
201+
# If an explicitly specified name is given, use that; otherwise, pick
202+
# up the name from the tensor proto itself
203+
initializer_name = extern_name if extern_name else initializer.name
204+
dims = list(initializer.dims)
205+
num_elements = 1
206+
for d in dims:
207+
num_elements = num_elements * d
208+
if num_elements < self.num_elements_threshold:
209+
imported_tensor = super().import_initializer(initializer)
210+
self._nv_map[initializer_name] = imported_tensor
211+
return imported_tensor
212+
213+
actual_symbol_name, tensor_type = self.create_tensor_global(initializer)
214+
vtensor_type = self._cc.get_vtensor_type(
215+
tuple(initializer.dims), self._cc.tensor_element_type(initializer.data_type)
216+
)
217+
218+
with InsertionPoint(self._b), Location.name(initializer_name):
219+
old_op = util.GlobalLoadOp(tensor_type, actual_symbol_name)
220+
converted_value = Operation.create(
221+
"torch_c.from_builtin_tensor",
222+
results=[vtensor_type],
223+
operands=[old_op.result],
224+
).result
225+
226+
self._nv_map[initializer_name] = converted_value
227+
tensor_as_array = numpy_helper.to_array(initializer)
228+
self.param_archive.add_buffer(actual_symbol_name, tensor_as_array)
229+
return converted_value
230+
231+
232+
ELEM_TYPE_TO_SIGNLESS_IR_TYPE = copy.deepcopy(onnx_importer.ELEM_TYPE_TO_IR_TYPE_CB)
233+
234+
ELEM_TYPE_TO_SIGNLESS_IR_TYPE[
235+
onnx.TensorProto.DataType.INT64
236+
] = lambda: IntegerType.get_signless(64)
237+
ELEM_TYPE_TO_SIGNLESS_IR_TYPE[
238+
onnx.TensorProto.DataType.INT32
239+
] = lambda: IntegerType.get_signless(32)
240+
ELEM_TYPE_TO_SIGNLESS_IR_TYPE[
241+
onnx.TensorProto.DataType.INT16
242+
] = lambda: IntegerType.get_signless(16)
243+
ELEM_TYPE_TO_SIGNLESS_IR_TYPE[
244+
onnx.TensorProto.DataType.INT8
245+
] = lambda: IntegerType.get_signless(8)
246+
ELEM_TYPE_TO_SIGNLESS_IR_TYPE[
247+
onnx.TensorProto.DataType.INT4
248+
] = lambda: IntegerType.get_signless(4)
249+
ELEM_TYPE_TO_SIGNLESS_IR_TYPE[
250+
onnx.TensorProto.DataType.UINT8
251+
] = lambda: IntegerType.get_signless(8)
252+
ELEM_TYPE_TO_SIGNLESS_IR_TYPE[
253+
onnx.TensorProto.DataType.UINT4
254+
] = lambda: IntegerType.get_signless(4)
255+
ELEM_TYPE_TO_SIGNLESS_IR_TYPE[
256+
onnx.TensorProto.DataType.UINT16
257+
] = lambda: IntegerType.get_signless(16)
258+
ELEM_TYPE_TO_SIGNLESS_IR_TYPE[
259+
onnx.TensorProto.DataType.UINT64
260+
] = lambda: IntegerType.get_signless(64)
261+
ELEM_TYPE_TO_SIGNLESS_IR_TYPE[
262+
onnx.TensorProto.DataType.UINT32
263+
] = lambda: IntegerType.get_signless(32)

0 commit comments

Comments
 (0)