Skip to content

Commit 21c1984

Browse files
committed
Solve merge conflicts and fixed tests after bump
Adjust ONNX-to-TOSA mul legalization to pass the mandatory shift operand and update conversion tests to expect the additional argument.
1 parent d2d6ae3 commit 21c1984

File tree

10 files changed

+74
-79
lines changed

10 files changed

+74
-79
lines changed

src/Conversion/ONNXToTOSA/DialectBuilder.cpp

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -219,14 +219,10 @@ Value TosaBuilder::mul(Value &lhs, Value &rhs, int32_t shift) {
219219
lhsType.getRank(), ShapedType::kDynamic),
220220
elementType);
221221

222-
if (isa<IntegerType>(elementType)) {
223-
Value shiftConst = tosa::createMulShiftConst(rewriter(), loc(), shift);
224-
return tosa::CreateOpAndInfer<mlir::tosa::MulOp>(
225-
rewriter(), loc(), newValueType, lhs, rhs, shiftConst);
226-
}
227-
222+
Value shiftConst =
223+
tosa::createMulShiftConst(rewriter(), loc(), /*shift=*/shift);
228224
return tosa::CreateOpAndInfer<mlir::tosa::MulOp>(
229-
rewriter(), loc(), newValueType, lhs, rhs, Value());
225+
rewriter(), loc(), newValueType, lhs, rhs, shiftConst);
230226
}
231227

232228
Value TosaBuilder::intdiv(Value &lhs, Value &rhs) {

src/Conversion/ONNXToTOSA/NN/DequantizeLinear.cpp

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -95,18 +95,10 @@ class ONNXDequantizeLinearOpLoweringToTOSA
9595
rewriter, loc, adaptor.getXScale(), axis, resultType.getRank());
9696
Value scaleFactorCast =
9797
tosaBuilder.castToNewTensorElementType(scaleFactorConst, arithType);
98-
Value mulOp;
99-
auto castedType = mlir::cast<ShapedType>(casted.getType());
100-
if (isa<IntegerType>(castedType.getElementType())) {
101-
Value shiftConst = tosa::createMulShiftConst(rewriter, loc, 0);
102-
mulOp = tosa::CreateOpAndInfer<mlir::tosa::MulOp>(
103-
rewriter, loc, casted.getType(), casted, scaleFactorCast, shiftConst)
104-
.getResult();
105-
} else {
106-
mulOp = tosa::CreateOpAndInfer<mlir::tosa::MulOp>(
107-
rewriter, loc, casted.getType(), casted, scaleFactorCast, Value())
108-
.getResult();
109-
}
98+
Value shiftConst = tosa::createMulShiftConst(rewriter, loc, 0);
99+
Value mulOp = tosa::CreateOpAndInfer<mlir::tosa::MulOp>(
100+
rewriter, loc, casted.getType(), casted, scaleFactorCast, shiftConst)
101+
.getResult();
110102
Value castOp = tosaBuilder.castToNewTensorElementType(
111103
mulOp, resultType.getElementType());
112104

src/Conversion/ONNXToTOSA/NN/QuantizeLinear.cpp

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -91,18 +91,10 @@ class ONNXQuantizeLinearOpLoweringToTOSA
9191
Value recOp = tosa::CreateOpAndInfer<mlir::tosa::ReciprocalOp>(rewriter,
9292
loc, expandedScaleFactorConst.getType(), expandedScaleFactorConst)
9393
.getResult();
94-
Value scaledResult;
95-
auto xShapedType = mlir::cast<ShapedType>(xType);
96-
if (isa<IntegerType>(xShapedType.getElementType())) {
97-
Value shiftConst = tosa::createMulShiftConst(rewriter, loc, 0);
98-
scaledResult = tosa::CreateOpAndInfer<mlir::tosa::MulOp>(
99-
rewriter, loc, xType, x, recOp, shiftConst)
100-
.getResult();
101-
} else {
102-
scaledResult = tosa::CreateOpAndInfer<mlir::tosa::MulOp>(
103-
rewriter, loc, xType, x, recOp, Value())
104-
.getResult();
105-
}
94+
Value shiftConst = tosa::createMulShiftConst(rewriter, loc, 0);
95+
Value scaledResult = tosa::CreateOpAndInfer<mlir::tosa::MulOp>(
96+
rewriter, loc, xType, x, recOp, shiftConst)
97+
.getResult();
10698

