Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion src/Conversion/ONNXToTOSA/DialectBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,13 @@ Value TosaBuilder::reshape(Value value, llvm::ArrayRef<int64_t> shape) {
rewriter(), loc(), newValueType, value, shapeAttr);
}

Value TosaBuilder::mul(Value &lhs, Value &rhs, int32_t shift) {
Value TosaBuilder::mul(Value &lhs, Value &rhs, int8_t shift) {
Value shiftConst =
tosa::createMulShiftConst(rewriter(), loc(), /*shift=*/shift);
return mul(lhs, rhs, shiftConst);
}

Value TosaBuilder::mul(Value &lhs, Value &rhs, Value shift) {
if (needsRankBroadcast({lhs, rhs})) {
llvm::SmallVector<Value, 4> valueVec = equalizeRanks({lhs, rhs});
lhs = valueVec[0];
Expand All @@ -217,6 +223,7 @@ Value TosaBuilder::mul(Value &lhs, Value &rhs, int32_t shift) {
: RankedTensorType::get(llvm::SmallVector<int64_t, 4>(
lhsType.getRank(), ShapedType::kDynamic),
lhsType.getElementType());

return tosa::CreateOpAndInfer<mlir::tosa::MulOp>(
rewriter(), loc(), newValueType, lhs, rhs, shift);
}
Expand Down
3 changes: 2 additions & 1 deletion src/Conversion/ONNXToTOSA/DialectBuilder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ struct TosaBuilder : DialectBuilder {
int32_t axis);
template <typename T>
mlir::Value binaryOp(mlir::Value &lhs, mlir::Value &rhs);
mlir::Value mul(mlir::Value &lhs, mlir::Value &rhs, int32_t shift = 0);
mlir::Value mul(mlir::Value &lhs, mlir::Value &rhs, int8_t shift = 0);
mlir::Value mul(mlir::Value &lhs, mlir::Value &rhs, mlir::Value shift);
mlir::Value intdiv(mlir::Value &lhs, mlir::Value &rhs);

mlir::Value transpose(mlir::Value &value, llvm::ArrayRef<int32_t> perm);
Expand Down
3 changes: 2 additions & 1 deletion src/Conversion/ONNXToTOSA/NN/DequantizeLinear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,9 @@ class ONNXDequantizeLinearOpLoweringToTOSA
rewriter, loc, adaptor.getXScale(), axis, resultType.getRank());
Value scaleFactorCast =
tosaBuilder.castToNewTensorElementType(scaleFactorConst, arithType);
Value shiftConst = tosa::createMulShiftConst(rewriter, loc, 0);
Value mulOp = tosa::CreateOpAndInfer<mlir::tosa::MulOp>(
rewriter, loc, casted.getType(), casted, scaleFactorCast, 0)
rewriter, loc, casted.getType(), casted, scaleFactorCast, shiftConst)
.getResult();
Value castOp = tosaBuilder.castToNewTensorElementType(
mulOp, resultType.getElementType());
Expand Down
3 changes: 2 additions & 1 deletion src/Conversion/ONNXToTOSA/NN/QuantizeLinear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,9 @@ class ONNXQuantizeLinearOpLoweringToTOSA
Value recOp = tosa::CreateOpAndInfer<mlir::tosa::ReciprocalOp>(rewriter,
loc, expandedScaleFactorConst.getType(), expandedScaleFactorConst)
.getResult();
Value shiftConst = tosa::createMulShiftConst(rewriter, loc, 0);
Value scaledResult = tosa::CreateOpAndInfer<mlir::tosa::MulOp>(
rewriter, loc, xType, x, recOp, 0)
rewriter, loc, xType, x, recOp, shiftConst)
.getResult();

// Quantization to i4/i8/16/ is particular since the intermediate result of
Expand Down
9 changes: 9 additions & 0 deletions src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,5 +149,14 @@ mlir::Value expandShape(mlir::PatternRewriter &rewriter, mlir::Location loc,
loc, resultTy, tensor, newShape);
}

