Skip to content

Commit 114e05b

Browse files
committed
Add TosaBuilder.mul overload taking a shift value
1 parent ee9326e commit 114e05b

File tree

4 files changed

+21
-18
lines changed

4 files changed

+21
-18
lines changed

src/Conversion/ONNXToTOSA/DialectBuilder.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,13 @@ Value TosaBuilder::reshape(Value value, llvm::ArrayRef<int64_t> shape) {
209209
}
210210

211211
Value TosaBuilder::mul(Value &lhs, Value &rhs, int8_t shift) {
212+
auto int8Type = rewriter().getI8Type();
213+
auto shiftValue =
214+
TosaBuilder::createConst(ArrayRef<int8_t>{shift}, {1}, int8Type);
215+
return TosaBuilder::mul(lhs, rhs, shiftValue);
216+
}
217+
218+
Value TosaBuilder::mul(Value &lhs, Value &rhs, Value &shift) {
212219
if (needsRankBroadcast({lhs, rhs})) {
213220
llvm::SmallVector<Value, 4> valueVec = equalizeRanks({lhs, rhs});
214221
lhs = valueVec[0];
@@ -222,11 +229,8 @@ Value TosaBuilder::mul(Value &lhs, Value &rhs, int8_t shift) {
222229
lhsType.getRank(), ShapedType::kDynamic),
223230
lhsType.getElementType());
224231

225-
auto int8Type = rewriter().getI8Type();
226-
auto shiftValue =
227-
TosaBuilder::createConst(ArrayRef<int8_t>{shift}, {1}, int8Type);
228232
return tosa::CreateOpAndInfer<mlir::tosa::MulOp>(
229-
rewriter(), loc(), newValueType, lhs, rhs, shiftValue);
233+
rewriter(), loc(), newValueType, lhs, rhs, shift);
230234
}
231235