10799
// Quantization to i4/i8/16/ is particular since the intermediate result of
108100
// (x / y_scale) must round to the nearest even. This is particularly

test/mlir/conversion/onnx_to_tosa/Math/Elementwise.mlir

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,8 @@ func.func @test_mul(%arg0: tensor<13x21x1xf32>, %arg1: tensor<13x21x1xf32>) -> t
239239
"func.return"(%0) : (tensor<13x21x1xf32>) -> ()
240240
// CHECK-LABEL: func @test_mul
241241
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<13x21x1xf32>, [[PARAM_1_:%.+]]: tensor<13x21x1xf32>) -> tensor<13x21x1xf32> {
242-
// CHECK-NEXT: [[VAR_0_:%.+]] = tosa.mul [[PARAM_0_]], [[PARAM_1_]] : (tensor<13x21x1xf32>, tensor<13x21x1xf32>) -> tensor<13x21x1xf32>
242+
// CHECK-NEXT: [[SHIFT_0_:%.+]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
243+
// CHECK-NEXT: [[VAR_0_:%.+]] = tosa.mul [[PARAM_0_]], [[PARAM_1_]], [[SHIFT_0_]] : (tensor<13x21x1xf32>, tensor<13x21x1xf32>, tensor<1xi8>) -> tensor<13x21x1xf32>
243244
}
244245

245246
// -----
@@ -249,7 +250,8 @@ func.func @test_mul_dynamic(%arg0: tensor<?x?x?xf32>, %arg1: tensor<13x?x?xf32>)
249250
"func.return"(%0) : (tensor<13x?x?xf32>) -> ()
250251
// CHECK-LABEL: func @test_mul_dynamic
251252
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<?x?x?xf32>, [[PARAM_1_:%.+]]: tensor<13x?x?xf32>) -> tensor<13x?x?xf32> {
252-
// CHECK-NEXT: [[VAR_0_:%.+]] = tosa.mul [[PARAM_0_]], [[PARAM_1_]] : (tensor<?x?x?xf32>, tensor<13x?x?xf32>) -> tensor<13x?x?xf32>
253+
// CHECK-NEXT: [[SHIFT_1_:%.+]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
254+
// CHECK-NEXT: [[VAR_0_:%.+]] = tosa.mul [[PARAM_0_]], [[PARAM_1_]], [[SHIFT_1_]] : (tensor<?x?x?xf32>, tensor<13x?x?xf32>, tensor<1xi8>) -> tensor<13x?x?xf32>
253255
}
254256

255257
// -----
@@ -260,7 +262,8 @@ func.func @test_mul_rank_broadcast(%arg0: tensor<13x21x1xf32>, %arg1: tensor<21x
260262
// CHECK-LABEL: func @test_mul_rank_broadcast
261263
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<13x21x1xf32>, [[PARAM_1_:%.+]]: tensor<21x1xf32>) -> tensor<13x21x1xf32> {
262264
// CHECK-NEXT: [[VAR_0_:%.+]] = tosa.reshape [[PARAM_1_]] {new_shape = array<i64: 1, 21, 1>} : (tensor<21x1xf32>) -> tensor<1x21x1xf32>
263-
// CHECK-NEXT: [[VAR_1_:%.+]] = tosa.mul [[PARAM_0_]], [[VAR_0_]] : (tensor<13x21x1xf32>, tensor<1x21x1xf32>) -> tensor<13x21x1xf32>
265+
// CHECK-NEXT: [[SHIFT_2_:%.+]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
266+
// CHECK-NEXT: [[VAR_1_:%.+]] = tosa.mul [[PARAM_0_]], [[VAR_0_]], [[SHIFT_2_]] : (tensor<13x21x1xf32>, tensor<1x21x1xf32>, tensor<1xi8>) -> tensor<13x21x1xf32>
264267
}
265268