mlir::Value createMulShiftConst(
mlir::PatternRewriter &rewriter, mlir::Location loc, int32_t shift) {
assert(shift >= -128 && shift <= 127 && "TOSA shift must fit in i8");
auto shiftType = RankedTensorType::get({1}, rewriter.getI8Type());
auto shiftAttr = DenseElementsAttr::get<int8_t>(
shiftType, llvm::ArrayRef<int8_t>{static_cast<int8_t>(shift)});
return rewriter.create<mlir::tosa::ConstOp>(loc, shiftType, shiftAttr);
}

} // namespace tosa
} // namespace onnx_mlir
3 changes: 3 additions & 0 deletions src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ mlir::ElementsAttr getElementsAttrFromConst(mlir::Value &val);
mlir::Value expandShape(mlir::PatternRewriter &rewriter, mlir::Location loc,
mlir::Value tensor, size_t axis, size_t rank);

mlir::Value createMulShiftConst(
mlir::PatternRewriter &rewriter, mlir::Location loc, int32_t shift);

// Creates a TOSA operation and performs shape inference on the individual
// op. This allows shape inference during the framework to TOSA lowering.
template <typename TosaOp, typename... Args>
Expand Down
62 changes: 34 additions & 28 deletions test/mlir/conversion/onnx_to_tosa/Math/Elementwise.mlir

Large diffs are not rendered by default.

12 changes: 6 additions & 6 deletions test/mlir/conversion/onnx_to_tosa/Math/Gemm_to_matmul.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ func.func @test_alpha(%arg0: tensor<3x6xf32>, %arg1: tensor<6x4xf32>, %arg2: ten
// CHECK-DAG: [[VAR_0_:%.+]] = tosa.reshape [[PARAM_0_]] {new_shape = array<i64: 1, 3, 6>} : (tensor<3x6xf32>) -> tensor<1x3x6xf32>
// CHECK-DAG: [[VAR_1_:%.+]] = tosa.reshape [[PARAM_1_]] {new_shape = array<i64: 1, 6, 4>} : (tensor<6x4xf32>) -> tensor<1x6x4xf32>
// CHECK-DAG: [[VAR_2_:%.+]] = "tosa.const"() <{value = dense<1.618000e+00> : tensor<1x1x1xf32>}> : () -> tensor<1x1x1xf32>
// CHECK: [[VAR_3_:%.+]] = tosa.mul [[VAR_2_]], [[VAR_0_]] {shift = 0 : i8} : (tensor<1x1x1xf32>, tensor<1x3x6xf32>) -> tensor<1x3x6xf32>
// CHECK: [[VAR_3_:%.+]] = tosa.mul [[VAR_2_]], [[VAR_0_]], {{.*}}: (tensor<1x1x1xf32>, tensor<1x3x6xf32>, tensor<1xi8>) -> tensor<1x3x6xf32>
// CHECK-DAG: [[VAR_4_:%.+]] = tosa.matmul [[VAR_3_]], [[VAR_1_]] : (tensor<1x3x6xf32>, tensor<1x6x4xf32>) -> tensor<1x3x4xf32>
// CHECK-DAG: [[VAR_5_:%.+]] = tosa.reshape [[PARAM_2_]] {new_shape = array<i64: 1, 3, 4>} : (tensor<3x4xf32>) -> tensor<1x3x4xf32>
// CHECK: [[VAR_6_:%.+]] = tosa.add [[VAR_4_]], [[VAR_5_]] : (tensor<1x3x4xf32>, tensor<1x3x4xf32>) -> tensor<1x3x4xf32>
Expand All @@ -47,7 +47,7 @@ func.func @test_beta(%arg0: tensor<3x6xf32>, %arg1: tensor<6x6xf32>, %arg2: tens
// CHECK-DAG: [[VAR_2_:%.+]] = "tosa.const"() <{value = dense<1.349000e+00> : tensor<1x1x1xf32>}> : () -> tensor<1x1x1xf32>
// CHECK-DAG: [[VAR_3_:%.+]] = tosa.reshape [[PARAM_2_]] {new_shape = array<i64: 1, 3, 6>} : (tensor<3x6xf32>) -> tensor<1x3x6xf32>
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[VAR_4_:%.+]] = tosa.mul [[VAR_2_]], [[VAR_3_]] {shift = 0 : i8} : (tensor<1x1x1xf32>, tensor<1x3x6xf32>) -> tensor<1x3x6xf32>
// CHECK-DAG: [[VAR_4_:%.+]] = tosa.mul [[VAR_2_]], [[VAR_3_]], {{.*}}: (tensor<1x1x1xf32>, tensor<1x3x6xf32>, tensor<1xi8>) -> tensor<1x3x6xf32>
// CHECK-DAG: [[VAR_5_:%.+]] = tosa.matmul [[VAR_0_]], [[VAR_1_]] : (tensor<1x3x6xf32>, tensor<1x6x6xf32>) -> tensor<1x3x6xf32>
// CHECK: [[VAR_6_:%.+]] = tosa.add [[VAR_5_]], [[VAR_4_]] : (tensor<1x3x6xf32>, tensor<1x3x6xf32>) -> tensor<1x3x6xf32>
// CHECK: [[VAR_7_:%.+]] = tosa.reshape [[VAR_6_]] {new_shape = array<i64: 3, 6>} : (tensor<1x3x6xf32>) -> tensor<3x6xf32>
Expand Down Expand Up @@ -87,7 +87,7 @@ func.func @test_transb(%arg0: tensor<3x6xf32>, %arg1: tensor<4x6xf32>, %arg2: te
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[VAR_3_:%.+]] = tosa.transpose [[VAR_1_]], [[VAR_2_]] : (tensor<1x4x6xf32>, tensor<3xi32>) -> tensor<1x6x4xf32>
// CHECK-DAG: [[VAR_4_:%.+]] = "tosa.const"() <{value = dense<1.184000e+00> : tensor<1x1x1xf32>}> : () -> tensor<1x1x1xf32>
// CHECK: [[VAR_5_:%.+]] = tosa.mul [[VAR_4_]], [[VAR_0_]] {shift = 0 : i8} : (tensor<1x1x1xf32>, tensor<1x3x6xf32>) -> tensor<1x3x6xf32>
// CHECK: [[VAR_5_:%.+]] = tosa.mul [[VAR_4_]], [[VAR_0_]], {{.*}}: (tensor<1x1x1xf32>, tensor<1x3x6xf32>, tensor<1xi8>) -> tensor<1x3x6xf32>
// CHECK-DAG: [[VAR_6_:%.+]] = tosa.matmul [[VAR_5_]], [[VAR_3_]] : (tensor<1x3x6xf32>, tensor<1x6x4xf32>) -> tensor<1x3x4xf32>
// CHECK-DAG: [[VAR_7_:%.+]] = tosa.reshape [[PARAM_2_]] {new_shape = array<i64: 1, 3, 4>} : (tensor<3x4xf32>) -> tensor<1x3x4xf32>
// CHECK: [[VAR_8_:%.+]] = tosa.add [[VAR_6_]], [[VAR_7_]] : (tensor<1x3x4xf32>, tensor<1x3x4xf32>) -> tensor<1x3x4xf32>
Expand Down Expand Up @@ -127,7 +127,7 @@ func.func @test_no_c_no_trans(%arg0: tensor<1x5xf32>, %arg1: tensor<5x6xf32>) ->
// CHECK-DAG: [[VAR_1_:%.+]] = tosa.reshape [[PARAM_0_]] {new_shape = array<i64: 1, 1, 5>} : (tensor<1x5xf32>) -> tensor<1x1x5xf32>
// CHECK-DAG: [[VAR_2_:%.+]] = tosa.reshape [[PARAM_1_]] {new_shape = array<i64: 1, 5, 6>} : (tensor<5x6xf32>) -> tensor<1x5x6xf32>
// CHECK-DAG: [[VAR_3_:%.+]] = "tosa.const"() <{value = dense<1.349000e+00> : tensor<1x1x1xf32>}> : () -> tensor<1x1x1xf32>
// CHECK: [[VAR_4_:%.+]] = tosa.mul [[VAR_3_]], [[VAR_1_]] {shift = 0 : i8} : (tensor<1x1x1xf32>, tensor<1x1x5xf32>) -> tensor<1x1x5xf32>
// CHECK: [[VAR_4_:%.+]] = tosa.mul [[VAR_3_]], [[VAR_1_]], {{.*}}: (tensor<1x1x1xf32>, tensor<1x1x5xf32>, tensor<1xi8>) -> tensor<1x1x5xf32>
// CHECK: [[VAR_5_:%.+]] = tosa.matmul [[VAR_4_]], [[VAR_2_]] : (tensor<1x1x5xf32>, tensor<1x5x6xf32>) -> tensor<1x1x6xf32>
// CHECK: [[VAR_6_:%.+]] = tosa.reshape [[VAR_5_]] {new_shape = array<i64: 1, 6>} : (tensor<1x1x6xf32>) -> tensor<1x6xf32>
// CHECK: return [[VAR_6_]] : tensor<1x6xf32>
Expand All @@ -151,11 +151,11 @@ func.func @test_mixed(%arg0: tensor<11x5xf32>, %arg1: tensor<3x11xf32>, %arg2: t
// CHECK-DAG: [[VAR_5_:%.+]] = tosa.transpose [[VAR_1_]], [[VAR_4_]] : (tensor<1x3x11xf32>, tensor<3xi32>) -> tensor<1x11x3xf32>
// CHECK-DAG: [[VAR_6_:%.+]] = "tosa.const"() <{value = dense<1.402000e+00> : tensor<1x1x1xf32>}> : () -> tensor<1x1x1xf32>
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[VAR_7_:%.+]] = tosa.mul [[VAR_6_]], [[VAR_3_]] {shift = 0 : i8} : (tensor<1x1x1xf32>, tensor<1x5x11xf32>) -> tensor<1x5x11xf32>
// CHECK-DAG: [[VAR_7_:%.+]] = tosa.mul [[VAR_6_]], [[VAR_3_]], {{.*}}: (tensor<1x1x1xf32>, tensor<1x5x11xf32>, tensor<1xi8>) -> tensor<1x5x11xf32>
// CHECK-DAG: [[VAR_8_:%.+]] = "tosa.const"() <{value = dense<1.998000e+00> : tensor<1x1x1xf32>}> : () -> tensor<1x1x1xf32>
// CHECK-DAG: [[VAR_9_:%.+]] = tosa.reshape [[PARAM_2_]] {new_shape = array<i64: 1, 5, 3>} : (tensor<5x3xf32>) -> tensor<1x5x3xf32>
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[VAR_10_:%.+]] = tosa.mul [[VAR_8_]], [[VAR_9_]] {shift = 0 : i8} : (tensor<1x1x1xf32>, tensor<1x5x3xf32>) -> tensor<1x5x3xf32>
// CHECK-DAG: [[VAR_10_:%.+]] = tosa.mul [[VAR_8_]], [[VAR_9_]], {{.*}}: (tensor<1x1x1xf32>, tensor<1x5x3xf32>, tensor<1xi8>) -> tensor<1x5x3xf32>
// CHECK-DAG: [[VAR_11_:%.+]] = tosa.matmul [[VAR_7_]], [[VAR_5_]] : (tensor<1x5x11xf32>, tensor<1x11x3xf32>) -> tensor<1x5x3xf32>
// CHECK: [[VAR_12_:%.+]] = tosa.add [[VAR_11_]], [[VAR_10_]] : (tensor<1x5x3xf32>, tensor<1x5x3xf32>) -> tensor<1x5x3xf32>
// CHECK: [[VAR_13_:%.+]] = tosa.reshape [[VAR_12_]] {new_shape = array<i64: 5, 3>} : (tensor<1x5x3xf32>) -> tensor<5x3xf32>
Expand Down
15 changes: 10 additions & 5 deletions test/mlir/conversion/onnx_to_tosa/Math/ReduceMean.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ return %1 : tensor<2x5x1x1xf32>
// CHECK: %[[VAL_1:.*]] = tosa.reduce_sum %[[VAL_0]] {axis = 2 : i32} : (tensor<2x5x9x11xf32>) -> tensor<2x5x1x11xf32>
// CHECK: %[[VAL_2:.*]] = tosa.reduce_sum %[[VAL_1]] {axis = 3 : i32} : (tensor<2x5x1x11xf32>) -> tensor<2x5x1x1xf32>
// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<0.0101010101> : tensor<1x1x1x1xf32>}> : () -> tensor<1x1x1x1xf32>
// CHECK: %[[VAL_4:.*]] = tosa.mul %[[VAL_2]], %[[VAL_3]] {shift = 0 : i8} : (tensor<2x5x1x1xf32>, tensor<1x1x1x1xf32>) -> tensor<2x5x1x1xf32>
// CHECK: %[[SHIFT_0:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
// CHECK: %[[VAL_4:.*]] = tosa.mul %[[VAL_2]], %[[VAL_3]], %[[SHIFT_0]] : (tensor<2x5x1x1xf32>, tensor<1x1x1x1xf32>, tensor<1xi8>) -> tensor<2x5x1x1xf32>
// CHECK: return %[[VAL_4]] : tensor<2x5x1x1xf32>
}

Expand All @@ -26,7 +27,8 @@ return %0 : tensor<1x1x1x1xf32>
// CHECK: %[[VAL_3:.*]] = tosa.reduce_sum %[[VAL_2]] {axis = 2 : i32} : (tensor<1x1x9x11xf32>) -> tensor<1x1x1x11xf32>
// CHECK: %[[VAL_4:.*]] = tosa.reduce_sum %[[VAL_3]] {axis = 3 : i32} : (tensor<1x1x1x11xf32>) -> tensor<1x1x1x1xf32>
// CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<0.00101010106> : tensor<1x1x1x1xf32>}> : () -> tensor<1x1x1x1xf32>
// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_4]], %[[VAL_5]] {shift = 0 : i8} : (tensor<1x1x1x1xf32>, tensor<1x1x1x1xf32>) -> tensor<1x1x1x1xf32>
// CHECK: %[[SHIFT_1:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_4]], %[[VAL_5]], %[[SHIFT_1]] : (tensor<1x1x1x1xf32>, tensor<1x1x1x1xf32>, tensor<1xi8>) -> tensor<1x1x1x1xf32>
// CHECK: return %[[VAL_6]] : tensor<1x1x1x1xf32>
}

Expand All @@ -42,7 +44,8 @@ return %1 : tensor<2x5xf32>
// CHECK: %[[VAL_2:.*]] = tosa.reduce_sum %[[VAL_1]] {axis = 3 : i32} : (tensor<2x5x1x11xf32>) -> tensor<2x5x1x1xf32>
// CHECK: %[[VAL_3:.*]] = tosa.reshape %[[VAL_2]] {new_shape = array<i64: 2, 5>} : (tensor<2x5x1x1xf32>) -> tensor<2x5xf32>
// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<0.0101010101> : tensor<1x1xf32>}> : () -> tensor<1x1xf32>
// CHECK: %[[VAL_5:.*]] = tosa.mul %[[VAL_3]], %[[VAL_4]] {shift = 0 : i8} : (tensor<2x5xf32>, tensor<1x1xf32>) -> tensor<2x5xf32>
// CHECK: %[[SHIFT_2:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
// CHECK: %[[VAL_5:.*]] = tosa.mul %[[VAL_3]], %[[VAL_4]], %[[SHIFT_2]] : (tensor<2x5xf32>, tensor<1x1xf32>, tensor<1xi8>) -> tensor<2x5xf32>
// CHECK: return %[[VAL_5]] : tensor<2x5xf32>
}

Expand All @@ -57,7 +60,8 @@ return %1 : tensor<2x5x1x1xf32>
// CHECK: %[[VAL_1:.*]] = tosa.reduce_sum %[[VAL_0]] {axis = 2 : i32} : (tensor<2x5x9x11xf32>) -> tensor<2x5x1x11xf32>
// CHECK: %[[VAL_2:.*]] = tosa.reduce_sum %[[VAL_1]] {axis = 3 : i32} : (tensor<2x5x1x11xf32>) -> tensor<2x5x1x1xf32>
// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<0.0101010101> : tensor<1x1x1x1xf32>}> : () -> tensor<1x1x1x1xf32>
// CHECK: %[[VAL_4:.*]] = tosa.mul %[[VAL_2]], %[[VAL_3]] {shift = 0 : i8} : (tensor<2x5x1x1xf32>, tensor<1x1x1x1xf32>) -> tensor<2x5x1x1xf32>
// CHECK: %[[SHIFT_3:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
// CHECK: %[[VAL_4:.*]] = tosa.mul %[[VAL_2]], %[[VAL_3]], %[[SHIFT_3]] : (tensor<2x5x1x1xf32>, tensor<1x1x1x1xf32>, tensor<1xi8>) -> tensor<2x5x1x1xf32>
// CHECK: return %[[VAL_4]] : tensor<2x5x1x1xf32>
}

Expand All @@ -81,7 +85,8 @@ func.func @test_reducemeanV13(%arg0: tensor<1x32x112x112xf32>) -> tensor<1x32x1x
// CHECK: [[VAR_0_:%.+]] = tosa.reduce_sum %arg0 {axis = 2 : i32}
// CHECK-DAG: [[VAR_1_:%.+]] = tosa.reduce_sum [[VAR_0_]] {axis = 3 : i32}
// CHECK-DAG: [[VAR_2_:%.+]] = "tosa.const"() <{value = dense<7.97193861E-5> : tensor<1x1x1x1xf32>}>
// CHECK: [[VAR_3_:%.+]] = tosa.mul [[VAR_1_]], [[VAR_2_]] {shift = 0 : i8} : (tensor<1x32x1x1xf32>, tensor<1x1x1x1xf32>)
// CHECK: [[SHIFT_4_:%.+]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
// CHECK: [[VAR_3_:%.+]] = tosa.mul [[VAR_1_]], [[VAR_2_]], [[SHIFT_4_]] : (tensor<1x32x1x1xf32>, tensor<1x1x1x1xf32>, tensor<1xi8>)
// CHECK: return [[VAR_3_]] : tensor<1x32x1x1xf32>
}

Expand Down
8 changes: 4 additions & 4 deletions test/mlir/conversion/onnx_to_tosa/Math/Softmax.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ func.func @test_softmax_v13(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
// CHECK: %[[VAL_1:.*]] = tosa.exp %[[SUB]] : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
// CHECK: %[[VAL_2:.*]] = tosa.reduce_sum %[[VAL_1]] {axis = 2 : i32} : (tensor<13x21x3xf32>) -> tensor<13x21x1xf32>
// CHECK: %[[VAL_3:.*]] = tosa.reciprocal %[[VAL_2]] : (tensor<13x21x1xf32>) -> tensor<13x21x1xf32>
// CHECK: %[[VAL_4:.*]] = tosa.mul %[[VAL_1]], %[[VAL_3]] {shift = 0 : i8} : (tensor<13x21x3xf32>, tensor<13x21x1xf32>) -> tensor<13x21x3xf32>
// CHECK: %[[VAL_4:.*]] = tosa.mul %[[VAL_1]], %[[VAL_3]], {{.*}}: (tensor<13x21x3xf32>, tensor<13x21x1xf32>, tensor<1xi8>) -> tensor<13x21x3xf32>
}

