|
| 1 | +# Copyright 2023 The IREE Authors |
| 2 | +# |
| 3 | +# Licensed under the Apache License v2.0 with LLVM Exceptions. |
| 4 | +# See https://llvm.org/LICENSE.txt for license information. |
| 5 | +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | + |
| 7 | +import copy |
| 8 | +import random |
| 9 | +import string |
| 10 | +import iree.runtime as rt |
| 11 | + |
| 12 | +from ...dialects import util |
| 13 | +from typing import Optional, Tuple, Any |
| 14 | + |
| 15 | +try: |
| 16 | + import onnx |
| 17 | +except ModuleNotFoundError as e: |
| 18 | + raise ModuleNotFoundError( |
| 19 | + f"iree-import-onnx requires that the `onnx` Python package is installed " |
| 20 | + f"(typically `{sys.executable} -m pip install onnx`)" |
| 21 | + ) from e |
| 22 | + |
| 23 | +try: |
| 24 | + from ...extras import onnx_importer |
| 25 | +except ModuleNotFoundError as e: |
| 26 | + raise ModuleNotFoundError( |
| 27 | + "iree-import-onnx is only available if IREE was built with Torch support" |
| 28 | + ) from e |
| 29 | + |
| 30 | +from onnx import numpy_helper |
| 31 | + |
| 32 | +from ...ir import ( |
| 33 | + Context, |
| 34 | + Type as IrType, |
| 35 | + TypeAttr, |
| 36 | + RankedTensorType, |
| 37 | + StringAttr, |
| 38 | + Attribute, |
| 39 | + Operation, |
| 40 | + Location, |
| 41 | + InsertionPoint, |
| 42 | + Value, |
| 43 | + SymbolTable, |
| 44 | + IntegerType, |
| 45 | +) |
| 46 | + |
| 47 | + |
| 48 | +class IREENodeImporter(onnx_importer.NodeImporter): |
| 49 | + def __init__( |
| 50 | + self, |
| 51 | + graph_info: onnx_importer.GraphInfo, |
| 52 | + *, |
| 53 | + parent_op: Operation, |
| 54 | + block: onnx_importer.Block, |
| 55 | + context_cache: "onnx_importer.ContextCache", |
| 56 | + module_op: Operation, |
| 57 | + module_cache: "onnx_importer.ModuleCache", |
| 58 | + num_elements_threshold: int, |
| 59 | + params_scope: str, |
| 60 | + ): |
| 61 | + super().__init__( |
| 62 | + graph_info, |
| 63 | + parent_op=parent_op, |
| 64 | + block=block, |
| 65 | + context_cache=context_cache, |
| 66 | + module_op=module_op, |
| 67 | + module_cache=module_cache, |
| 68 | + ) |
| 69 | + self.last_global_op = None |
| 70 | + self.symbol_table = SymbolTable(module_op) |
| 71 | + self.symbol_table.insert(parent_op) |
| 72 | + self.num_elements_threshold = num_elements_threshold |
| 73 | + self.param_archive = rt.ParameterIndex() |
| 74 | + self.params_scope = params_scope |
| 75 | + |
| 76 | + def sanitize_name(self, name: str) -> str: |
| 77 | + # There are often some initializers in the models that have no name |
| 78 | + # labels, or contain substrings like '::', which can cause conflicts, |
| 79 | + # and invalid symbol names for symbolic references. This function will |
| 80 | + # remove substrings like '::' when the name is not empty, and generate |
| 81 | + # a random string when it is, as a placeholder. |
| 82 | + new_name: str = "" |
| 83 | + for c in range(len(name)): |
| 84 | + if name[c] == ":": |
| 85 | + new_name += "_" |
| 86 | + else: |
| 87 | + new_name += name[c] |
| 88 | + |
| 89 | + if len(new_name) == 0: |
| 90 | + alpha = string.ascii_lowercase |
| 91 | + ch = random.choice(alpha) |
| 92 | + new_name = str(random.randrange(1, 1000)) + "__" + ch |
| 93 | + return new_name |
| 94 | + |
| 95 | + def create_tensor_global( |
| 96 | + self, |
| 97 | + t: onnx.TensorProto, |
| 98 | + ) -> Tuple[str, IrType]: |
| 99 | + # Always create globals at the top. Then after created, if there was |
| 100 | + # a prior one, move the new one to after it to maintain declaration |
| 101 | + # order. |
| 102 | + name = self.sanitize_name(t.name) |
| 103 | + with InsertionPoint.at_block_begin( |
| 104 | + self._m.regions[0].blocks[0] |
| 105 | + ), Location.unknown(): |
| 106 | + # After lowering to linalg-on-tensors, the data type needs to be signless. |
| 107 | + # So, we construct the globals to have signless types, and use |
| 108 | + # torch_c.from_builtin_tensor to convert to the correct frontend type. |
| 109 | + vtensor_type = RankedTensorType.get( |
| 110 | + tuple(t.dims), ELEM_TYPE_TO_SIGNLESS_IR_TYPE[t.data_type]() |
| 111 | + ) |
| 112 | + ir_attrs = { |
| 113 | + "sym_name": StringAttr.get(name), |
| 114 | + "sym_visibility": StringAttr.get("private"), |
| 115 | + "type": TypeAttr.get(vtensor_type), |
| 116 | + } |
| 117 | + |
| 118 | + external_scope_attr = StringAttr.get(self.params_scope) |
| 119 | + external_name_attr = StringAttr.get(name) |
| 120 | + ir_attrs["initial_value"] = Attribute.parse( |
| 121 | + f"#stream.parameter.named<{external_scope_attr}::{external_name_attr}> : {vtensor_type}" |
| 122 | + ) |
| 123 | + global_op = util.GlobalOp( |
| 124 | + ir_attrs["sym_name"], |
| 125 | + ir_attrs["type"], |
| 126 | + sym_visibility=ir_attrs["sym_visibility"], |
| 127 | + initial_value=ir_attrs["initial_value"], |
| 128 | + ) |
| 129 | + self.symbol_table.insert(global_op) |
| 130 | + if self.last_global_op is not None: |
| 131 | + global_op.move_after(self.last_global_op) |
| 132 | + self.last_global_op = global_op |
| 133 | + actual_symbol_name = StringAttr(global_op.attributes["sym_name"]).value |
| 134 | + return actual_symbol_name, vtensor_type |
| 135 | + |
| 136 | + @classmethod |
| 137 | + def define_function( |
| 138 | + cls, |
| 139 | + graph_info: onnx_importer.GraphInfo, |
| 140 | + module_op: Operation, |
| 141 | + num_elements_threshold: int, |
| 142 | + params_scope: str, |
| 143 | + context_cache: Optional["onnx_importer.ContextCache"] = None, |
| 144 | + module_cache: Optional["onnx_importer.ModuleCache"] = None, |
| 145 | + private: bool = False, |
| 146 | + ) -> "IREENodeImporter": |
| 147 | + # Recover per-context caches of various attributes. |
| 148 | + # Allows modifications in the same context without |
| 149 | + # loss of current state. |
| 150 | + cc = ( |
| 151 | + context_cache |
| 152 | + if context_cache is not None |
| 153 | + else onnx_importer.ContextCache(module_op.context) |
| 154 | + ) |
| 155 | + # Recover per-module caches of various attributes. |
| 156 | + # Allows modification in the same module_op without |
| 157 | + # loss of current state. |
| 158 | + mc = ( |
| 159 | + module_cache |
| 160 | + if module_cache is not None |
| 161 | + else onnx_importer.ModuleCache(module_op, cc) |
| 162 | + ) |
| 163 | + with module_op.context, Location.name(f"graph:{graph_info.graph_proto.name}"): |
| 164 | + body = module_op.regions[0].blocks[0] |
| 165 | + func_name = graph_info.graph_proto.name |
| 166 | + input_types = [ |
| 167 | + cc.type_proto_to_type(inp.type) for inp in graph_info.input_map.values() |
| 168 | + ] |
| 169 | + output_types = [ |
| 170 | + cc.type_proto_to_type(out.type) |
| 171 | + for out in graph_info.output_map.values() |
| 172 | + ] |
| 173 | + ftype = onnx_importer.FunctionType.get(input_types, output_types) |
| 174 | + func_op = onnx_importer.func_dialect.FuncOp( |
| 175 | + func_name, |
| 176 | + ftype, |
| 177 | + ip=InsertionPoint(body), |
| 178 | + visibility="private" if private else None, |
| 179 | + ) |
| 180 | + block = func_op.add_entry_block( |
| 181 | + [Location.name(k) for k in graph_info.input_map.keys()] |
| 182 | + ) |
| 183 | + imp = IREENodeImporter( |
| 184 | + graph_info, |
| 185 | + parent_op=func_op, |
| 186 | + block=block, |
| 187 | + context_cache=cc, |
| 188 | + module_op=module_op, |
| 189 | + module_cache=mc, |
| 190 | + num_elements_threshold=num_elements_threshold, |
| 191 | + params_scope=params_scope, |
| 192 | + ) |
| 193 | + for node_name, input_value in zip(graph_info.input_map.keys(), block.arguments): |
| 194 | + imp._nv_map[node_name] = input_value |
| 195 | + imp._populate_graph_attrs(func_op) |
| 196 | + return imp |
| 197 | + |
| 198 | + def import_initializer( |
| 199 | + self, initializer: onnx.TensorProto, extern_name: Optional[str] = None |
| 200 | + ) -> Value: |
| 201 | + # If an explicitly specified name is given, use that; otherwise, pick |
| 202 | + # up the name from the tensor proto itself |
| 203 | + initializer_name = extern_name if extern_name else initializer.name |
| 204 | + dims = list(initializer.dims) |
| 205 | + num_elements = 1 |
| 206 | + for d in dims: |
| 207 | + num_elements = num_elements * d |
| 208 | + if num_elements < self.num_elements_threshold: |
| 209 | + imported_tensor = super().import_initializer(initializer) |
| 210 | + self._nv_map[initializer_name] = imported_tensor |
| 211 | + return imported_tensor |
| 212 | + |
| 213 | + actual_symbol_name, tensor_type = self.create_tensor_global(initializer) |
| 214 | + vtensor_type = self._cc.get_vtensor_type( |
| 215 | + tuple(initializer.dims), self._cc.tensor_element_type(initializer.data_type) |
| 216 | + ) |
| 217 | + |
| 218 | + with InsertionPoint(self._b), Location.name(initializer_name): |
| 219 | + old_op = util.GlobalLoadOp(tensor_type, actual_symbol_name) |
| 220 | + converted_value = Operation.create( |
| 221 | + "torch_c.from_builtin_tensor", |
| 222 | + results=[vtensor_type], |
| 223 | + operands=[old_op.result], |
| 224 | + ).result |
| 225 | + |
| 226 | + self._nv_map[initializer_name] = converted_value |
| 227 | + tensor_as_array = numpy_helper.to_array(initializer) |
| 228 | + self.param_archive.add_buffer(actual_symbol_name, tensor_as_array) |
| 229 | + return converted_value |
| 230 | + |
| 231 | + |
| 232 | +ELEM_TYPE_TO_SIGNLESS_IR_TYPE = copy.deepcopy(onnx_importer.ELEM_TYPE_TO_IR_TYPE_CB) |
| 233 | + |
| 234 | +ELEM_TYPE_TO_SIGNLESS_IR_TYPE[ |
| 235 | + onnx.TensorProto.DataType.INT64 |
| 236 | +] = lambda: IntegerType.get_signless(64) |
| 237 | +ELEM_TYPE_TO_SIGNLESS_IR_TYPE[ |
| 238 | + onnx.TensorProto.DataType.INT32 |
| 239 | +] = lambda: IntegerType.get_signless(32) |
| 240 | +ELEM_TYPE_TO_SIGNLESS_IR_TYPE[ |
| 241 | + onnx.TensorProto.DataType.INT16 |
| 242 | +] = lambda: IntegerType.get_signless(16) |
| 243 | +ELEM_TYPE_TO_SIGNLESS_IR_TYPE[ |
| 244 | + onnx.TensorProto.DataType.INT8 |
| 245 | +] = lambda: IntegerType.get_signless(8) |
| 246 | +ELEM_TYPE_TO_SIGNLESS_IR_TYPE[ |
| 247 | + onnx.TensorProto.DataType.INT4 |
| 248 | +] = lambda: IntegerType.get_signless(4) |
| 249 | +ELEM_TYPE_TO_SIGNLESS_IR_TYPE[ |
| 250 | + onnx.TensorProto.DataType.UINT8 |
| 251 | +] = lambda: IntegerType.get_signless(8) |
| 252 | +ELEM_TYPE_TO_SIGNLESS_IR_TYPE[ |
| 253 | + onnx.TensorProto.DataType.UINT4 |
| 254 | +] = lambda: IntegerType.get_signless(4) |
| 255 | +ELEM_TYPE_TO_SIGNLESS_IR_TYPE[ |
| 256 | + onnx.TensorProto.DataType.UINT16 |
| 257 | +] = lambda: IntegerType.get_signless(16) |
| 258 | +ELEM_TYPE_TO_SIGNLESS_IR_TYPE[ |
| 259 | + onnx.TensorProto.DataType.UINT64 |
| 260 | +] = lambda: IntegerType.get_signless(64) |
| 261 | +ELEM_TYPE_TO_SIGNLESS_IR_TYPE[ |
| 262 | + onnx.TensorProto.DataType.UINT32 |
| 263 | +] = lambda: IntegerType.get_signless(32) |
0 commit comments