266269
// -----
@@ -271,7 +274,8 @@ func.func @test_mul_rank_broadcast2(%arg0: tensor<21x1xf32>, %arg1: tensor<13x21
271274
// CHECK-LABEL: func @test_mul_rank_broadcast2
272275
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<21x1xf32>, [[PARAM_1_:%.+]]: tensor<13x21x1xf32>) -> tensor<13x21x1xf32> {
273276
// CHECK-NEXT: [[VAR_0_:%.+]] = tosa.reshape [[PARAM_0_]] {new_shape = array<i64: 1, 21, 1>} : (tensor<21x1xf32>) -> tensor<1x21x1xf32>
274-
// CHECK-NEXT: [[VAR_1_:%.+]] = tosa.mul [[VAR_0_]], [[PARAM_1_]] : (tensor<1x21x1xf32>, tensor<13x21x1xf32>) -> tensor<13x21x1xf32>
277+
// CHECK-NEXT: [[SHIFT_3_:%.+]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
278+
// CHECK-NEXT: [[VAR_1_:%.+]] = tosa.mul [[VAR_0_]], [[PARAM_1_]], [[SHIFT_3_]] : (tensor<1x21x1xf32>, tensor<13x21x1xf32>, tensor<1xi8>) -> tensor<13x21x1xf32>
275279
}
276280

277281
// -----
@@ -302,7 +306,7 @@ func.func @test_div_dynamic_float(%arg0: tensor<?x?x?xf32>, %arg1: tensor<13x?x?
302306
// CHECK-LABEL: func.func @test_div_dynamic_float
303307
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<?x?x?xf32>, [[PARAM_1_:%.+]]: tensor<13x?x?xf32>) -> tensor<13x?x?xf32> {
304308
// CHECK: [[VAR_0_:%.+]] = tosa.reciprocal [[PARAM_1_]] : (tensor<13x?x?xf32>) -> tensor<13x?x?xf32>
305-
// CHECK: [[VAR_1_:%.+]] = tosa.mul [[PARAM_0_]], [[VAR_0_]] : (tensor<?x?x?xf32>, tensor<13x?x?xf32>) -> tensor<13x?x?xf32>
309+
// CHECK: [[VAR_1_:%.+]] = tosa.mul [[PARAM_0_]], [[VAR_0_]], {{.*}}: (tensor<?x?x?xf32>, tensor<13x?x?xf32>, tensor<1xi8>) -> tensor<13x?x?xf32>
306310
// CHECK: return [[VAR_1_]] : tensor<13x?x?xf32>
307311
// CHECK: }
308312
}
@@ -336,7 +340,8 @@ func.func @test_div_decomposed(%arg0: tensor<13x21x1xf32>, %arg1: tensor<13x21x1
336340
// CHECK-LABEL: func @test_div_decomposed
337341
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<13x21x1xf32>, [[PARAM_1_:%.+]]: tensor<13x21x1xf32>) -> tensor<13x21x1xf32> {
338342
// CHECK-NEXT: [[VAR_0_:%.+]] = tosa.reciprocal [[PARAM_1_]] : (tensor<13x21x1xf32>) -> tensor<13x21x1xf32>
339-
// CHECK-NEXT: [[VAR_1_:%.+]] = tosa.mul [[PARAM_0_]], [[VAR_0_]] : (tensor<13x21x1xf32>, tensor<13x21x1xf32>) -> tensor<13x21x1xf32>
343+
// CHECK-NEXT: [[SHIFT_4_:%.+]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
344+
// CHECK-NEXT: [[VAR_1_:%.+]] = tosa.mul [[PARAM_0_]], [[VAR_0_]], [[SHIFT_4_]] : (tensor<13x21x1xf32>, tensor<13x21x1xf32>, tensor<1xi8>) -> tensor<13x21x1xf32>
340345
}
341346

