Skip to content

Commit 44e2d2f

Browse files
authored
[stablehlo][tosa] Fix for tosa dialect updates (#2756)
This fixes up legalization passes for converting to/from tosa for tosa dialect changes - tosa rescale has multiplier, shift, input_zp, output_zp changed from attribute to operands with specific type requirements - tosa rescale attribute "double_round" changed to "rounding_mode" - tosa negate has additional input_zp and output_zp operands - tosa const attribute name changed from "value" to "values" Signed-off-by: Tai Ly <[email protected]>
1 parent 5bf0fef commit 44e2d2f

File tree

8 files changed

+226
-81
lines changed

8 files changed

+226
-81
lines changed

stablehlo/conversions/tosa/tests/binary.mlir

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,10 @@ func.func @divide(%arg0 : tensor<10xi32>, %arg1 : tensor<10xi32>) -> tensor<10xi
5252

5353
// CHECK-LABEL: @dot_vector_vector
5454
func.func @dot_vector_vector(%arg0 : tensor<3xf32>, %arg1 : tensor<3xf32>) -> tensor<f32> {
55-
// CHECK-DAG: %[[VAR0:.*]] = tosa.const_shape {value = dense<> : tensor<0xindex>} : () -> !tosa.shape<0>
56-
// CHECK-DAG: %[[VAR1:.*]] = tosa.const_shape {value = dense<[1, 1, 3]> : tensor<3xindex>} : () -> !tosa.shape<3>
55+
// CHECK-DAG: %[[VAR0:.*]] = tosa.const_shape {values = dense<> : tensor<0xindex>} : () -> !tosa.shape<0>
56+
// CHECK-DAG: %[[VAR1:.*]] = tosa.const_shape {values = dense<[1, 1, 3]> : tensor<3xindex>} : () -> !tosa.shape<3>
5757
// CHECK-DAG: %[[VAR2:.*]] = tosa.reshape %arg0, %[[VAR1]]
58-
// CHECK-DAG: %[[VAR3:.*]] = tosa.const_shape {value = dense<[1, 3, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
58+
// CHECK-DAG: %[[VAR3:.*]] = tosa.const_shape {values = dense<[1, 3, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
5959
// CHECK-DAG: %[[VAR4:.*]] = tosa.reshape %arg1, %[[VAR3]]
6060
// CHECK-DAG: %[[VAR5:.*]] = tosa.matmul %[[VAR2]], %[[VAR4]]
6161
// CHECK-DAG: %[[VAR6:.*]] = tosa.reshape %[[VAR5]], %[[VAR0]]
@@ -65,10 +65,10 @@ func.func @dot_vector_vector(%arg0 : tensor<3xf32>, %arg1 : tensor<3xf32>) -> te
6565

6666
// CHECK-LABEL: @dot_vector_matrix
6767
func.func @dot_vector_matrix(%arg0 : tensor<2xf32>, %arg1 : tensor<2x3xf32>) -> tensor<3xf32> {
68-
// CHECK-DAG: %[[VAR0:.*]] = tosa.const_shape {value = dense<3> : tensor<1xindex>} : () -> !tosa.shape<1>
69-
// CHECK-DAG: %[[VAR1:.*]] = tosa.const_shape {value = dense<[1, 1, 2]> : tensor<3xindex>} : () -> !tosa.shape<3>
68+
// CHECK-DAG: %[[VAR0:.*]] = tosa.const_shape {values = dense<3> : tensor<1xindex>} : () -> !tosa.shape<1>
69+
// CHECK-DAG: %[[VAR1:.*]] = tosa.const_shape {values = dense<[1, 1, 2]> : tensor<3xindex>} : () -> !tosa.shape<3>
7070
// CHECK-DAG: %[[VAR2:.*]] = tosa.reshape %arg0, %[[VAR1]]
71-
// CHECK-DAG: %[[VAR3:.*]] = tosa.const_shape {value = dense<[1, 2, 3]> : tensor<3xindex>} : () -> !tosa.shape<3>
71+
// CHECK-DAG: %[[VAR3:.*]] = tosa.const_shape {values = dense<[1, 2, 3]> : tensor<3xindex>} : () -> !tosa.shape<3>
7272
// CHECK-DAG: %[[VAR4:.*]] = tosa.reshape %arg1, %[[VAR3]]
7373
// CHECK-DAG: %[[VAR5:.*]] = tosa.matmul %[[VAR2]], %[[VAR4]]
7474
// CHECK-DAG: %[[VAR6:.*]] = tosa.reshape %[[VAR5]], %[[VAR0]]
@@ -78,10 +78,10 @@ func.func @dot_vector_matrix(%arg0 : tensor<2xf32>, %arg1 : tensor<2x3xf32>) ->
7878

7979
// CHECK-LABEL: @dot_matrix_vector
8080
func.func @dot_matrix_vector(%arg0 : tensor<2x3xf32>, %arg1 : tensor<3xf32>) -> tensor<2xf32> {
81-
// CHECK-DAG: %[[VAR0:.*]] = tosa.const_shape {value = dense<2> : tensor<1xindex>} : () -> !tosa.shape<1>
82-
// CHECK-DAG: %[[VAR1:.*]] = tosa.const_shape {value = dense<[1, 2, 3]> : tensor<3xindex>} : () -> !tosa.shape<3>
81+
// CHECK-DAG: %[[VAR0:.*]] = tosa.const_shape {values = dense<2> : tensor<1xindex>} : () -> !tosa.shape<1>
82+
// CHECK-DAG: %[[VAR1:.*]] = tosa.const_shape {values = dense<[1, 2, 3]> : tensor<3xindex>} : () -> !tosa.shape<3>
8383
// CHECK-DAG: %[[VAR2:.*]] = tosa.reshape %arg0, %[[VAR1]]
84-
// CHECK-DAG: %[[VAR3:.*]] = tosa.const_shape {value = dense<[1, 3, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
84+
// CHECK-DAG: %[[VAR3:.*]] = tosa.const_shape {values = dense<[1, 3, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
8585
// CHECK-DAG: %[[VAR4:.*]] = tosa.reshape %arg1, %[[VAR3]]
8686
// CHECK-DAG: %[[VAR5:.*]] = tosa.matmul %[[VAR2]], %[[VAR4]]
8787
// CHECK-DAG: %[[VAR6:.*]] = tosa.reshape %[[VAR5]], %[[VAR0]]
@@ -91,10 +91,10 @@ func.func @dot_matrix_vector(%arg0 : tensor<2x3xf32>, %arg1 : tensor<3xf32>) ->
9191

9292
// CHECK-LABEL: @dot_matrix_matrix
9393
func.func @dot_matrix_matrix(%arg0 : tensor<2x3xf32>, %arg1 : tensor<3x4xf32>) -> tensor<2x4xf32> {
94-
// CHECK-DAG: %[[VAR0:.*]] = tosa.const_shape {value = dense<[2, 4]> : tensor<2xindex>} : () -> !tosa.shape<2>
95-
// CHECK-DAG: %[[VAR1:.*]] = tosa.const_shape {value = dense<[1, 2, 3]> : tensor<3xindex>} : () -> !tosa.shape<3>
94+
// CHECK-DAG: %[[VAR0:.*]] = tosa.const_shape {values = dense<[2, 4]> : tensor<2xindex>} : () -> !tosa.shape<2>
95+
// CHECK-DAG: %[[VAR1:.*]] = tosa.const_shape {values = dense<[1, 2, 3]> : tensor<3xindex>} : () -> !tosa.shape<3>
9696
// CHECK-DAG: %[[VAR2:.*]] = tosa.reshape %arg0, %[[VAR1]]
97-
// CHECK-DAG: %[[VAR3:.*]] = tosa.const_shape {value = dense<[1, 3, 4]> : tensor<3xindex>} : () -> !tosa.shape<3>
97+
// CHECK-DAG: %[[VAR3:.*]] = tosa.const_shape {values = dense<[1, 3, 4]> : tensor<3xindex>} : () -> !tosa.shape<3>
9898
// CHECK-DAG: %[[VAR4:.*]] = tosa.reshape %arg1, %[[VAR3]]
9999
// CHECK-DAG: %[[VAR5:.*]] = tosa.matmul %[[VAR2]], %[[VAR4]]
100100
// CHECK-DAG: %[[VAR6:.*]] = tosa.reshape %[[VAR5]], %[[VAR0]]
@@ -104,10 +104,10 @@ func.func @dot_matrix_matrix(%arg0 : tensor<2x3xf32>, %arg1 : tensor<3x4xf32>) -
104104

105105
// CHECK-LABEL: @dot_general_vector_vector
106106
func.func @dot_general_vector_vector(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<f32> {
107-
// CHECK-DAG: %[[VAR0:.*]] = tosa.const_shape {value = dense<> : tensor<0xindex>} : () -> !tosa.shape<0>
108-
// CHECK-DAG: %[[VAR1:.*]] = tosa.const_shape {value = dense<[1, 1, 3]> : tensor<3xindex>} : () -> !tosa.shape<3>
107+
// CHECK-DAG: %[[VAR0:.*]] = tosa.const_shape {values = dense<> : tensor<0xindex>} : () -> !tosa.shape<0>
108+
// CHECK-DAG: %[[VAR1:.*]] = tosa.const_shape {values = dense<[1, 1, 3]> : tensor<3xindex>} : () -> !tosa.shape<3>
109109
// CHECK-DAG: %[[VAR2:.*]] = tosa.reshape %arg0, %[[VAR1]]
110-
// CHECK-DAG: %[[VAR3:.*]] = tosa.const_shape {value = dense<[1, 3, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
110+
// CHECK-DAG: %[[VAR3:.*]] = tosa.const_shape {values = dense<[1, 3, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
111111
// CHECK-DAG: %[[VAR4:.*]] = tosa.reshape %arg1, %[[VAR3]]
112112
// CHECK-DAG: %[[VAR5:.*]] = tosa.matmul %[[VAR2]], %[[VAR4]]
113113
// CHECK-DAG: %[[VAR6:.*]] = tosa.reshape %[[VAR5]], %[[VAR0]]
@@ -117,10 +117,10 @@ func.func @dot_general_vector_vector(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>)
117117

118118
// CHECK-LABEL: @dot_general_vector_matrix
119119
func.func @dot_general_vector_matrix(%arg0: tensor<2xf32>, %arg1: tensor<2x3xf32>) -> tensor<3xf32> {
120-
// CHECK-DAG: %[[VAR0:.*]] = tosa.const_shape {value = dense<3> : tensor<1xindex>} : () -> !tosa.shape<1>
121-
// CHECK-DAG: %[[VAR1:.*]] = tosa.const_shape {value = dense<[1, 1, 2]> : tensor<3xindex>} : () -> !tosa.shape<3>
120+
// CHECK-DAG: %[[VAR0:.*]] = tosa.const_shape {values = dense<3> : tensor<1xindex>} : () -> !tosa.shape<1>
121+
// CHECK-DAG: %[[VAR1:.*]] = tosa.const_shape {values = dense<[1, 1, 2]> : tensor<3xindex>} : () -> !tosa.shape<3>
122122
// CHECK-DAG: %[[VAR2:.*]] = tosa.reshape %arg0, %[[VAR1]]
123-
// CHECK-DAG: %[[VAR3:.*]] = tosa.const_shape {value = dense<[1, 2, 3]> : tensor<3xindex>} : () -> !tosa.shape<3>
123+
// CHECK-DAG: %[[VAR3:.*]] = tosa.const_shape {values = dense<[1, 2, 3]> : tensor<3xindex>} : () -> !tosa.shape<3>
124124
// CHECK-DAG: %[[VAR4:.*]] = tosa.reshape %arg1, %[[VAR3]]
125125
// CHECK-DAG: %[[VAR5:.*]] = tosa.matmul %[[VAR2]], %[[VAR4]]
126126
// CHECK-DAG: %[[VAR6:.*]] = tosa.reshape %[[VAR5]], %[[VAR0]]
@@ -130,10 +130,10 @@ func.func @dot_general_vector_matrix(%arg0: tensor<2xf32>, %arg1: tensor<2x3xf32
130130

131131
// CHECK-LABEL: @dot_general_matrix_vector
132132
func.func @dot_general_matrix_vector(%arg0: tensor<2x3xf32>, %arg1: tensor<3xf32>) -> tensor<2xf32> {
133-
// CHECK-DAG: %[[VAR0:.*]] = tosa.const_shape {value = dense<2> : tensor<1xindex>} : () -> !tosa.shape<1>
134-
// CHECK-DAG: %[[VAR1:.*]] = tosa.const_shape {value = dense<[1, 2, 3]> : tensor<3xindex>} : () -> !tosa.shape<3>
133+
// CHECK-DAG: %[[VAR0:.*]] = tosa.const_shape {values = dense<2> : tensor<1xindex>} : () -> !tosa.shape<1>
134+
// CHECK-DAG: %[[VAR1:.*]] = tosa.const_shape {values = dense<[1, 2, 3]> : tensor<3xindex>} : () -> !tosa.shape<3>
135135
// CHECK-DAG: %[[VAR2:.*]] = tosa.reshape %arg0, %[[VAR1]]
136-
// CHECK-DAG: %[[VAR3:.*]] = tosa.const_shape {value = dense<[1, 3, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
136+
// CHECK-DAG: %[[VAR3:.*]] = tosa.const_shape {values = dense<[1, 3, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
137137
// CHECK-DAG: %[[VAR4:.*]] = tosa.reshape %arg1, %[[VAR3]]
138138
// CHECK-DAG: %[[VAR5:.*]] = tosa.matmul %[[VAR2]], %[[VAR4]]
139139
// CHECK-DAG: %[[VAR6:.*]] = tosa.reshape %[[VAR5]], %[[VAR0]]
@@ -143,10 +143,10 @@ func.func @dot_general_matrix_vector(%arg0: tensor<2x3xf32>, %arg1: tensor<3xf32
143143

144144
// CHECK-LABEL: @dot_general_matrix_matrix
145145
func.func @dot_general_matrix_matrix(%arg0: tensor<2x3xf32>, %arg1: tensor<3x4xf32>) -> tensor<2x4xf32> {
146-
// CHECK-DAG: %[[VAR0:.*]] = tosa.const_shape {value = dense<[2, 4]> : tensor<2xindex>} : () -> !tosa.shape<2>
147-
// CHECK-DAG: %[[VAR1:.*]] = tosa.const_shape {value = dense<[1, 2, 3]> : tensor<3xindex>} : () -> !tosa.shape<3>
146+
// CHECK-DAG: %[[VAR0:.*]] = tosa.const_shape {values = dense<[2, 4]> : tensor<2xindex>} : () -> !tosa.shape<2>
147+
// CHECK-DAG: %[[VAR1:.*]] = tosa.const_shape {values = dense<[1, 2, 3]> : tensor<3xindex>} : () -> !tosa.shape<3>
148148
// CHECK-DAG: %[[VAR2:.*]] = tosa.reshape %arg0, %[[VAR1]]
149-
// CHECK-DAG: %[[VAR3:.*]] = tosa.const_shape {value = dense<[1, 3, 4]> : tensor<3xindex>} : () -> !tosa.shape<3>
149+
// CHECK-DAG: %[[VAR3:.*]] = tosa.const_shape {values = dense<[1, 3, 4]> : tensor<3xindex>} : () -> !tosa.shape<3>
150150
// CHECK-DAG: %[[VAR4:.*]] = tosa.reshape %arg1, %[[VAR3]]
151151
// CHECK-DAG: %[[VAR5:.*]] = tosa.matmul %[[VAR2]], %[[VAR4]]
152152
// CHECK-DAG: %[[VAR6:.*]] = tosa.reshape %[[VAR5]], %[[VAR0]]

0 commit comments

Comments
 (0)