From 34e009ffbc644ec99c94024e37f4a31761ce1008 Mon Sep 17 00:00:00 2001 From: Erik Lundell Date: Thu, 13 Feb 2025 14:17:10 +0100 Subject: [PATCH] Improve error handling in ArmBackend Change asserts to exceptions, Print more information on failure. Signed-off-by: Erik Lundell Change-Id: If9fbf827458c2f8c35b998ea562d6047347ae10b --- .../tosa_supported_operators.py | 4 +- backends/arm/operators/op_slice.py | 6 +- backends/arm/process_node.py | 75 ++++++++++++++----- backends/arm/tosa_mapping.py | 25 ++++--- backends/arm/tosa_utils.py | 39 ++++++---- 5 files changed, 100 insertions(+), 49 deletions(-) diff --git a/backends/arm/operator_support/tosa_supported_operators.py b/backends/arm/operator_support/tosa_supported_operators.py index dd092968764..2120a8870be 100644 --- a/backends/arm/operator_support/tosa_supported_operators.py +++ b/backends/arm/operator_support/tosa_supported_operators.py @@ -64,7 +64,9 @@ def get_registered_tosa_support_checks( ) -> list[Type[SupportedTOSAOperatorCheck]]: if tosa_spec not in _tosa_spec_support: - raise RuntimeError + raise RuntimeError( + f"TOSA specification not valid: {tosa_spec} not in {list(_tosa_spec_support.keys())}" + ) return _tosa_spec_support[tosa_spec] diff --git a/backends/arm/operators/op_slice.py b/backends/arm/operators/op_slice.py index 7f4804af587..fe4f850b01f 100644 --- a/backends/arm/operators/op_slice.py +++ b/backends/arm/operators/op_slice.py @@ -41,10 +41,10 @@ def define_node( shape = input_node.shape dim = dim.number if end.number < 0: - end = end.number % shape[dim] + end_index = end.number % shape[dim] else: - end = min(end.number, shape[dim]) - size = end - start.number + end_index = min(end.number, shape[dim]) + size = end_index - start.number assert size > 0 assert size <= shape[dim] diff --git a/backends/arm/process_node.py b/backends/arm/process_node.py index a83ead987ed..377f8c17c4c 100644 --- a/backends/arm/process_node.py +++ b/backends/arm/process_node.py @@ -14,7 +14,11 @@ from executorch.backends.arm.operators.node_visitor import NodeVisitor from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_specification import TosaSpecification -from executorch.backends.arm.tosa_utils import getNodeArgs, tosa_shape +from executorch.backends.arm.tosa_utils import ( + get_node_debug_info, + getNodeArgs, + tosa_shape, +) from torch.export.exported_program import ExportedProgram @@ -28,8 +32,13 @@ def process_call_function( inputs = getNodeArgs(node) # Convert output (this node itself) - output = TosaArg(node) - + try: + output = TosaArg(node) + except ValueError as e: + raise ValueError( + f"Failed processing call_function:\n{get_node_debug_info(node)}" + "Is the original torch function supported?" + ) from e tosa_graph.currRegion.currBasicBlock.addTensor( output.name, tosa_shape(output.shape, output.dim_order), output.dtype ) @@ -61,15 +70,21 @@ def process_inputs( f"Arm backend only supports contiguous memory format for inputs. " f"Expected dim_order: {tuple(range(meta.dim()))}, but got: {meta.dim_order()} for node {node.name}" ) - inputs = [TosaArg(node)] - input_shape = inputs[0].shape - input_dim_order = inputs[0].dim_order + try: + tosa_arg = TosaArg(node) + except ValueError as e: + raise ValueError( + f"Failed processing input placeholder:\n{get_node_debug_info(node)}" + "Is the original torch function supported?" + ) from e + input_shape = tosa_arg.shape + input_dim_order = tosa_arg.dim_order tensor = ts.TosaSerializerTensor( - inputs[0].name, + tosa_arg.name, tosa_shape(input_shape, input_dim_order), - inputs[0].dtype, + tosa_arg.dtype, data=None, - placeholderFilename=inputs[0].name + ".npy", + placeholderFilename=tosa_arg.name + ".npy", ) tosa_graph.addInputTensor(tensor) @@ -81,20 +96,26 @@ def process_inputs_to_parameters( tosa_spec: TosaSpecification, ): """Serialize bias and non-quantized weights""" - inputs = [TosaArg(node)] - parameter_name = edge_program.graph_signature.inputs_to_parameters[node.name] + try: + tosa_arg = TosaArg(node) + except ValueError as e: + raise ValueError( + f"Failed processing parameter placeholder:\n{get_node_debug_info(node)}" + "Is the original torch function supported?" + ) from e + parameter_name = edge_program.graph_signature.inputs_to_parameters[tosa_arg.name] parameter_data = edge_program.state_dict[parameter_name] assert isinstance(parameter_data, torch.Tensor), "Expect Attr to be tensor" parameter_values = parameter_data.detach().numpy() - if inputs[0].dtype == torch.float32: + if tosa_arg.dtype == torch.float32: assert tosa_spec.support_float(), f"{tosa_spec} doesn't support float" - parameter_values = np.transpose(parameter_values, inputs[0].dim_order) + parameter_values = np.transpose(parameter_values, tosa_arg.dim_order) tosa_graph.addConst( - parameter_values.shape, inputs[0].dtype, parameter_values, name=node.name + parameter_values.shape, tosa_arg.dtype, parameter_values, name=tosa_arg.name ) @@ -104,7 +125,13 @@ def process_inputs_to_buffers( edge_program: ExportedProgram, ): """Serialize quantized weights""" - inputs = [TosaArg(node)] + try: + tosa_arg = TosaArg(node) + except ValueError as e: + raise ValueError( + f"Failed processing buffer placeholder:\n{get_node_debug_info(node)}" + "Is the original torch function supported?" + ) from e buffer_name = edge_program.graph_signature.inputs_to_buffers[node.name] buffer_data = edge_program.state_dict[buffer_name] @@ -114,10 +141,10 @@ def process_inputs_to_buffers( # TODO: fragile code for temporary fix # the mean and var tensors are also stored here but they have shape (1, ) # we only transpose weights here - buffer_values = np.transpose(buffer_values, inputs[0].dim_order) + buffer_values = np.transpose(buffer_values, tosa_arg.dim_order) tosa_graph.addConst( - buffer_values.shape, inputs[0].dtype, buffer_values, name=node.name + buffer_values.shape, tosa_arg.dtype, buffer_values, name=node.name ) @@ -126,14 +153,22 @@ def process_inputs_to_lifted_tensor_constants( tosa_graph: ts.TosaSerializer, edge_program: ExportedProgram, ): - arg = TosaArg(node) + try: + tosa_arg = TosaArg(node) + except ValueError as e: + raise ValueError( + f"Failed processing lifted tensor constant placeholder:\n{get_node_debug_info(node)}" + "Is the original torch function supported?" + ) from e tensor_name = edge_program.graph_signature.inputs_to_lifted_tensor_constants[ - arg.name + tosa_arg.name ] tensor = edge_program.tensor_constants[tensor_name] tensor_data = tensor.detach().numpy() - tosa_graph.addConst(tensor_data.shape, arg.dtype, tensor_data, name=arg.name) + tosa_graph.addConst( + tensor_data.shape, tosa_arg.dtype, tensor_data, name=tosa_arg.name + ) def process_placeholder( diff --git a/backends/arm/tosa_mapping.py b/backends/arm/tosa_mapping.py index 75d82f2a4b6..13eb53dfa82 100644 --- a/backends/arm/tosa_mapping.py +++ b/backends/arm/tosa_mapping.py @@ -43,8 +43,10 @@ def map_dtype(data_type): - assert data_type not in UNSUPPORTED_DTYPES, f"Unsupported type: {data_type}" - assert data_type in DTYPE_MAP, f"Unknown type: {data_type}" + if data_type in UNSUPPORTED_DTYPES: + raise ValueError(f"Unsupported type: {data_type}") + if data_type not in DTYPE_MAP: + raise ValueError(f"Unknown type: {data_type}") return DTYPE_MAP[data_type] @@ -58,7 +60,10 @@ def extract_tensor_meta(meta): # TODO: should use first concrete representation val = val[0] - assert torch._subclasses.fake_tensor.FakeTensor == type(val) + if not isinstance(val, torch._subclasses.fake_tensor.FakeTensor): + raise ValueError( + f"Expected first value in node.meta['val'] to be FakeTensor, got {val.__class__}" + ) dtype = map_dtype(val.dtype) shape = tuple(val.size()) @@ -71,19 +76,18 @@ def extract_tensor_meta(meta): # Class to capture arguments and turn into tensor references for TOSA OPs class TosaArg: - def __process_node(self, argument): - assert isinstance(argument, torch.fx.node.Node) + def __process_node(self, argument: torch.fx.Node): self.name = argument.name self.dtype, self.shape, self.dim_order = extract_tensor_meta(argument.meta) def __process_list(self, argument): self.special = list(argument) - def __process_number(self, argument): + def __process_number(self, argument: float | int): self.number = argument def __init__(self, argument) -> None: - self.name = None + self.name = None # type: ignore[assignment] self.dtype = None self.shape = None self.dim_order = None @@ -92,16 +96,13 @@ def __init__(self, argument) -> None: if argument is None: return - if isinstance(argument, torch.fx.node.Node): + if isinstance(argument, torch.fx.Node): self.__process_node(argument) return if isinstance(argument, list): self.__process_list(argument) return - if isinstance(argument, int): - self.__process_number(argument) - return - if isinstance(argument, float): + if isinstance(argument, (int, float)): self.__process_number(argument) return diff --git a/backends/arm/tosa_utils.py b/backends/arm/tosa_utils.py index 15d29b57482..0d4aeba2d55 100644 --- a/backends/arm/tosa_utils.py +++ b/backends/arm/tosa_utils.py @@ -25,20 +25,28 @@ logger.setLevel(logging.INFO) -def dbg_node(node): +def dbg_node(node: torch.fx.Node): # Debug output of node information - logger.info("OP") - logger.info(f" op is {node.op}") - logger.info(f" name is {node.name}") - logger.info(f" node target is {node.target}") - logger.info(f" node args is {node.args}") - logger.info(f" node kwargs is {node.kwargs}") - logger.info(" node.meta = ") + logger.info(get_node_debug_info(node)) + + +def get_node_debug_info(node: torch.fx.Node) -> str: + output = ( + "-- NODE DEBUG INFO --\n" + f" Op is {node.op}\n" + f" Name is {node.name}\n" + f" Node target is {node.target}\n" + f" Node args is {node.args}\n" + f" Node kwargs is {node.kwargs}\n" + f" Node users is {node.users}\n" + " Node.meta = \n" + ) for k, v in node.meta.items(): - logger.info(f" '{k}' = {v}") + output += f" '{k}' = {v}\n" if isinstance(v, list): for i in v: - logger.info(f" {i} ") + output += f" {i}\n" + return output # Output TOSA flatbuffer and test harness file @@ -65,14 +73,19 @@ def dbg_tosa_dump(tosa_graph: ts.TosaSerializer, path: str, suffix: str = ""): def dbg_fail(node, tosa_graph, path): dbg_tosa_dump(tosa_graph, path) - logger.warn("Internal error due to poorly handled node:") + logger.warning("Internal error due to poorly handled node:") dbg_node(node) - logger.warn(f"Debug output captured in '{path}'.") + logger.warning(f"Debug output captured in '{path}'.") raise RuntimeError("TOSA Internal Error on node, enable logging for further info.") def getNodeArgs(node: Node) -> list[TosaArg]: - return [TosaArg(arg) for arg in node.args] + try: + return [TosaArg(arg) for arg in node.args] + except ValueError as e: + raise ValueError( + f"Failed processing args to op:\n{get_node_debug_info(node)}" + ) from e def get_output_node(node: Node) -> Node: