Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions src/Conversion/ONNXToTOSA/DialectBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -211,14 +211,18 @@ Value TosaBuilder::mul(Value &lhs, Value &rhs, int32_t shift) {
rhs = valueVec[1];
}
auto lhsType = mlir::cast<ShapedType>(lhs.getType());
auto elementType = lhsType.getElementType();
Type newValueType =
(!lhsType.hasRank())
? lhsType
: RankedTensorType::get(llvm::SmallVector<int64_t, 4>(
lhsType.getRank(), ShapedType::kDynamic),
lhsType.getElementType());
elementType);

Value shiftConst =
tosa::createMulShiftConst(rewriter(), loc(), /*shift=*/shift);
return tosa::CreateOpAndInfer<mlir::tosa::MulOp>(
rewriter(), loc(), newValueType, lhs, rhs, shift);
rewriter(), loc(), newValueType, lhs, rhs, shiftConst);
}

Value TosaBuilder::intdiv(Value &lhs, Value &rhs) {
Expand Down
3 changes: 2 additions & 1 deletion src/Conversion/ONNXToTOSA/NN/DequantizeLinear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,9 @@ class ONNXDequantizeLinearOpLoweringToTOSA
rewriter, loc, adaptor.getXScale(), axis, resultType.getRank());
Value scaleFactorCast =
tosaBuilder.castToNewTensorElementType(scaleFactorConst, arithType);
Value shiftConst = tosa::createMulShiftConst(rewriter, loc, 0);
Value mulOp = tosa::CreateOpAndInfer<mlir::tosa::MulOp>(
rewriter, loc, casted.getType(), casted, scaleFactorCast, 0)
rewriter, loc, casted.getType(), casted, scaleFactorCast, shiftConst)
.getResult();
Value castOp = tosaBuilder.castToNewTensorElementType(
mulOp, resultType.getElementType());
Expand Down
3 changes: 2 additions & 1 deletion src/Conversion/ONNXToTOSA/NN/QuantizeLinear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,9 @@ class ONNXQuantizeLinearOpLoweringToTOSA
Value recOp = tosa::CreateOpAndInfer<mlir::tosa::ReciprocalOp>(rewriter,
loc, expandedScaleFactorConst.getType(), expandedScaleFactorConst)
.getResult();
Value shiftConst = tosa::createMulShiftConst(rewriter, loc, 0);
Value scaledResult = tosa::CreateOpAndInfer<mlir::tosa::MulOp>(
rewriter, loc, xType, x, recOp, 0)
rewriter, loc, xType, x, recOp, shiftConst)
.getResult();

// Quantization to i4/i8/16/ is particular since the intermediate result of
Expand Down
9 changes: 9 additions & 0 deletions src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,5 +149,14 @@ mlir::Value expandShape(mlir::PatternRewriter &rewriter, mlir::Location loc,
loc, resultTy, tensor, newShape);
}

mlir::Value createMulShiftConst(
mlir::PatternRewriter &rewriter, mlir::Location loc, int32_t shift) {
assert(shift >= -128 && shift <= 127 && "TOSA shift must fit in i8");
auto shiftType = RankedTensorType::get({1}, rewriter.getI8Type());
auto shiftAttr = DenseElementsAttr::get<int8_t>(
shiftType, llvm::ArrayRef<int8_t>{static_cast<int8_t>(shift)});
return rewriter.create<mlir::tosa::ConstOp>(loc, shiftType, shiftAttr);
}

} // namespace tosa
} // namespace onnx_mlir
3 changes: 3 additions & 0 deletions src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ mlir::ElementsAttr getElementsAttrFromConst(mlir::Value &val);
mlir::Value expandShape(mlir::PatternRewriter &rewriter, mlir::Location loc,
mlir::Value tensor, size_t axis, size_t rank);

mlir::Value createMulShiftConst(
mlir::PatternRewriter &rewriter, mlir::Location loc, int32_t shift);

// Creates a TOSA operation and performs shape inference on the individual
// op. This allows shape inference during the framework to TOSA lowering.
template <typename TosaOp, typename... Args>
Expand Down
62 changes: 34 additions & 28 deletions test/mlir/conversion/onnx_to_tosa/Math/Elementwise.mlir

Large diffs are not rendered by default.

12 changes: 6 additions & 6 deletions test/mlir/conversion/onnx_to_tosa/Math/Gemm_to_matmul.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ func.func @test_alpha(%arg0: tensor<3x6xf32>, %arg1: tensor<6x4xf32>, %arg2: ten
// CHECK-DAG: [[VAR_0_:%.+]] = tosa.reshape [[PARAM_0_]] {new_shape = array<i64: 1, 3, 6>} : (tensor<3x6xf32>) -> tensor<1x3x6xf32>
// CHECK-DAG: [[VAR_1_:%.+]] = tosa.reshape [[PARAM_1_]] {new_shape = array<i64: 1, 6, 4>} : (tensor<6x4xf32>) -> tensor<1x6x4xf32>
// CHECK-DAG: [[VAR_2_:%.+]] = "tosa.const"() <{value = dense<1.618000e+00> : tensor<1x1x1xf32>}> : () -> tensor<1x1x1xf32>
// CHECK: [[VAR_3_:%.+]] = tosa.mul [[VAR_2_]], [[VAR_0_]] {shift = 0 : i8} : (tensor<1x1x1xf32>, tensor<1x3x6xf32>) -> tensor<1x3x6xf32>
// CHECK: [[VAR_3_:%.+]] = tosa.mul [[VAR_2_]], [[VAR_0_]], {{.*}}: (tensor<1x1x1xf32>, tensor<1x3x6xf32>, tensor<1xi8>) -> tensor<1x3x6xf32>
// CHECK-DAG: [[VAR_4_:%.+]] = tosa.matmul [[VAR_3_]], [[VAR_1_]] : (tensor<1x3x6xf32>, tensor<1x6x4xf32>) -> tensor<1x3x4xf32>
// CHECK-DAG: [[VAR_5_:%.+]] = tosa.reshape [[PARAM_2_]] {new_shape = array<i64: 1, 3, 4>} : (tensor<3x4xf32>) -> tensor<1x3x4xf32>
// CHECK: [[VAR_6_:%.+]] = tosa.add [[VAR_4_]], [[VAR_5_]] : (tensor<1x3x4xf32>, tensor<1x3x4xf32>) -> tensor<1x3x4xf32>
Expand All @@ -47,7 +47,7 @@ func.func @test_beta(%arg0: tensor<3x6xf32>, %arg1: tensor<6x6xf32>, %arg2: tens
// CHECK-DAG: [[VAR_2_:%.+]] = "tosa.const"() <{value = dense<1.349000e+00> : tensor<1x1x1xf32>}> : () -> tensor<1x1x1xf32>
// CHECK-DAG: [[VAR_3_:%.+]] = tosa.reshape [[PARAM_2_]] {new_shape = array<i64: 1, 3, 6>} : (tensor<3x6xf32>) -> tensor<1x3x6xf32>
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[VAR_4_:%.+]] = tosa.mul [[VAR_2_]], [[VAR_3_]] {shift = 0 : i8} : (tensor<1x1x1xf32>, tensor<1x3x6xf32>) -> tensor<1x3x6xf32>
// CHECK-DAG: [[VAR_4_:%.+]] = tosa.mul [[VAR_2_]], [[VAR_3_]], {{.*}}: (tensor<1x1x1xf32>, tensor<1x3x6xf32>, tensor<1xi8>) -> tensor<1x3x6xf32>
// CHECK-DAG: [[VAR_5_:%.+]] = tosa.matmul [[VAR_0_]], [[VAR_1_]] : (tensor<1x3x6xf32>, tensor<1x6x6xf32>) -> tensor<1x3x6xf32>
// CHECK: [[VAR_6_:%.+]] = tosa.add [[VAR_5_]], [[VAR_4_]] : (tensor<1x3x6xf32>, tensor<1x3x6xf32>) -> tensor<1x3x6xf32>
// CHECK: [[VAR_7_:%.+]] = tosa.reshape [[VAR_6_]] {new_shape = array<i64: 3, 6>} : (tensor<1x3x6xf32>) -> tensor<3x6xf32>
Expand Down Expand Up @@ -87,7 +87,7 @@ func.func @test_transb(%arg0: tensor<3x6xf32>, %arg1: tensor<4x6xf32>, %arg2: te
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[VAR_3_:%.+]] = tosa.transpose [[VAR_1_]], [[VAR_2_]] : (tensor<1x4x6xf32>, tensor<3xi32>) -> tensor<1x6x4xf32>
// CHECK-DAG: [[VAR_4_:%.+]] = "tosa.const"() <{value = dense<1.184000e+00> : tensor<1x1x1xf32>}> : () -> tensor<1x1x1xf32>
// CHECK: [[VAR_5_:%.+]] = tosa.mul [[VAR_4_]], [[VAR_0_]] {shift = 0 : i8} : (tensor<1x1x1xf32>, tensor<1x3x6xf32>) -> tensor<1x3x6xf32>
// CHECK: [[VAR_5_:%.+]] = tosa.mul [[VAR_4_]], [[VAR_0_]], {{.*}}: (tensor<1x1x1xf32>, tensor<1x3x6xf32>, tensor<1xi8>) -> tensor<1x3x6xf32>
// CHECK-DAG: [[VAR_6_:%.+]] = tosa.matmul [[VAR_5_]], [[VAR_3_]] : (tensor<1x3x6xf32>, tensor<1x6x4xf32>) -> tensor<1x3x4xf32>
// CHECK-DAG: [[VAR_7_:%.+]] = tosa.reshape [[PARAM_2_]] {new_shape = array<i64: 1, 3, 4>} : (tensor<3x4xf32>) -> tensor<1x3x4xf32>
// CHECK: [[VAR_8_:%.+]] = tosa.add [[VAR_6_]], [[VAR_7_]] : (tensor<1x3x4xf32>, tensor<1x3x4xf32>) -> tensor<1x3x4xf32>
Expand Down Expand Up @@ -127,7 +127,7 @@ func.func @test_no_c_no_trans(%arg0: tensor<1x5xf32>, %arg1: tensor<5x6xf32>) ->
// CHECK-DAG: [[VAR_1_:%.+]] = tosa.reshape [[PARAM_0_]] {new_shape = array<i64: 1, 1, 5>} : (tensor<1x5xf32>) -> tensor<1x1x5xf32>
// CHECK-DAG: [[VAR_2_:%.+]] = tosa.reshape [[PARAM_1_]] {new_shape = array<i64: 1, 5, 6>} : (tensor<5x6xf32>) -> tensor<1x5x6xf32>
// CHECK-DAG: [[VAR_3_:%.+]] = "tosa.const"() <{value = dense<1.349000e+00> : tensor<1x1x1xf32>}> : () -> tensor<1x1x1xf32>
// CHECK: [[VAR_4_:%.+]] = tosa.mul [[VAR_3_]], [[VAR_1_]] {shift = 0 : i8} : (tensor<1x1x1xf32>, tensor<1x1x5xf32>) -> tensor<1x1x5xf32>
// CHECK: [[VAR_4_:%.+]] = tosa.mul [[VAR_3_]], [[VAR_1_]], {{.*}}: (tensor<1x1x1xf32>, tensor<1x1x5xf32>, tensor<1xi8>) -> tensor<1x1x5xf32>
// CHECK: [[VAR_5_:%.+]] = tosa.matmul [[VAR_4_]], [[VAR_2_]] : (tensor<1x1x5xf32>, tensor<1x5x6xf32>) -> tensor<1x1x6xf32>
// CHECK: [[VAR_6_:%.+]] = tosa.reshape [[VAR_5_]] {new_shape = array<i64: 1, 6>} : (tensor<1x1x6xf32>) -> tensor<1x6xf32>
// CHECK: return [[VAR_6_]] : tensor<1x6xf32>
Expand All @@ -151,11 +151,11 @@ func.func @test_mixed(%arg0: tensor<11x5xf32>, %arg1: tensor<3x11xf32>, %arg2: t
// CHECK-DAG: [[VAR_5_:%.+]] = tosa.transpose [[VAR_1_]], [[VAR_4_]] : (tensor<1x3x11xf32>, tensor<3xi32>) -> tensor<1x11x3xf32>
// CHECK-DAG: [[VAR_6_:%.+]] = "tosa.const"() <{value = dense<1.402000e+00> : tensor<1x1x1xf32>}> : () -> tensor<1x1x1xf32>
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[VAR_7_:%.+]] = tosa.mul [[VAR_6_]], [[VAR_3_]] {shift = 0 : i8} : (tensor<1x1x1xf32>, tensor<1x5x11xf32>) -> tensor<1x5x11xf32>
// CHECK-DAG: [[VAR_7_:%.+]] = tosa.mul [[VAR_6_]], [[VAR_3_]], {{.*}}: (tensor<1x1x1xf32>, tensor<1x5x11xf32>, tensor<1xi8>) -> tensor<1x5x11xf32>
// CHECK-DAG: [[VAR_8_:%.+]] = "tosa.const"() <{value = dense<1.998000e+00> : tensor<1x1x1xf32>}> : () -> tensor<1x1x1xf32>
// CHECK-DAG: [[VAR_9_:%.+]] = tosa.reshape [[PARAM_2_]] {new_shape = array<i64: 1, 5, 3>} : (tensor<5x3xf32>) -> tensor<1x5x3xf32>
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[VAR_10_:%.+]] = tosa.mul [[VAR_8_]], [[VAR_9_]] {shift = 0 : i8} : (tensor<1x1x1xf32>, tensor<1x5x3xf32>) -> tensor<1x5x3xf32>
// CHECK-DAG: [[VAR_10_:%.+]] = tosa.mul [[VAR_8_]], [[VAR_9_]], {{.*}}: (tensor<1x1x1xf32>, tensor<1x5x3xf32>, tensor<1xi8>) -> tensor<1x5x3xf32>
// CHECK-DAG: [[VAR_11_:%.+]] = tosa.matmul [[VAR_7_]], [[VAR_5_]] : (tensor<1x5x11xf32>, tensor<1x11x3xf32>) -> tensor<1x5x3xf32>
// CHECK: [[VAR_12_:%.+]] = tosa.add [[VAR_11_]], [[VAR_10_]] : (tensor<1x5x3xf32>, tensor<1x5x3xf32>) -> tensor<1x5x3xf32>
// CHECK: [[VAR_13_:%.+]] = tosa.reshape [[VAR_12_]] {new_shape = array<i64: 5, 3>} : (tensor<1x5x3xf32>) -> tensor<5x3xf32>
Expand Down
15 changes: 10 additions & 5 deletions test/mlir/conversion/onnx_to_tosa/Math/ReduceMean.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ return %1 : tensor<2x5x1x1xf32>
// CHECK: %[[VAL_1:.*]] = tosa.reduce_sum %[[VAL_0]] {axis = 2 : i32} : (tensor<2x5x9x11xf32>) -> tensor<2x5x1x11xf32>
// CHECK: %[[VAL_2:.*]] = tosa.reduce_sum %[[VAL_1]] {axis = 3 : i32} : (tensor<2x5x1x11xf32>) -> tensor<2x5x1x1xf32>
// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<0.0101010101> : tensor<1x1x1x1xf32>}> : () -> tensor<1x1x1x1xf32>
// CHECK: %[[VAL_4:.*]] = tosa.mul %[[VAL_2]], %[[VAL_3]] {shift = 0 : i8} : (tensor<2x5x1x1xf32>, tensor<1x1x1x1xf32>) -> tensor<2x5x1x1xf32>
// CHECK: %[[SHIFT_0:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
// CHECK: %[[VAL_4:.*]] = tosa.mul %[[VAL_2]], %[[VAL_3]], %[[SHIFT_0]] : (tensor<2x5x1x1xf32>, tensor<1x1x1x1xf32>, tensor<1xi8>) -> tensor<2x5x1x1xf32>
// CHECK: return %[[VAL_4]] : tensor<2x5x1x1xf32>
}

