diff --git a/backends/xnnpack/_passes/TARGETS b/backends/xnnpack/_passes/TARGETS index a199e1aab01..972980570ec 100644 --- a/backends/xnnpack/_passes/TARGETS +++ b/backends/xnnpack/_passes/TARGETS @@ -19,5 +19,6 @@ python_library( "//executorch/exir/passes:const_prop_pass", "//executorch/exir/passes:memory_format_ops_pass", "//executorch/exir/program:program", + "//executorch/backends/transforms:utils", ], ) diff --git a/backends/xnnpack/_passes/fuse_batch_norm_with_conv.py b/backends/xnnpack/_passes/fuse_batch_norm_with_conv.py index b0f4779eb4c..6f31fe698ba 100644 --- a/backends/xnnpack/_passes/fuse_batch_norm_with_conv.py +++ b/backends/xnnpack/_passes/fuse_batch_norm_with_conv.py @@ -7,13 +7,22 @@ import operator import torch +from executorch.backends.transforms.utils import ( + create_constant_placeholder, + delete_constant_placeholder, +) from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass -from executorch.backends.xnnpack.utils.utils import get_param_tensor, is_param_node +from executorch.backends.xnnpack.utils.utils import ( + get_param_tensor, + get_tensor_name, + is_param_node, +) from executorch.exir import ExportedProgram from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import PassResult +from torch.export.graph_signature import InputKind from torch.nn.utils.fusion import fuse_conv_bn_weights @@ -28,7 +37,7 @@ class FuseBatchNormWithConvPass(XNNPACKPass): def call(self, graph_module: torch.fx.GraphModule): graph = graph_module.graph - counter = 0 + constant_placeholders_to_delete = set() for conv in graph.nodes: # We want to discover a chain of conv -> batch_norm. # Only proceed if the current node is a conv node, and has a single @@ -55,9 +64,11 @@ def call(self, graph_module: torch.fx.GraphModule): assert len(conv.args) == 9 conv_weight = get_param_tensor(self.exported_program, conv.args[1]) + conv_weight_name = get_tensor_name(self.exported_program, conv.args[1]) assert conv_weight is not None conv_bias = get_param_tensor(self.exported_program, conv.args[2]) + conv_bias_name = get_tensor_name(self.exported_program, conv.args[2]) # Get the parameters from the batchnorm op assert ( @@ -95,23 +106,43 @@ def call(self, graph_module: torch.fx.GraphModule): bn_bias, is_transpose, ) + fused_weight_name = (conv_weight_name + "_fused_bn").replace(".", "_") + if conv_bias_name == "": + fused_bias_name = (conv_weight_name + "_bias_fused_bn").replace( + ".", "_" + ) + else: + fused_bias_name = (conv_bias_name + "_fused_bn").replace(".", "_") # Modify the graph by updating the weight and bias of conv op # with the fused weight and bias params, and replacing all the users # of getitem(batchnorm) with the conv op. - with graph.inserting_before(conv): - fused_weight_name = f"_fused_with_bn_weight_{counter}" - graph_module.register_parameter(fused_weight_name, fused_weight) - fused_weight_node = graph.get_attr(fused_weight_name) - fused_bias_name = f"_fused_with_bn_bias_{counter}" - graph_module.register_parameter(fused_bias_name, fused_bias) - fused_bias_node = graph.get_attr(fused_bias_name) - - # Update the weight and bias of conv op - conv_args = list(conv.args) + ([None] if len(conv.args) == 2 else []) - conv_args[1] = fused_weight_node - conv_args[2] = fused_bias_node - conv.args = tuple(conv_args) + with graph.inserting_before(conv.args[1]): + fused_conv_weight_node = create_constant_placeholder( + exp_program=self.exported_program, + graph=graph_module.graph, + kind=InputKind.PARAMETER, + name=fused_weight_name, + data=fused_weight, + ) + if fused_bias is not None: + fused_conv_bias_node = create_constant_placeholder( + exp_program=self.exported_program, + graph=graph_module.graph, + kind=InputKind.PARAMETER, + name=fused_bias_name, + data=fused_bias, + ) + else: + fused_conv_bias_node = None + + conv.args = ( + conv.args[0], + fused_conv_weight_node, + fused_conv_bias_node, + *conv.args[3:], + ) + # Remove any use of batchnorm from the graph for user in bn.users.copy(): assert user.target == operator.getitem @@ -119,8 +150,13 @@ def call(self, graph_module: torch.fx.GraphModule): graph.erase_node(user) graph.erase_node(bn) + constant_placeholders_to_delete.update(conv.args[1:3] + bn.args[1:5]) - counter += 1 + if len(constant_placeholders_to_delete) > 0: + graph_module.graph.eliminate_dead_code() + for node in constant_placeholders_to_delete: + if (node is not None) and (len(node.users) == 0): + delete_constant_placeholder(self.exported_program, node) graph_module.recompile() # To Regenerate meta data and shape information, retrace module diff --git a/backends/xnnpack/operators/node_visitor.py b/backends/xnnpack/operators/node_visitor.py index 0a825a94bef..ec39d287346 100644 --- a/backends/xnnpack/operators/node_visitor.py +++ b/backends/xnnpack/operators/node_visitor.py @@ -34,11 +34,16 @@ check_or_raise, get_input_node, get_param_tensor, + get_tensor_name, is_param_node, PERM_NCHW_TO_NHWC, ) -from executorch.backends.xnnpack.utils.xnnpack_constants import XNN_INVALID_VALUE_ID +from executorch.backends.xnnpack.utils.xnnpack_constants import ( + UINT64_MAX, + XNN_INVALID_VALUE_ID, +) +from executorch.exir._serialize._named_data_store import NamedDataStore from torch.export import ExportedProgram XNN_TYPE_MAP = { @@ -46,8 +51,6 @@ } from executorch.backends.xnnpack.serialization.xnnpack_graph_serialize import ( - _aligned_size, - _pad_to, CONSTANT_TENSOR_ALIGNMENT, ) @@ -86,11 +89,11 @@ def __init__( self, exported_program: ExportedProgram, external_ids: Dict, - constant_data_bytes: bytearray, + named_data_store: NamedDataStore, ) -> None: self._external_ids = external_ids or {} self._exported_program = exported_program or None - self._constant_data_bytes = constant_data_bytes + self._named_data_store = named_data_store @property def external_ids(self) -> Dict: @@ -579,11 +582,16 @@ def get_serialized_buffer_index( ctypes.POINTER(array_type), ).contents - offset = len(self._constant_data_bytes) + named_key = get_tensor_name(self.exported_program, get_attr_node) + if named_key == "": + raise ValueError(f"Tensor from node: {get_attr_node} has no name") + size = const_val.untyped_storage().nbytes() - xnn_graph.constant_data.append(ConstantDataOffset(offset=offset, size=size)) - self._constant_data_bytes.extend( - _pad_to(bytes(array), _aligned_size(size, CONSTANT_TENSOR_ALIGNMENT)) + xnn_graph.constant_data.append( + ConstantDataOffset(offset=UINT64_MAX, size=size, named_key=named_key) + ) + self._named_data_store.add_named_data( + named_key, bytes(array), alignment=CONSTANT_TENSOR_ALIGNMENT ) return buffer_idx diff --git a/backends/xnnpack/serialization/schema.fbs b/backends/xnnpack/serialization/schema.fbs index 5a43481b98d..193656c30b1 100644 --- a/backends/xnnpack/serialization/schema.fbs +++ b/backends/xnnpack/serialization/schema.fbs @@ -316,11 +316,20 @@ table XNNLeakyReLU { table ConstantDataOffset { // Constant data offsets are relative to the constant data base offset provided // in the XNNPACKHeader. + // named_key and offset are mutually exclusive, meaning only one of these values + // are valid. If the named key is a non-empty string, then the offset must be UINT64_MAX. + // If the offset is not UINT64_MAX, then the named key must be an empty string offset: uint64; // The size in bytes of valid data starting at the offset. The constant data // may be followed by padding before the next piece of constant data size: uint64; + + // unique string id used to query the offset from the named data store. + // named_key and offset are mutually exclusive, meaning only one of these values + // are valid. If the named key is a non-empty string, then the offset must be UINT64_MAX. + // If the offset is not UINT64_MAX, then the named key must be an empty string + named_key: string; } table XNNGraph { diff --git a/backends/xnnpack/serialization/xnnpack_graph_schema.py b/backends/xnnpack/serialization/xnnpack_graph_schema.py index 3276dac7869..3cb572c66ef 100644 --- a/backends/xnnpack/serialization/xnnpack_graph_schema.py +++ b/backends/xnnpack/serialization/xnnpack_graph_schema.py @@ -470,6 +470,7 @@ class XValue: class ConstantDataOffset: offset: int size: int + named_key: str = "" @dataclass diff --git a/backends/xnnpack/utils/gen_xnnpack_constants.sh b/backends/xnnpack/utils/gen_xnnpack_constants.sh index 6be9d4519f3..5fa92e5b038 100644 --- a/backends/xnnpack/utils/gen_xnnpack_constants.sh +++ b/backends/xnnpack/utils/gen_xnnpack_constants.sh @@ -26,5 +26,6 @@ } > xnnpack_constants.py echo UINT32_MAX = 4294967295 >> xnnpack_constants.py +echo UINT64_MAX = 18446744073709551615 >> xnnpack_constants.py awk '/^#define\s+XNN_/ { print $2,"=",$3} ' "$1"/include/xnnpack.h >> xnnpack_constants.py if ! grep -qc "^XNN_" xnnpack_constants.py; then false; fi diff --git a/backends/xnnpack/utils/utils.py b/backends/xnnpack/utils/utils.py index b802d73c16b..fab95618807 100644 --- a/backends/xnnpack/utils/utils.py +++ b/backends/xnnpack/utils/utils.py @@ -131,6 +131,22 @@ def get_param_tensor( raise RuntimeError(f"unsupported param type, {node.op}.") +def get_tensor_name(exp_prog: ExportedProgram, node: torch.fx.Node) -> str: + if node is None: + return "" + if is_param(exp_prog, node): + return exp_prog.graph_signature.inputs_to_parameters[node.name] + elif is_buffer(exp_prog, node): + return exp_prog.graph_signature.inputs_to_buffers[node.name] + elif is_lifted_tensor_constant(exp_prog, node): + return exp_prog.graph_signature.inputs_to_lifted_tensor_constants[node.name] + else: + assert isinstance(node.target, str) + return node.target + + return "" + + def get_source_fn(node: torch.fx.Node) -> Optional[torch.fx.Node]: """ Returns the source fn of the given node, return None if something goes wrong diff --git a/backends/xnnpack/utils/xnnpack_constants.py b/backends/xnnpack/utils/xnnpack_constants.py index 351cc8ad897..364819a2435 100644 --- a/backends/xnnpack/utils/xnnpack_constants.py +++ b/backends/xnnpack/utils/xnnpack_constants.py @@ -6,8 +6,11 @@ # Auto-generated by gen_xnnpack_constants.sh script. Do not modify UINT32_MAX = 4294967295 +UINT64_MAX = 18446744073709551615 +XNN_EXTRA_BYTES = 128 XNN_EXTRA_BYTES = 16 XNN_MAX_TENSOR_DIMS = 6 +XNN_INVALID_VALUE_ID = UINT32_MAX XNN_FLAG_HINT_SPARSE_INFERENCE = 0x00000001 XNN_FLAG_HINT_FP16_INFERENCE = 0x00000002 XNN_FLAG_FORCE_FP16_INFERENCE = 0x00000004 @@ -26,7 +29,8 @@ XNN_FLAG_YIELD_WORKERS = 0x00000010 XNN_FLAG_TRANSIENT_INDIRECTION_BUFFER = 0x00000020 XNN_FLAG_KEEP_DIMS = 0x00000040 -XNN_EXTRA_QUANTIZATION_PARAMS = 8 +XNN_EXTRA_QUANTIZATION_PARAMS = 10 +XNN_MIN_BLOCKSIZE = 32 XNN_VALUE_FLAG_EXTERNAL_INPUT = 0x00000001 XNN_VALUE_FLAG_EXTERNAL_OUTPUT = 0x00000002 XNN_VALUE_FLAG_PERSISTENT = 0x00000004 diff --git a/backends/xnnpack/xnnpack_preprocess.py b/backends/xnnpack/xnnpack_preprocess.py index 4548de4940a..84cdfd69a48 100644 --- a/backends/xnnpack/xnnpack_preprocess.py +++ b/backends/xnnpack/xnnpack_preprocess.py @@ -31,6 +31,7 @@ XNN_VALUE_FLAG_EXTERNAL_INPUT, XNN_VALUE_FLAG_EXTERNAL_OUTPUT, ) +from executorch.exir._serialize._named_data_store import NamedDataStore from executorch.exir.backend.backend_details import ( BackendDetails, @@ -103,7 +104,7 @@ def preprocess( edge_program: ExportedProgram, compile_specs: List[CompileSpec], ) -> PreprocessResult: - + named_data_store = NamedDataStore() xnnpack_edge_compile_config = get_xnnpack_edge_compile_config() # Need to wrap EP here because xnnpack does addmm to linear @@ -162,7 +163,7 @@ def preprocess( ) constant_data_bytes = bytearray() - node_visitors = get_node_visitors(ep, node_to_external_map, constant_data_bytes) + node_visitors = get_node_visitors(ep, node_to_external_map, named_data_store) for node in graph_module.graph.nodes: if node.op == "call_function": @@ -191,4 +192,5 @@ def preprocess( xnnpack_graph, constant_data_bytes ), debug_handle_map={}, + data_store_output=named_data_store.get_named_data_store_output(), )