// -----
Expand All @@ -23,7 +23,7 @@ func.func @test_softmax_v13_axis_one(%arg0: tensor<13x21x3xf32>) -> tensor<13x21
// CHECK: %[[VAL_1:.*]] = tosa.exp %[[SUB]] : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
// CHECK: %[[VAL_2:.*]] = tosa.reduce_sum %[[VAL_1]] {axis = 1 : i32} : (tensor<13x21x3xf32>) -> tensor<13x1x3xf32>
// CHECK: %[[VAL_3:.*]] = tosa.reciprocal %[[VAL_2]] : (tensor<13x1x3xf32>) -> tensor<13x1x3xf32>
// CHECK: %[[VAL_4:.*]] = tosa.mul %[[VAL_1]], %[[VAL_3]] {shift = 0 : i8} : (tensor<13x21x3xf32>, tensor<13x1x3xf32>) -> tensor<13x21x3xf32>
// CHECK: %[[VAL_4:.*]] = tosa.mul %[[VAL_1]], %[[VAL_3]], {{.*}}: (tensor<13x21x3xf32>, tensor<13x1x3xf32>, tensor<1xi8>) -> tensor<13x21x3xf32>
}

// -----
Expand All @@ -39,7 +39,7 @@ func.func @test_softmax_before_v13(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3
// CHECK: %[[VAL_2:.*]] = tosa.reduce_sum %[[VAL_1]] {axis = 1 : i32} : (tensor<13x21x3xf32>) -> tensor<13x1x3xf32>
// CHECK: %[[VAL_3:.*]] = tosa.reduce_sum %[[VAL_2]] {axis = 2 : i32} : (tensor<13x1x3xf32>) -> tensor<13x1x1xf32>
// CHECK: %[[VAL_4:.*]] = tosa.reciprocal %[[VAL_3]] : (tensor<13x1x1xf32>) -> tensor<13x1x1xf32>
// CHECK: %[[VAL_5:.*]] = tosa.mul %[[VAL_1]], %[[VAL_4]] {shift = 0 : i8} : (tensor<13x21x3xf32>, tensor<13x1x1xf32>) -> tensor<13x21x3xf32>
// CHECK: %[[VAL_5:.*]] = tosa.mul %[[VAL_1]], %[[VAL_4]], {{.*}}: (tensor<13x21x3xf32>, tensor<13x1x1xf32>, tensor<1xi8>) -> tensor<13x21x3xf32>
}

// -----
Expand All @@ -57,5 +57,5 @@ func.func @test_softmax_before_v13_axis_zero(%arg0: tensor<13x21x3xf32>) -> tens
// CHECK: %[[VAL_3:.*]] = tosa.reduce_sum %[[VAL_2]] {axis = 1 : i32} : (tensor<1x21x3xf32>) -> tensor<1x1x3xf32>
// CHECK: %[[VAL_4:.*]] = tosa.reduce_sum %[[VAL_3]] {axis = 2 : i32} : (tensor<1x1x3xf32>) -> tensor<1x1x1xf32>
// CHECK: %[[VAL_5:.*]] = tosa.reciprocal %[[VAL_4]] : (tensor<1x1x1xf32>) -> tensor<1x1x1xf32>
// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_1]], %[[VAL_5]] {shift = 0 : i8} : (tensor<13x21x3xf32>, tensor<1x1x1xf32>) -> tensor<13x21x3xf32>
// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_1]], %[[VAL_5]], {{.*}}: (tensor<13x21x3xf32>, tensor<1x1x1xf32>, tensor<1xi8>) -> tensor<13x21x3xf32>
}
Loading
Loading