Expand All @@ -26,7 +27,8 @@ return %0 : tensor<1x1x1x1xf32>
// CHECK: %[[VAL_3:.*]] = tosa.reduce_sum %[[VAL_2]] {axis = 2 : i32} : (tensor<1x1x9x11xf32>) -> tensor<1x1x1x11xf32>
// CHECK: %[[VAL_4:.*]] = tosa.reduce_sum %[[VAL_3]] {axis = 3 : i32} : (tensor<1x1x1x11xf32>) -> tensor<1x1x1x1xf32>
// CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<0.00101010106> : tensor<1x1x1x1xf32>}> : () -> tensor<1x1x1x1xf32>
// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_4]], %[[VAL_5]] {shift = 0 : i8} : (tensor<1x1x1x1xf32>, tensor<1x1x1x1xf32>) -> tensor<1x1x1x1xf32>
// CHECK: %[[SHIFT_1:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_4]], %[[VAL_5]], %[[SHIFT_1]] : (tensor<1x1x1x1xf32>, tensor<1x1x1x1xf32>, tensor<1xi8>) -> tensor<1x1x1x1xf32>
// CHECK: return %[[VAL_6]] : tensor<1x1x1x1xf32>
}

Expand All @@ -42,7 +44,8 @@ return %1 : tensor<2x5xf32>
// CHECK: %[[VAL_2:.*]] = tosa.reduce_sum %[[VAL_1]] {axis = 3 : i32} : (tensor<2x5x1x11xf32>) -> tensor<2x5x1x1xf32>
// CHECK: %[[VAL_3:.*]] = tosa.reshape %[[VAL_2]] {new_shape = array<i64: 2, 5>} : (tensor<2x5x1x1xf32>) -> tensor<2x5xf32>
// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<0.0101010101> : tensor<1x1xf32>}> : () -> tensor<1x1xf32>
// CHECK: %[[VAL_5:.*]] = tosa.mul %[[VAL_3]], %[[VAL_4]] {shift = 0 : i8} : (tensor<2x5xf32>, tensor<1x1xf32>) -> tensor<2x5xf32>
// CHECK: %[[SHIFT_2:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
// CHECK: %[[VAL_5:.*]] = tosa.mul %[[VAL_3]], %[[VAL_4]], %[[SHIFT_2]] : (tensor<2x5xf32>, tensor<1x1xf32>, tensor<1xi8>) -> tensor<2x5xf32>
// CHECK: return %[[VAL_5]] : tensor<2x5xf32>
}

