diff --git a/backends/arm/_passes/annotate_channels_last_dim_order_pass.py b/backends/arm/_passes/annotate_channels_last_dim_order_pass.py index 24a2bd3c513..4744845dc2a 100644 --- a/backends/arm/_passes/annotate_channels_last_dim_order_pass.py +++ b/backends/arm/_passes/annotate_channels_last_dim_order_pass.py @@ -35,7 +35,10 @@ def _transpose_impl(*args, **kwargs): # Validate length of dim_order array dim = args[1] - assert len(dim) in (4, 5) + if len(dim) != 4 and len(dim) != 5: + raise ValueError( + f"Dim order length must be either 4 or 5, got {len(dim)}: {dim}" + ) # Pass-through in edge-IR return args[0] diff --git a/backends/arm/_passes/convert_split_to_slice.py b/backends/arm/_passes/convert_split_to_slice.py index 06104811be5..67bd9d73e81 100644 --- a/backends/arm/_passes/convert_split_to_slice.py +++ b/backends/arm/_passes/convert_split_to_slice.py @@ -41,9 +41,14 @@ def call(self, graph_module: torch.fx.GraphModule): dim = split_node.args[2] if len(split_node.args) > 2 else 0 dim = (dim + rank) % rank - assert ( - sum(split_lengths) == shape[dim] - ), "Given split lengths don't sum up to the size of the dimension." + # Validate that split lengths cover the entire dimension + length_sum = sum(split_lengths) + dim_size = shape[dim] + if length_sum != dim_size: + raise ValueError( + f"Split sizes {split_lengths} sum to {length_sum}, " + f"but dimension {dim} has size {dim_size}" + ) # Convert split argument 'split_lengths' to slice arguments start and end. starts = [0] * len(split_lengths) diff --git a/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py b/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py index 4d265534f97..63c57e1bedd 100644 --- a/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py +++ b/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py @@ -120,7 +120,9 @@ def fold_and_annotate_arg( if input_qparams is not None: node.meta["input_qparams"][i] = input_qparams for n in nodes_to_remove: - assert n.target == dq_op + if n.target != dq_op: + raise RuntimeError(f"Expected {dq_op} dq_op, got {n.target}") + n.replace_all_uses_with(n.args[0]) # type: ignore[arg-type] graph_module.graph.erase_node(n) @@ -136,14 +138,16 @@ def call(self, graph_module: GraphModule) -> PassResult: continue # Make sure we haven't already set qparams meta information on the node - assert "input_qparams" not in n.meta, ( - f'Unexpected key "input_qparams" found in meta for node {n}. ' - "input_qparams should not have been set at this point" - ) - assert "output_qparams" not in n.meta, ( - f'Unexpected key "output_qparams" found in meta for node {n}. ' - "output_qparams should not have been set at this point" - ) + if "input_qparams" in n.meta: + raise RuntimeError( + f'Unexpected key "input_qparams" found in meta for node {n}. ' + "input_qparams should not have been set at this point" + ) + if "output_qparams" in n.meta: + raise RuntimeError( + f'Unexpected key "output_qparams" found in meta for node {n}. ' + "output_qparams should not have been set at this point" + ) # for the inputs and outputs search the graph for quantization info and # store the information in a dict with order of the _tensor_ inputs as key, diff --git a/backends/arm/_passes/insert_table_ops.py b/backends/arm/_passes/insert_table_ops.py index 6bb1ce7dce1..aeb9d3bc5eb 100644 --- a/backends/arm/_passes/insert_table_ops.py +++ b/backends/arm/_passes/insert_table_ops.py @@ -240,8 +240,17 @@ def call(self, graph_module: GraphModule) -> PassResult: args=(node.args[0],), ) output_node = table_node - assert len(input_qparams) == 1 - assert len(output_qparams) == 1 + # Expect exactly one quantization parameter for input and output + if len(input_qparams) != 1: + raise ValueError( + f"InsertTableOpsPass expected exactly one input quantization parameter, " + f"got {len(input_qparams)} for node {node.name}" + ) + if len(output_qparams) != 1: + raise ValueError( + f"InsertTableOpsPass expected exactly one output quantization parameter, " + f"got {len(output_qparams)} for node {node.name}" + ) # Generate table buffer and how much to lshift the table output. buffer, lshift = self.generate_table_values( diff --git a/backends/arm/_passes/remove_clone_pass.py b/backends/arm/_passes/remove_clone_pass.py index 9542a4097af..a2822c7378e 100644 --- a/backends/arm/_passes/remove_clone_pass.py +++ b/backends/arm/_passes/remove_clone_pass.py @@ -17,5 +17,8 @@ def call_operator(self, op, args, kwargs, meta): if op != exir_ops.edge.aten.clone.default: return super().call_operator(op, args, kwargs, meta) - assert len(args) == 1 + if len(args) != 1: + raise ValueError( + f"clone operator expects exactly one argument, got {len(args)}" + ) return args[0]