342347
// -----
@@ -442,11 +447,11 @@ func.func @test_selu_dynamic(%arg0: tensor<?x4x?xf32>) -> tensor<?x4x?xf32> {
442447
// CHECK-DAG: [[VAR_1_:%.+]] = "tosa.const"() <{value = dense<2.000000e+00> : tensor<1x1x1xf32>}> : () -> tensor<1x1x1xf32>
443448
// CHECK-DAG: [[VAR_2_:%.+]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1x1x1xf32>}> : () -> tensor<1x1x1xf32>
444449
// CHECK-DAG: [[VAR_3_:%.+]] = tosa.exp [[PARAM_0_]] : (tensor<?x4x?xf32>) -> tensor<?x4x?xf32>
445-
// CHECK: [[VAR_4_:%.+]] = tosa.mul [[VAR_3_]], [[VAR_0_]] : (tensor<?x4x?xf32>, tensor<1x1x1xf32>) -> tensor<?x4x?xf32>
450+
// CHECK: [[VAR_4_:%.+]] = tosa.mul [[VAR_3_]], [[VAR_0_]], {{.*}}: (tensor<?x4x?xf32>, tensor<1x1x1xf32>, tensor<1xi8>) -> tensor<?x4x?xf32>
446451
// CHECK-DAG: [[VAR_5_:%.+]] = tosa.sub [[VAR_4_]], [[VAR_0_]] : (tensor<?x4x?xf32>, tensor<1x1x1xf32>) -> tensor<?x4x?xf32>
447452
// CHECK-DAG: [[VAR_6_:%.+]] = tosa.greater [[PARAM_0_]], [[VAR_2_]] : (tensor<?x4x?xf32>, tensor<1x1x1xf32>) -> tensor<?x4x?xi1>
448453
// CHECK: [[VAR_7_:%.+]] = tosa.select [[VAR_6_]], [[PARAM_0_]], [[VAR_5_]] : (tensor<?x4x?xi1>, tensor<?x4x?xf32>, tensor<?x4x?xf32>) -> tensor<?x4x?xf32>
449-
// CHECK: [[VAR_8_:%.+]] = tosa.mul [[VAR_7_]], [[VAR_1_]] : (tensor<?x4x?xf32>, tensor<1x1x1xf32>) -> tensor<?x4x?xf32>
454+
// CHECK: [[VAR_8_:%.+]] = tosa.mul [[VAR_7_]], [[VAR_1_]], {{.*}}: (tensor<?x4x?xf32>, tensor<1x1x1xf32>, tensor<1xi8>) -> tensor<?x4x?xf32>
450455
// CHECK: return [[VAR_8_]] : tensor<?x4x?xf32>
451456
}
452457

@@ -697,7 +702,8 @@ func.func @test_div_decomposed_broadcast(%arg0: tensor<13x21x1xf32>, %arg1: tens
697702
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<13x21x1xf32>, [[PARAM_1_:%.+]]: tensor<1xf32>) -> tensor<13x21x1xf32> {
698703
// CHECK-NEXT: [[VAR_0_:%.+]] = tosa.reciprocal [[PARAM_1_]] : (tensor<1xf32>) -> tensor<1xf32>
699704
// CHECK-NEXT: [[VAR_1_:%.+]] = tosa.reshape [[VAR_0_]] {new_shape = array<i64: 1, 1, 1>} : (tensor<1xf32>) -> tensor<1x1x1xf32>
700-
// CHECK-NEXT: [[VAR_2_:%.+]] = tosa.mul [[PARAM_0_]], [[VAR_1_]] : (tensor<13x21x1xf32>, tensor<1x1x1xf32>) -> tensor<13x21x1xf32>
705+
// CHECK-NEXT: [[SHIFT_5_:%.+]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
706+
// CHECK-NEXT: [[VAR_2_:%.+]] = tosa.mul [[PARAM_0_]], [[VAR_1_]], [[SHIFT_5_]] : (tensor<13x21x1xf32>, tensor<1x1x1xf32>, tensor<1xi8>) -> tensor<13x21x1xf32>
701707
}
702708