232236
Value TosaBuilder::intdiv(Value &lhs, Value &rhs) {

src/Conversion/ONNXToTOSA/DialectBuilder.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ struct TosaBuilder : DialectBuilder {
4444
template <typename T>
4545
mlir::Value binaryOp(mlir::Value &lhs, mlir::Value &rhs);
4646
mlir::Value mul(mlir::Value &lhs, mlir::Value &rhs, int8_t shift = 0);
47+
mlir::Value mul(mlir::Value &lhs, mlir::Value &rhs, mlir::Value &shift);
4748
mlir::Value intdiv(mlir::Value &lhs, mlir::Value &rhs);
4849

4950
mlir::Value transpose(mlir::Value &value, llvm::ArrayRef<int32_t> perm);

test/mlir/conversion/onnx_to_tosa/Math/Elementwise.mlir

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ func.func @test_mul_rank_broadcast(%arg0: tensor<13x21x1xf32>, %arg1: tensor<21x
266266
"func.return"(%0) : (tensor<13x21x1xf32>) -> ()
267267
// CHECK-LABEL: func.func @test_mul_rank_broadcast
268268
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<13x21x1xf32>, [[PARAM_1_:%.+]]: tensor<21x1xf32>) -> tensor<13x21x1xf32> {
269-
// CHECK: [[VAR_0_:%.+]] = tosa.const_shape {value = dense<[1, 21, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
269+
// CHECK-DAG: [[VAR_0_:%.+]] = tosa.const_shape {value = dense<[1, 21, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
270270
// CHECK-DAG: [[VAR_1_:%.+]] = tosa.reshape [[PARAM_1_]], [[VAR_0_]] : (tensor<21x1xf32>, !tosa.shape<3>) -> tensor<1x21x1xf32>
271271
// CHECK-DAG: [[VAR_2_:%.+]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
272272
// CHECK: [[VAR_3_:%.+]] = tosa.mul [[PARAM_0_]], [[VAR_1_]], [[VAR_2_]] : (tensor<13x21x1xf32>, tensor<1x21x1xf32>, tensor<1xi8>) -> tensor<13x21x1xf32>
@@ -280,7 +280,7 @@ func.func @test_mul_rank_broadcast2(%arg0: tensor<21x1xf32>, %arg1: tensor<13x21
280280
"func.return"(%0) : (tensor<13x21x1xf32>) -> ()
281281
// CHECK-LABEL: func.func @test_mul_rank_broadcast2
282282
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<21x1xf32>, [[PARAM_1_:%.+]]: tensor<13x21x1xf32>) -> tensor<13x21x1xf32> {
283-
// CHECK: [[VAR_0_:%.+]] = tosa.const_shape {value = dense<[1, 21, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
283+
// CHECK-DAG: [[VAR_0_:%.+]] = tosa.const_shape {value = dense<[1, 21, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
284284
// CHECK-DAG: [[VAR_1_:%.+]] = tosa.reshape [[PARAM_0_]], [[VAR_0_]] : (tensor<21x1xf32>, !tosa.shape<3>) -> tensor<1x21x1xf32>
285285
// CHECK-DAG: [[VAR_2_:%.+]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
286286
// CHECK: [[VAR_3_:%.+]] = tosa.mul [[VAR_1_]], [[PARAM_1_]], [[VAR_2_]] : (tensor<1x21x1xf32>, tensor<13x21x1xf32>, tensor<1xi8>) -> tensor<13x21x1xf32>
@@ -731,9 +731,9 @@ func.func @test_div_decomposed_broadcast(%arg0: tensor<13x21x1xf32>, %arg1: tens
731731
// CHECK-LABEL: func @test_div_decomposed_broadcast
732732
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<13x21x1xf32>, [[PARAM_1_:%.+]]: tensor<1xf32>) -> tensor<13x21x1xf32> {
733733
// CHECK-NEXT: [[VAR_0_:%.+]] = tosa.reciprocal [[PARAM_1_]] : (tensor<1xf32>) -> tensor<1xf32>
734-
// CHECK-NEXT: [[SHAPE:%.+]] = tosa.const_shape {value = dense<1> : tensor<3xindex>} : () -> !tosa.shape<3>
735-
// CHECK-NEXT: [[VAR_1_:%.+]] = tosa.reshape [[VAR_0_]], [[SHAPE]] : (tensor<1xf32>, !tosa.shape<3>) -> tensor<1x1x1xf32>
736-
// CHECK-NEXT: [[ZERO:%.+]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
734+
// CHECK-DAG: [[SHAPE:%.+]] = tosa.const_shape {value = dense<1> : tensor<3xindex>} : () -> !tosa.shape<3>
735+
// CHECK-DAG: [[VAR_1_:%.+]] = tosa.reshape [[VAR_0_]], [[SHAPE]] : (tensor<1xf32>, !tosa.shape<3>) -> tensor<1x1x1xf32>
736+
// CHECK-DAG: [[ZERO:%.+]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
737737
// CHECK-NEXT: [[VAR_2_:%.+]] = tosa.mul [[PARAM_0_]], [[VAR_1_]], [[ZERO]] : (tensor<13x21x1xf32>, tensor<1x1x1xf32>, tensor<1xi8>) -> tensor<13x21x1xf32>
738738
}
739739

test/mlir/conversion/onnx_to_tosa/Math/Gemm_to_matmul.mlir

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,10 @@ func.func @test_beta(%arg0: tensor<3x6xf32>, %arg1: tensor<6x6xf32>, %arg2: tens
5555
// CHECK: [[VAR_0_:%.+]] = tosa.reshape [[PARAM_0_]], [[SHAPE_0]] : (tensor<3x6xf32>, !tosa.shape<3>) -> tensor<1x3x6xf32>
5656
// CHECK-DAG: [[SHAPE_1:%.+]] = tosa.const_shape {value = dense<[1, 6, 6]> : tensor<3xindex>} : () -> !tosa.shape<3>
5757
// CHECK: [[VAR_1_:%.+]] = tosa.reshape [[PARAM_1_]], [[SHAPE_1]] : (tensor<6x6xf32>, !tosa.shape<3>) -> tensor<1x6x6xf32>
58-
// CHECK: [[VAR_2_:%.+]] = "tosa.const"() <{value = dense<1.349000e+00> : tensor<1x1x1xf32>}> : () -> tensor<1x1x1xf32>
58+
// CHECK-DAG: [[VAR_2_:%.+]] = "tosa.const"() <{value = dense<1.349000e+00> : tensor<1x1x1xf32>}> : () -> tensor<1x1x1xf32>
5959
// CHECK-DAG: [[SHAPE_2:%.+]] = tosa.const_shape {value = dense<[1, 3, 6]> : tensor<3xindex>} : () -> !tosa.shape<3>
60-
// CHECK: [[VAR_3_:%.+]] = tosa.reshape [[PARAM_2_]], [[SHAPE_2]] : (tensor<3x6xf32>, !tosa.shape<3>) -> tensor<1x3x6xf32>
61-
// CHECK-NOT: separator of consecutive DAGs
62-
// CHECK: [[ZERO:%.+]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
60+
// CHECK-DAG: [[VAR_3_:%.+]] = tosa.reshape [[PARAM_2_]], [[SHAPE_2]] : (tensor<3x6xf32>, !tosa.shape<3>) -> tensor<1x3x6xf32>
61+
// CHECK-DAG: [[ZERO:%.+]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
6362
// CHECK: [[VAR_4_:%.+]] = tosa.mul [[VAR_2_]], [[VAR_3_]], [[ZERO]] : (tensor<1x1x1xf32>, tensor<1x3x6xf32>, tensor<1xi8>) -> tensor<1x3x6xf32>
6463
// CHECK: [[VAR_5_:%.+]] = tosa.matmul [[VAR_0_]], [[VAR_1_]] : (tensor<1x3x6xf32>, tensor<1x6x6xf32>) -> tensor<1x3x6xf32>
6564
// CHECK: [[VAR_6_:%.+]] = tosa.add [[VAR_5_]], [[VAR_4_]] : (tensor<1x3x6xf32>, tensor<1x3x6xf32>) -> tensor<1x3x6xf32>
@@ -185,11 +184,10 @@ func.func @test_mixed(%arg0: tensor<11x5xf32>, %arg1: tensor<3x11xf32>, %arg2: t
185184
// CHECK-DAG: [[ZERO_0:%.+]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
186185
// CHECK-NOT: separator of consecutive DAGs
187186
// CHECK: [[VAR_7_:%.+]] = tosa.mul [[VAR_6_]], [[VAR_3_]], [[ZERO_0]] : (tensor<1x1x1xf32>, tensor<1x5x11xf32>, tensor<1xi8>) -> tensor<1x5x11xf32>
188-
// CHECK: [[VAR_8_:%.+]] = "tosa.const"() <{value = dense<1.998000e+00> : tensor<1x1x1xf32>}> : () -> tensor<1x1x1xf32>
189-
// CHECK: [[SHAPE_2:%.+]] = tosa.const_shape {value = dense<[1, 5, 3]> : tensor<3xindex>} : () -> !tosa.shape<3>
190-
// CHECK: [[VAR_9_:%.+]] = tosa.reshape [[PARAM_2_]], [[SHAPE_2]] : (tensor<5x3xf32>, !tosa.shape<3>) -> tensor<1x5x3xf32>
191-
// CHECK-NOT: separator of consecutive DAGs
192-
// CHECK: [[ZERO_1:%.+]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
187+
// CHECK-DAG: [[VAR_8_:%.+]] = "tosa.const"() <{value = dense<1.998000e+00> : tensor<1x1x1xf32>}> : () -> tensor<1x1x1xf32>
188+
// CHECK-DAG: [[SHAPE_2:%.+]] = tosa.const_shape {value = dense<[1, 5, 3]> : tensor<3xindex>} : () -> !tosa.shape<3>
189+
// CHECK-DAG: [[VAR_9_:%.+]] = tosa.reshape [[PARAM_2_]], [[SHAPE_2]] : (tensor<5x3xf32>, !tosa.shape<3>) -> tensor<1x5x3xf32>
190+
// CHECK-DAG: [[ZERO_1:%.+]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
193191
// CHECK: [[VAR_10_:%.+]] = tosa.mul [[VAR_8_]], [[VAR_9_]], [[ZERO_1]] : (tensor<1x1x1xf32>, tensor<1x5x3xf32>, tensor<1xi8>) -> tensor<1x5x3xf32>
194192
// CHECK: [[VAR_11_:%.+]] = tosa.matmul [[VAR_7_]], [[VAR_5_]] : (tensor<1x5x11xf32>, tensor<1x11x3xf32>) -> tensor<1x5x3xf32>
195193
// CHECK: [[VAR_12_:%.+]] = tosa.add [[VAR_11_]], [[VAR_10_]] : (tensor<1x5x3xf32>, tensor<1x5x3xf32>) -> tensor<1x5x3xf32>

0 commit comments

Comments
 (0)