Expand All @@ -57,7 +60,8 @@ return %1 : tensor<2x5x1x1xf32>
// CHECK: %[[VAL_1:.*]] = tosa.reduce_sum %[[VAL_0]] {axis = 2 : i32} : (tensor<2x5x9x11xf32>) -> tensor<2x5x1x11xf32>
// CHECK: %[[VAL_2:.*]] = tosa.reduce_sum %[[VAL_1]] {axis = 3 : i32} : (tensor<2x5x1x11xf32>) -> tensor<2x5x1x1xf32>
// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<0.0101010101> : tensor<1x1x1x1xf32>}> : () -> tensor<1x1x1x1xf32>
// CHECK: %[[VAL_4:.*]] = tosa.mul %[[VAL_2]], %[[VAL_3]] {shift = 0 : i8} : (tensor<2x5x1x1xf32>, tensor<1x1x1x1xf32>) -> tensor<2x5x1x1xf32>
// CHECK: %[[SHIFT_3:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
// CHECK: %[[VAL_4:.*]] = tosa.mul %[[VAL_2]], %[[VAL_3]], %[[SHIFT_3]] : (tensor<2x5x1x1xf32>, tensor<1x1x1x1xf32>, tensor<1xi8>) -> tensor<2x5x1x1xf32>
// CHECK: return %[[VAL_4]] : tensor<2x5x1x1xf32>
}