703709
// -----
@@ -846,7 +852,7 @@ func.func @test_hardsigmoid_default_values_f32(%arg0: tensor<3xf32>) -> tensor<3
846852
// CHECK-DAG: [[VAR_1_:%.+]] = "tosa.const"() <{value = dense<2.000000e-01> : tensor<1xf32>}> : () -> tensor<1xf32>
847853
// CHECK: [[VAR_2_:%.+]] = tosa.add [[PARAM_0_]], [[VAR_0_]] : (tensor<3xf32>, tensor<1xf32>) -> tensor<3xf32>
848854
// CHECK: [[VAR_3_:%.+]] = tosa.clamp [[VAR_2_]] {max_fp = 5.000000e+00 : f32, max_int = 5 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<3xf32>) -> tensor<3xf32>
849-
// CHECK: [[VAR_4_:%.+]] = tosa.mul [[VAR_3_]], [[VAR_1_]] : (tensor<3xf32>, tensor<1xf32>) -> tensor<3xf32>
855+
// CHECK: [[VAR_4_:%.+]] = tosa.mul [[VAR_3_]], [[VAR_1_]], {{.*}}: (tensor<3xf32>, tensor<1xf32>, tensor<1xi8>) -> tensor<3xf32>
850856
// CHECK: return [[VAR_4_]] : tensor<3xf32>
851857
}
852858

@@ -859,7 +865,7 @@ func.func @test_hardsigmoid_default_values_f16(%arg0: tensor<3xf16>) -> tensor<3
859865
// CHECK-DAG: [[VAR_1_:%.+]] = "tosa.const"() <{value = dense<1.999510e-01> : tensor<1xf16>}> : () -> tensor<1xf16>
860866
// CHECK: [[VAR_2_:%.+]] = tosa.add [[PARAM_0_]], [[VAR_0_]] : (tensor<3xf16>, tensor<1xf16>) -> tensor<3xf16>
861867
// CHECK: [[VAR_3_:%.+]] = tosa.clamp [[VAR_2_]] {max_fp = 5.000000e+00 : f32, max_int = 5 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<3xf16>) -> tensor<3xf16>
862-
// CHECK: [[VAR_4_:%.+]] = tosa.mul [[VAR_3_]], [[VAR_1_]] : (tensor<3xf16>, tensor<1xf16>) -> tensor<3xf16>
868+
// CHECK: [[VAR_4_:%.+]] = tosa.mul [[VAR_3_]], [[VAR_1_]], {{.*}}: (tensor<3xf16>, tensor<1xf16>, tensor<1xi8>) -> tensor<3xf16>
863869
// CHECK: return [[VAR_4_]] : tensor<3xf16>
864870
}
865871

@@ -873,7 +879,7 @@ func.func @test_hardsigmoid_f32(%arg0: tensor<3xf32>) -> tensor<3xf32> {
873879
// CHECK-DAG: [[VAR_1_:%.+]] = "tosa.const"() <{value = dense<0.166666672> : tensor<1xf32>}> : () -> tensor<1xf32>
874880
// CHECK: [[VAR_2_:%.+]] = tosa.add [[PARAM_0_]], [[VAR_0_]] : (tensor<3xf32>, tensor<1xf32>) -> tensor<3xf32>
875881
// CHECK: [[VAR_3_:%.+]] = tosa.clamp [[VAR_2_]] {max_fp = 6.000000e+00 : f32, max_int = 6 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<3xf32>) -> tensor<3xf32>
876-
// CHECK: [[VAR_4_:%.+]] = tosa.mul [[VAR_3_]], [[VAR_1_]] : (tensor<3xf32>, tensor<1xf32>) -> tensor<3xf32>
882+
// CHECK: [[VAR_4_:%.+]] = tosa.mul [[VAR_3_]], [[VAR_1_]], {{.*}}: (tensor<3xf32>, tensor<1xf32>, tensor<1xi8>) -> tensor<3xf32>
877883
// CHECK: return [[VAR_4_]] : tensor<3xf32>
878884
}
879885

