diff --git a/backends/arm/operators/op_conv2d.py b/backends/arm/operators/op_conv2d.py index fdbb20fbe18..a566b0fbfa7 100644 --- a/backends/arm/operators/op_conv2d.py +++ b/backends/arm/operators/op_conv2d.py @@ -6,7 +6,6 @@ # pyre-unsafe from typing import Any, List -import numpy as np import torch from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( @@ -333,21 +332,22 @@ def define_node( weight.dtype, ) shape = tosa_graph.addConst( - np.array(weight_post_shape).shape, + [len(weight_post_shape)], ts.DType.SHAPE, - np.array(weight_post_shape), + weight_post_shape, name=weight_reshaped.name + "_shape", ) - attr = ts.TosaSerializerAttribute() - attr.ReshapeAttribute() + reshape_attr = ts.TosaSerializerAttribute() + reshape_attr.ReshapeAttribute() tosa_graph.addOperator( ts.TosaOp.Op().RESHAPE, [weight.name, shape.name], [weight_reshaped.name], - attr, + reshape_attr, ) + attr = ts.TosaSerializerAttribute() tosa_op = ts.TosaOp.Op().DEPTHWISE_CONV2D weight_name = weight_reshaped.name diff --git a/backends/arm/operators/op_view.py b/backends/arm/operators/op_view.py index e7a062bbf22..d8ac85ec63a 100644 --- a/backends/arm/operators/op_view.py +++ b/backends/arm/operators/op_view.py @@ -74,14 +74,14 @@ def define_node( tosa_graph = cast(ts.TosaSerializer, tosa_graph) if len(output.shape) != 0: - shape_len = len(output.shape) + shape_len = [len(output.shape)] shape_data = list(tosa_shape(output.shape, output.dim_order)) else: - shape_len = 1 - shape_data = [0] + shape_len = [] + shape_data = [] shape = tosa_graph.addConst( - [shape_len], + shape_len, ts.DType.SHAPE, shape_data, name=node.name + "_shape",