Expand All @@ -81,7 +85,8 @@ func.func @test_reducemeanV13(%arg0: tensor<1x32x112x112xf32>) -> tensor<1x32x1x
// CHECK: [[VAR_0_:%.+]] = tosa.reduce_sum %arg0 {axis = 2 : i32}
// CHECK-DAG: [[VAR_1_:%.+]] = tosa.reduce_sum [[VAR_0_]] {axis = 3 : i32}
// CHECK-DAG: [[VAR_2_:%.+]] = "tosa.const"() <{value = dense<7.97193861E-5> : tensor<1x1x1x1xf32>}>
// CHECK: [[VAR_3_:%.+]] = tosa.mul [[VAR_1_]], [[VAR_2_]] {shift = 0 : i8} : (tensor<1x32x1x1xf32>, tensor<1x1x1x1xf32>)
// CHECK: [[SHIFT_4_:%.+]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
// CHECK: [[VAR_3_:%.+]] = tosa.mul [[VAR_1_]], [[VAR_2_]], [[SHIFT_4_]] : (tensor<1x32x1x1xf32>, tensor<1x1x1x1xf32>, tensor<1xi8>)
// CHECK: return [[VAR_3_]] : tensor<1x32x1x1xf32>
}

Expand Down
8 changes: 4 additions & 4 deletions test/mlir/conversion/onnx_to_tosa/Math/Softmax.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ func.func @test_softmax_v13(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
// CHECK: %[[VAL_1:.*]] = tosa.exp %[[SUB]] : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
// CHECK: %[[VAL_2:.*]] = tosa.reduce_sum %[[VAL_1]] {axis = 2 : i32} : (tensor<13x21x3xf32>) -> tensor<13x21x1xf32>
// CHECK: %[[VAL_3:.*]] = tosa.reciprocal %[[VAL_2]] : (tensor<13x21x1xf32>) -> tensor<13x21x1xf32>
// CHECK: %[[VAL_4:.*]] = tosa.mul %[[VAL_1]], %[[VAL_3]] {shift = 0 : i8} : (tensor<13x21x3xf32>, tensor<13x21x1xf32>) -> tensor<13x21x3xf32>
// CHECK: %[[VAL_4:.*]] = tosa.mul %[[VAL_1]], %[[VAL_3]], {{.*}}: (tensor<13x21x3xf32>, tensor<13x21x1xf32>, tensor<1xi8>) -> tensor<13x21x3xf32>
}