@@ -886,7 +892,7 @@ func.func @test_hardsigmoid_f16(%arg0: tensor<3xf16>) -> tensor<3xf16> {
886892
// CHECK-DAG: [[VAR_1_:%.+]] = "tosa.const"() <{value = dense<1.666260e-01> : tensor<1xf16>}> : () -> tensor<1xf16>
887893
// CHECK: [[VAR_2_:%.+]] = tosa.add [[PARAM_0_]], [[VAR_0_]] : (tensor<3xf16>, tensor<1xf16>) -> tensor<3xf16>
888894
// CHECK: [[VAR_3_:%.+]] = tosa.clamp [[VAR_2_]] {max_fp = 6.000000e+00 : f32, max_int = 6 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<3xf16>) -> tensor<3xf16>
889-
// CHECK: [[VAR_4_:%.+]] = tosa.mul [[VAR_3_]], [[VAR_1_]] : (tensor<3xf16>, tensor<1xf16>) -> tensor<3xf16>
895+
// CHECK: [[VAR_4_:%.+]] = tosa.mul [[VAR_3_]], [[VAR_1_]], {{.*}}: (tensor<3xf16>, tensor<1xf16>, tensor<1xi8>) -> tensor<3xf16>
890896
// CHECK: return [[VAR_4_]] : tensor<3xf16>
891897
}
892898

@@ -901,7 +907,7 @@ func.func @test_hardsigmoid_dynamic(%arg0: tensor<?x3x?xf16>) -> tensor<?x3x?xf1
901907
// CHECK-DAG: [[VAR_1_:%.+]] = "tosa.const"() <{value = dense<1.666260e-01> : tensor<1x1x1xf16>}> : () -> tensor<1x1x1xf16>
902908
// CHECK: [[VAR_2_:%.+]] = tosa.add [[PARAM_0_]], [[VAR_0_]] : (tensor<?x3x?xf16>, tensor<1x1x1xf16>) -> tensor<?x3x?xf16>
903909
// CHECK: [[VAR_3_:%.+]] = tosa.clamp [[VAR_2_]] {max_fp = 6.000000e+00 : f32, max_int = 6 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<?x3x?xf16>) -> tensor<?x3x?xf16>
904-
// CHECK: [[VAR_4_:%.+]] = tosa.mul [[VAR_3_]], [[VAR_1_]] : (tensor<?x3x?xf16>, tensor<1x1x1xf16>) -> tensor<?x3x?xf16>
910+
// CHECK: [[VAR_4_:%.+]] = tosa.mul [[VAR_3_]], [[VAR_1_]], {{.*}}: (tensor<?x3x?xf16>, tensor<1x1x1xf16>, tensor<1xi8>) -> tensor<?x3x?xf16>
905911
// CHECK: return [[VAR_4_]] : tensor<?x3x?xf16>
906912
}
907913

