Skip to content

Commit 489eea2

Browse files
Arm backend: TOSA 1.0 fixes (#11284)
- Fix rank reshapes - Fix invalid TosaSerializerAttribute() usage. Signed-off-by: Oscar Andersson <[email protected]>
1 parent 2dd4e11 commit 489eea2

File tree

2 files changed

+10
-10
lines changed

2 files changed

+10
-10
lines changed

backends/arm/operators/op_conv2d.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
# pyre-unsafe
77
from typing import Any, List
88

9-
import numpy as np
109
import torch
1110

1211
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
@@ -333,21 +332,22 @@ def define_node(
333332
weight.dtype,
334333
)
335334
shape = tosa_graph.addConst(
336-
np.array(weight_post_shape).shape,
335+
[len(weight_post_shape)],
337336
ts.DType.SHAPE,
338-
np.array(weight_post_shape),
337+
weight_post_shape,
339338
name=weight_reshaped.name + "_shape",
340339
)
341340

342-
attr = ts.TosaSerializerAttribute()
343-
attr.ReshapeAttribute()
341+
reshape_attr = ts.TosaSerializerAttribute()
342+
reshape_attr.ReshapeAttribute()
344343
tosa_graph.addOperator(
345344
ts.TosaOp.Op().RESHAPE,
346345
[weight.name, shape.name],
347346
[weight_reshaped.name],
348-
attr,
347+
reshape_attr,
349348
)
350349

350+
attr = ts.TosaSerializerAttribute()
351351
tosa_op = ts.TosaOp.Op().DEPTHWISE_CONV2D
352352
weight_name = weight_reshaped.name
353353

backends/arm/operators/op_view.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -74,14 +74,14 @@ def define_node(
7474
tosa_graph = cast(ts.TosaSerializer, tosa_graph)
7575

7676
if len(output.shape) != 0:
77-
shape_len = len(output.shape)
77+
shape_len = [len(output.shape)]
7878
shape_data = list(tosa_shape(output.shape, output.dim_order))
7979
else:
80-
shape_len = 1
81-
shape_data = [0]
80+
shape_len = []
81+
shape_data = []
8282

8383
shape = tosa_graph.addConst(
84-
[shape_len],
84+
shape_len,
8585
ts.DType.SHAPE,
8686
shape_data,
8787
name=node.name + "_shape",

0 commit comments

Comments
 (0)