// -----
Expand All @@ -23,7 +23,7 @@ func.func @test_softmax_v13_axis_one(%arg0: tensor<13x21x3xf32>) -> tensor<13x21
// CHECK: %[[VAL_1:.*]] = tosa.exp %[[SUB]] : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
// CHECK: %[[VAL_2:.*]] = tosa.reduce_sum %[[VAL_1]] {axis = 1 : i32} : (tensor<13x21x3xf32>) -> tensor<13x1x3xf32>
// CHECK: %[[VAL_3:.*]] = tosa.reciprocal %[[VAL_2]] : (tensor<13x1x3xf32>) -> tensor<13x1x3xf32>
// CHECK: %[[VAL_4:.*]] = tosa.mul %[[VAL_1]], %[[VAL_3]] {shift = 0 : i8} : (tensor<13x21x3xf32>, tensor<13x1x3xf32>) -> tensor<13x21x3xf32>
// CHECK: %[[VAL_4:.*]] = tosa.mul %[[VAL_1]], %[[VAL_3]], {{.*}}: (tensor<13x21x3xf32>, tensor<13x1x3xf32>, tensor<1xi8>) -> tensor<13x21x3xf32>
}

// -----
Expand All @@ -39,7 +39,7 @@ func.func @test_softmax_before_v13(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3
// CHECK: %[[VAL_2:.*]] = tosa.reduce_sum %[[VAL_1]] {axis = 1 : i32} : (tensor<13x21x3xf32>) -> tensor<13x1x3xf32>
// CHECK: %[[VAL_3:.*]] = tosa.reduce_sum %[[VAL_2]] {axis = 2 : i32} : (tensor<13x1x3xf32>) -> tensor<13x1x1xf32>
// CHECK: %[[VAL_4:.*]] = tosa.reciprocal %[[VAL_3]] : (tensor<13x1x1xf32>) -> tensor<13x1x1xf32>
// CHECK: %[[VAL_5:.*]] = tosa.mul %[[VAL_1]], %[[VAL_4]] {shift = 0 : i8} : (tensor<13x21x3xf32>, tensor<13x1x1xf32>) -> tensor<13x21x3xf32>
// CHECK: %[[VAL_5:.*]] = tosa.mul %[[VAL_1]], %[[VAL_4]], {{.*}}: (tensor<13x21x3xf32>, tensor<13x1x1xf32>, tensor<1xi8>) -> tensor<13x21x3xf32>
}

// -----
Expand All @@ -57,5 +57,5 @@ func.func @test_softmax_before_v13_axis_zero(%arg0: tensor<13x21x3xf32>) -> tens
// CHECK: %[[VAL_3:.*]] = tosa.reduce_sum %[[VAL_2]] {axis = 1 : i32} : (tensor<1x21x3xf32>) -> tensor<1x1x3xf32>
// CHECK: %[[VAL_4:.*]] = tosa.reduce_sum %[[VAL_3]] {axis = 2 : i32} : (tensor<1x1x3xf32>) -> tensor<1x1x1xf32>
// CHECK: %[[VAL_5:.*]] = tosa.reciprocal %[[VAL_4]] : (tensor<1x1x1xf32>) -> tensor<1x1x1xf32>
// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_1]], %[[VAL_5]] {shift = 0 : i8} : (tensor<13x21x3xf32>, tensor<1x1x1xf32>) -> tensor<13x21x3xf32>
// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_1]], %[[VAL_5]], {{.*}}: (tensor<13x21x3xf32>, tensor<1x1x1xf32>, tensor<1xi8>) -> tensor<13x21x3xf32>
}
Loading