@@ -928,7 +934,7 @@ func.func @test_elu_f32(%arg0: tensor<3xf32>) -> tensor<3xf32> {
928934
// CHECK-DAG: [[VAR_2_:%.+]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
929935
// CHECK-DAG: [[VAR_3_:%.+]] = tosa.exp [[PARAM_0_]] : (tensor<3xf32>) -> tensor<3xf32>
930936
// CHECK: [[VAR_4_:%.+]] = tosa.sub [[VAR_3_]], [[VAR_0_]] : (tensor<3xf32>, tensor<1xf32>) -> tensor<3xf32>
931-
// CHECK-DAG: [[VAR_5_:%.+]] = tosa.mul [[VAR_4_]], [[VAR_1_]] : (tensor<3xf32>, tensor<1xf32>) -> tensor<3xf32>
937+
// CHECK-DAG: [[VAR_5_:%.+]] = tosa.mul [[VAR_4_]], [[VAR_1_]], {{.*}}: (tensor<3xf32>, tensor<1xf32>, tensor<1xi8>) -> tensor<3xf32>
932938
// CHECK-DAG: [[VAR_6_:%.+]] = tosa.greater_equal [[PARAM_0_]], [[VAR_2_]] : (tensor<3xf32>, tensor<1xf32>) -> tensor<3xi1>
933939
// CHECK: [[VAR_7_:%.+]] = tosa.select [[VAR_6_]], [[PARAM_0_]], [[VAR_5_]] : (tensor<3xi1>, tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32>
934940
// CHECK: return [[VAR_7_]]
@@ -944,7 +950,7 @@ func.func @test_elu_f16(%arg0: tensor<3xf16>) -> tensor<3xf16> {
944950
// CHECK-DAG: [[VAR_2_:%.+]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1xf16>}> : () -> tensor<1xf16>
945951
// CHECK-DAG: [[VAR_3_:%.+]] = tosa.exp [[PARAM_0_]] : (tensor<3xf16>) -> tensor<3xf16>
946952
// CHECK: [[VAR_4_:%.+]] = tosa.sub [[VAR_3_]], [[VAR_0_]] : (tensor<3xf16>, tensor<1xf16>) -> tensor<3xf16>
947-
// CHECK-DAG: [[VAR_5_:%.+]] = tosa.mul [[VAR_4_]], [[VAR_1_]] : (tensor<3xf16>, tensor<1xf16>) -> tensor<3xf16>
953+
// CHECK-DAG: [[VAR_5_:%.+]] = tosa.mul [[VAR_4_]], [[VAR_1_]], {{.*}}: (tensor<3xf16>, tensor<1xf16>, tensor<1xi8>) -> tensor<3xf16>
948954
// CHECK-DAG: [[VAR_6_:%.+]] = tosa.greater_equal [[PARAM_0_]], [[VAR_2_]] : (tensor<3xf16>, tensor<1xf16>) -> tensor<3xi1>
949955
// CHECK: [[VAR_7_:%.+]] = tosa.select [[VAR_6_]], [[PARAM_0_]], [[VAR_5_]] : (tensor<3xi1>, tensor<3xf16>, tensor<3xf16>) -> tensor<3xf16>
950956
// CHECK: return [[VAR_7_]]
@@ -962,7 +968,7 @@ func.func @test_elu_unranked(%arg0: tensor<*xf32>) -> tensor<3xf32> {
962968
// CHECK-DAG: [[VAR_2_:%.+]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
963969
// CHECK-DAG: [[VAR_3_:%.+]] = tosa.exp [[PARAM_0_]] : (tensor<*xf32>) -> tensor<3xf32>
964970
// CHECK: [[VAR_4_:%.+]] = tosa.sub [[VAR_3_]], [[VAR_0_]] : (tensor<3xf32>, tensor<1xf32>) -> tensor<3xf32>
965-
// CHECK-DAG: [[VAR_5_:%.+]] = tosa.mul [[VAR_4_]], [[VAR_1_]] : (tensor<3xf32>, tensor<1xf32>) -> tensor<3xf32>
971+
// CHECK-DAG: [[VAR_5_:%.+]] = tosa.mul [[VAR_4_]], [[VAR_1_]], {{.*}}: (tensor<3xf32>, tensor<1xf32>, tensor<1xi8>) -> tensor<3xf32>
966972
// CHECK-DAG: [[VAR_6_:%.+]] = tosa.greater_equal [[PARAM_0_]], [[VAR_2_]] : (tensor<*xf32>, tensor<1xf32>) -> tensor<*xi1>
967973
// CHECK: [[VAR_7_:%.+]] = tosa.select [[VAR_6_]], [[PARAM_0_]], [[VAR_5_]] : (tensor<*xi1>, tensor<*xf32>, tensor<3xf32>) -> tensor<3xf32>
968974
// CHECK: return [[VAR_7_]] : tensor<3xf32>

0 commit comments

Comments
 (0)