diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index 3e99c1f717d09..1be5697955020 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -501,15 +501,21 @@ void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results, // Operator Folders. //===----------------------------------------------------------------------===// -template +template DenseElementsAttr binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs, RankedTensorType returnTy) { - if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) { - auto lETy = llvm::cast(lhs.getType()).getElementType(); - auto rETy = llvm::cast(rhs.getType()).getElementType(); - if (lETy != rETy) - return {}; + if (!rhs || !lhs) + return {}; + + auto lETy = llvm::cast(lhs.getType()).getElementType(); + auto rETy = llvm::cast(rhs.getType()).getElementType(); + if (lETy != rETy) + return {}; + + if (!lETy.isIntOrFloat()) + return {}; + if (rhs.isSplat() && lhs.isSplat()) { if (llvm::isa(lETy)) { APInt l = lhs.getSplatValue(); APInt r = rhs.getSplatValue(); @@ -525,9 +531,59 @@ DenseElementsAttr binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs, } } + auto lhsCount = lhs.getNumElements(); + auto rhsCount = rhs.getNumElements(); + if (lhsCount != rhsCount) + return {}; + + // to prevent long compile time, skip if too many elements + if (lhsCount > 128) + return {}; + + if (llvm::isa(lETy)) { + auto lvalues = lhs.getValues(); + auto rvalues = rhs.getValues(); + SmallVector results; + IntFolder intFolder{}; + for (const auto &[l, r] : llvm::zip(lvalues, rvalues)) { + auto result = intFolder(l, r); + results.push_back(result); + } + return DenseElementsAttr::get(returnTy, results); + } + + if (llvm::isa(lETy)) { + auto lvalues = lhs.getValues(); + auto rvalues = rhs.getValues(); + // FloatFolder() may return either APFloat or APInt (comparison functions) + SmallVector results; + FloatFolder floatFolder{}; + for (const auto &[l, r] : llvm::zip(lvalues, rvalues)) { + auto result = floatFolder(l, r); + results.push_back(result); + } + return DenseElementsAttr::get(returnTy, results); + } + return {}; } +template +DenseElementsAttr comparisonBinaryFolder(DenseElementsAttr lhs, + DenseElementsAttr rhs, + RankedTensorType returnTy) { + // comparison FloatFolder() functions return APInt values + return binaryFolder(lhs, rhs, returnTy); +} + +template +DenseElementsAttr arithmeticBinaryFolder(DenseElementsAttr lhs, + DenseElementsAttr rhs, + RankedTensorType returnTy) { + // arithmetic FloatFolder() functions return APFloat values + return binaryFolder(lhs, rhs, returnTy); +} + static bool isSplatZero(Type elemType, DenseElementsAttr val) { if (llvm::isa(elemType)) return val && val.isSplat() && val.getSplatValue().isZero(); @@ -574,8 +630,8 @@ OpFoldResult AddOp::fold(FoldAdaptor adaptor) { if (!lhsAttr || !rhsAttr) return {}; - return binaryFolder, std::plus>(lhsAttr, rhsAttr, - resultTy); + return arithmeticBinaryFolder, std::plus>( + lhsAttr, rhsAttr, resultTy); } OpFoldResult ArgMaxOp::fold(FoldAdaptor adaptor) { @@ -632,32 +688,96 @@ OpFoldResult IntDivOp::fold(FoldAdaptor adaptor) { } namespace { + +// calculate lhs * rhs >> shift according to TOSA Spec +// return nullopt if result is not in range of int32_t when shift > 0 +std::optional mulInt(APInt lhs, APInt rhs, int32_t shift, + unsigned bitwidth) { + APInt result = lhs.sext(64) * rhs.sext(64); + + if (shift > 0) { + auto round = APInt(64, 1) << (shift - 1); + result += round; + result.ashrInPlace(shift); + // REQUIRE(product >= minimum_s() && product <= maximum_s()) + if (!(result.getSExtValue() >= INT32_MIN && + result.getSExtValue() <= INT32_MAX)) { + // REQUIRE failed + return std::nullopt; + } + } + + return result.trunc(bitwidth); +} + DenseElementsAttr mulBinaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs, RankedTensorType ty, int32_t shift) { - if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) { - if (llvm::isa(ty.getElementType())) { - APInt l = lhs.getSplatValue(); - APInt r = rhs.getSplatValue(); + if (!lhs || !rhs) + return {}; + + // REQUIRE(0 <= shift && shift <= 63); + if (!(0 <= shift && shift <= 63)) + return {}; + + auto elementType = ty.getElementType(); + if (!elementType.isIntOrFloat()) + return {}; - if (shift == 0) { - return DenseElementsAttr::get(ty, l * r); + unsigned bitwidth = elementType.getIntOrFloatBitWidth(); + // REQUIRE(in_t == int32_t || shift == 0); + if (!((llvm::isa(elementType) && bitwidth == 32) || shift == 0)) + return {}; + + if (rhs.isSplat() && lhs.isSplat()) { + if (llvm::isa(elementType)) { + auto l = lhs.getSplatValue(); + auto r = rhs.getSplatValue(); + + if (auto result = mulInt(l, r, shift, bitwidth)) { + return DenseElementsAttr::get(ty, result.value()); } + // mulInt failed + return {}; + } - auto bitwidth = ty.getElementType().getIntOrFloatBitWidth(); - l = l.sext(bitwidth * 2); - r = r.sext(bitwidth * 2); + if (llvm::isa(elementType)) { + auto l = lhs.getSplatValue(); + auto r = rhs.getSplatValue(); auto result = l * r; - result.lshrInPlace(shift); - result = result.trunc(bitwidth); return DenseElementsAttr::get(ty, result); } + } - if (llvm::isa(ty.getElementType())) { - APFloat l = lhs.getSplatValue(); - APFloat r = rhs.getSplatValue(); - APFloat result = l * r; - return DenseElementsAttr::get(ty, result); + if (llvm::isa(elementType)) { + auto lvalues = lhs.getValues(); + auto rvalues = rhs.getValues(); + if (lvalues.size() != rvalues.size()) { + return {}; + } + SmallVector results; + for (const auto &[l, r] : llvm::zip(lvalues, rvalues)) { + if (auto result = mulInt(l, r, shift, bitwidth)) { + results.push_back(result.value()); + continue; + } + // mulInt failed + return {}; + } + return DenseElementsAttr::get(ty, results); + } + + if (llvm::isa(elementType)) { + auto lvalues = lhs.getValues(); + auto rvalues = rhs.getValues(); + if (lvalues.size() != rvalues.size()) { + return {}; } + SmallVector results; + for (const auto &[l, r] : llvm::zip(lvalues, rvalues)) { + auto result = l * r; + results.push_back(result); + } + return DenseElementsAttr::get(ty, results); } return {}; @@ -732,8 +852,8 @@ OpFoldResult SubOp::fold(FoldAdaptor adaptor) { if (!lhsAttr || !rhsAttr) return {}; - return binaryFolder, std::minus>(lhsAttr, rhsAttr, - resultTy); + return arithmeticBinaryFolder, std::minus>( + lhsAttr, rhsAttr, resultTy); } namespace { @@ -774,7 +894,8 @@ OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) { if (!lhsAttr || !rhsAttr) return {}; - return binaryFolder>>( + return comparisonBinaryFolder>>( lhsAttr, rhsAttr, resultTy); } @@ -788,8 +909,8 @@ OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) { if (!lhsAttr || !rhsAttr) return {}; - return binaryFolder>>( + return comparisonBinaryFolder>>( lhsAttr, rhsAttr, resultTy); } @@ -813,9 +934,9 @@ OpFoldResult EqualOp::fold(FoldAdaptor adaptor) { if (!lhsAttr || !rhsAttr) return {}; - return binaryFolder>, - ComparisonFold>>(lhsAttr, rhsAttr, - resultTy); + return comparisonBinaryFolder>, + ComparisonFold>>( + lhsAttr, rhsAttr, resultTy); } OpFoldResult CastOp::fold(FoldAdaptor adaptor) { diff --git a/mlir/test/Dialect/Tosa/constant-op-fold.mlir b/mlir/test/Dialect/Tosa/constant-op-fold.mlir index 8ac1e177ae4d4..2f396519b4420 100644 --- a/mlir/test/Dialect/Tosa/constant-op-fold.mlir +++ b/mlir/test/Dialect/Tosa/constant-op-fold.mlir @@ -1082,11 +1082,8 @@ func.func @reduce_sum_constant() -> tensor<1x3xi32> { func.func @reduce_sum_constant() -> tensor<1x3xi32> { // CHECK-LABEL: func.func @reduce_sum_constant() -> tensor<1x3xi32> { - // CHECK: %[[VAL_0:.*]] = "tosa.const"() <{values = dense<{{\[\[}}1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>}> : () -> tensor<2x3xi32> - // CHECK: %[[VAL_1:.*]] = "tosa.const"() <{values = dense<{{\[\[}}1, 2, 3], [4, 5, 7]]> : tensor<2x3xi32>}> : () -> tensor<2x3xi32> - // CHECK: %[[VAL_2:.*]] = tosa.add %[[VAL_0]], %[[VAL_1]] : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> - // CHECK: %[[VAL_3:.*]] = tosa.reduce_sum %[[VAL_2]] {axis = 0 : i32} : (tensor<2x3xi32>) -> tensor<1x3xi32> - // CHECK: return %[[VAL_3]] : tensor<1x3xi32> + // CHECK: %[[K:.*]] = "tosa.const"() <{values = dense<{{\[\[}}10, 14, 19]]> : tensor<1x3xi32>}> : () -> tensor<1x3xi32> + // CHECK: return %[[K]] : tensor<1x3xi32> %arg0 = "tosa.const"() <{values = dense<[[1,2,3], [4,5,6]]> : tensor<2x3xi32>}> : () -> tensor<2x3xi32> %arg1 = "tosa.const"() <{values = dense<[[1,2,3], [4,5,7]]> : tensor<2x3xi32>}> : () -> tensor<2x3xi32> %arg2 = tosa.add %arg0, %arg1 : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> diff --git a/mlir/test/Dialect/Tosa/constant_folding.mlir b/mlir/test/Dialect/Tosa/constant_folding.mlir index 9b6ccdb54c107..8ae8af75c3856 100644 --- a/mlir/test/Dialect/Tosa/constant_folding.mlir +++ b/mlir/test/Dialect/Tosa/constant_folding.mlir @@ -1,4 +1,6 @@ -// RUN: mlir-opt --test-constant-fold %s | FileCheck %s +// RUN: mlir-opt --split-input-file --canonicalize --test-constant-fold %s | FileCheck %s + +// ----- // CHECK-LABEL: func @test_const func.func @test_const(%arg0 : index) -> tensor<4xi32> { @@ -7,6 +9,8 @@ func.func @test_const(%arg0 : index) -> tensor<4xi32> { return %0 : tensor<4xi32> } +// ----- + // CHECK-LABEL: func @test_const_i64 func.func @test_const_i64(%arg0 : index) -> tensor<4xi64> { // CHECK: tosa.const @@ -14,10 +18,218 @@ func.func @test_const_i64(%arg0 : index) -> tensor<4xi64> { return %0 : tensor<4xi64> } +// ----- + // CHECK-LABEL: func @try_fold_equal_with_unranked_tensor -func.func @try_fold_equal_with_unranked_tensor(%arg0: tensor<4xi32>, %arg1: tensor<1xi32>) { +func.func @try_fold_equal_with_unranked_tensor(%arg0: tensor<4xi32>, %arg1: tensor<1xi32>) -> tensor<*xi1> { // CHECK: tosa.equal // CHECK-NEXT: return %0 = tosa.equal %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi1> - return + return %0 : tensor<*xi1> +} + +// ----- + +// CHECK-LABEL: test_mul_i32 +func.func @test_mul_i32() -> tensor<4xi32> { + // CHECK: %[[VAL:.+]] = "tosa.const"() <{values = dense<[9, 36, 36, 81]> : tensor<4xi32>}> + // CHECK: return %[[VAL]] + %lhs = "tosa.const"() {values = dense<[1, 2, -2, -3]> : tensor<4xi32>} : () -> tensor<4xi32> + %rhs = "tosa.const"() {values = dense<3> : tensor<4xi32>} : () -> tensor<4xi32> + %shift = "tosa.const"() { values = dense<0> : tensor<1xi8> } : () -> tensor<1xi8> + %x = tosa.mul %lhs, %rhs, %shift : (tensor<4xi32>, tensor<4xi32>, tensor<1xi8>) -> tensor<4xi32> + %y = tosa.mul %rhs, %lhs, %shift : (tensor<4xi32>, tensor<4xi32>, tensor<1xi8>) -> tensor<4xi32> + %result = tosa.mul %x, %y, %shift : (tensor<4xi32>, tensor<4xi32>, tensor<1xi8>) -> tensor<4xi32> + + return %result : tensor<4xi32> +} + +// ----- + +// CHECK-LABEL: test_mul_i32_shift +func.func @test_mul_i32_shift() -> tensor<4xi32> { + // CHECK: %[[VAL:.+]] = "tosa.const"() <{values = dense<[2550, 8100, 2, 2025]> : tensor<4xi32>}> + // CHECK: return %[[VAL]] + %lhs = "tosa.const"() {values = dense<[135, 240, -4, -120]> : tensor<4xi32>} : () -> tensor<4xi32> + %rhs = "tosa.const"() {values = dense<3> : tensor<4xi32>} : () -> tensor<4xi32> + %shift = "tosa.const"() { values = dense<2> : tensor<1xi8> } : () -> tensor<1xi8> + %x = tosa.mul %lhs, %rhs, %shift : (tensor<4xi32>, tensor<4xi32>, tensor<1xi8>) -> tensor<4xi32> + %y = tosa.mul %rhs, %lhs, %shift : (tensor<4xi32>, tensor<4xi32>, tensor<1xi8>) -> tensor<4xi32> + %result = tosa.mul %x, %y, %shift : (tensor<4xi32>, tensor<4xi32>, tensor<1xi8>) -> tensor<4xi32> + return %result : tensor<4xi32> +} + +// ----- + +// CHECK-LABEL: test_mul_f32 +func.func @test_mul_f32() -> tensor<4xf32> { + // CHECK: %[[VAL:.+]] = "tosa.const"() <{values = dense<[2.304000e+01, 58.9824028, 1.6384002, 14.7456007]> : tensor<4xf32>}> + // CHECK: return %[[VAL]] + %lhs = "tosa.const"() {values = dense<[1.5, 2.4, -0.4, -1.2]> : tensor<4xf32>} : () -> tensor<4xf32> + %rhs = "tosa.const"() {values = dense<3.2> : tensor<4xf32>} : () -> tensor<4xf32> + %shift = "tosa.const"() { values = dense<0> : tensor<1xi8> } : () -> tensor<1xi8> + %x = tosa.mul %lhs, %rhs, %shift : (tensor<4xf32>, tensor<4xf32>, tensor<1xi8>) -> tensor<4xf32> + %y = tosa.mul %rhs, %lhs, %shift : (tensor<4xf32>, tensor<4xf32>, tensor<1xi8>) -> tensor<4xf32> + %result = tosa.mul %x, %y, %shift : (tensor<4xf32>, tensor<4xf32>, tensor<1xi8>) -> tensor<4xf32> + return %result : tensor<4xf32> +} + +// ----- + +// CHECK-LABEL: test_add_f32 +func.func @test_add_f32() -> tensor<4xf32> { + // CHECK: %[[VAL:.+]] = "tosa.const"() <{values = dense<[7.500000e+00, 9.300000e+00, 3.69999981, 2.100000e+00]> : tensor<4xf32>}> + // CHECK: return %[[VAL]] + %cst = "tosa.const"() {values = dense<[1.5, 2.4, -0.4, -1.2]> : tensor<4xf32>} : () -> tensor<4xf32> + %splat1 = "tosa.const"() {values = dense<3.2> : tensor<4xf32>} : () -> tensor<4xf32> + %splat2 = "tosa.const"() {values = dense<1.3> : tensor<4xf32>} : () -> tensor<4xf32> + %x = tosa.add %cst, %splat1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + %y = tosa.add %splat2, %cst : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + %result = tosa.add %x, %y : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + return %result : tensor<4xf32> +} + +// ----- + +// CHECK-LABEL: test_add_i32 +func.func @test_add_i32() -> tensor<4xi32> { + // CHECK: %[[VAL:.+]] = "tosa.const"() <{values = dense<[75, 93, 37, 21]> : tensor<4xi32>}> + // CHECK: return %[[VAL]] + %cst = "tosa.const"() {values = dense<[15, 24, -4, -12]> : tensor<4xi32>} : () -> tensor<4xi32> + %splat1 = "tosa.const"() {values = dense<32> : tensor<4xi32>} : () -> tensor<4xi32> + %splat2 = "tosa.const"() {values = dense<13> : tensor<4xi32>} : () -> tensor<4xi32> + %x = tosa.add %cst, %splat1 : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + %y = tosa.add %splat2, %cst : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + %result = tosa.add %x, %y : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + return %result : tensor<4xi32> +} + +// ----- + +// CHECK-LABEL: test_sub_f32 +func.func @test_sub_f32() -> tensor<4xf32> { + // CHECK: %[[VAL:.+]] = "tosa.const"() <{values = dense<[-1.500000e+00, 0.300000191, -5.300000e+00, -6.900000e+00]> : tensor<4xf32>}> + // CHECK: return %[[VAL]] + %cst = "tosa.const"() {values = dense<[1.5, 2.4, -0.4, -1.2]> : tensor<4xf32>} : () -> tensor<4xf32> + %splat1 = "tosa.const"() {values = dense<3.2> : tensor<4xf32>} : () -> tensor<4xf32> + %splat2 = "tosa.const"() {values = dense<1.3> : tensor<4xf32>} : () -> tensor<4xf32> + %x = tosa.sub %cst, %splat1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + %y = tosa.sub %splat2, %cst : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + %result = tosa.sub %x, %y : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + return %result : tensor<4xf32> +} + +// ----- + +// CHECK-LABEL: test_sub_i32 +func.func @test_sub_i32() -> tensor<4xi32> { + // CHECK: %[[VAL:.+]] = "tosa.const"() <{values = dense<[-15, 3, -53, -69]> : tensor<4xi32>}> + // CHECK: return %[[VAL]] + %cst = "tosa.const"() {values = dense<[15, 24, -4, -12]> : tensor<4xi32>} : () -> tensor<4xi32> + %splat1 = "tosa.const"() {values = dense<32> : tensor<4xi32>} : () -> tensor<4xi32> + %splat2 = "tosa.const"() {values = dense<13> : tensor<4xi32>} : () -> tensor<4xi32> + %x = tosa.sub %cst, %splat1 : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + %y = tosa.sub %splat2, %cst : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + %result = tosa.sub %x, %y : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + return %result : tensor<4xi32> +} + +// ----- + +// CHECK-LABEL: test_greater_f32 +func.func @test_greater_f32() -> (tensor<4xi1>, tensor<4xi1>, tensor<4xi1>) { + // CHECK: %[[VAL_0:.+]] = "tosa.const"() <{values = dense<[false, true, false, false]> : tensor<4xi1>}> : () -> tensor<4xi1> + // CHECK: %[[VAL_1:.+]] = "tosa.const"() <{values = dense<[false, false, true, true]> : tensor<4xi1>}> : () -> tensor<4xi1> + // CHECK: %[[VAL_2:.+]] = "tosa.const"() <{values = dense<[false, true, true, false]> : tensor<4xi1>}> : () -> tensor<4xi1> + // CHECK: return %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] + %cst1 = "tosa.const"() {values = dense<[1.5, 2.4, -0.4, -1.2]> : tensor<4xf32>} : () -> tensor<4xf32> + %splat = "tosa.const"() {values = dense<1.5> : tensor<4xf32>} : () -> tensor<4xf32> + %cst2 = "tosa.const"() {values = dense<[1.7, 2.3, -0.5, -1.1]> : tensor<4xf32>} : () -> tensor<4xf32> + %x = tosa.greater %cst1, %splat : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> + %y = tosa.greater %splat, %cst1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> + %z = tosa.greater %cst1, %cst2 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> + return %x, %y, %z : tensor<4xi1>, tensor<4xi1>, tensor<4xi1> +} + +// ----- + +// CHECK-LABEL: test_greater_i32 +func.func @test_greater_i32() -> (tensor<4xi1>, tensor<4xi1>, tensor<4xi1>) { + // CHECK: %[[VAL_0:.+]] = "tosa.const"() <{values = dense<[false, true, false, false]> : tensor<4xi1>}> : () -> tensor<4xi1> + // CHECK: %[[VAL_1:.+]] = "tosa.const"() <{values = dense<[false, false, true, true]> : tensor<4xi1>}> : () -> tensor<4xi1> + // CHECK: %[[VAL_2:.+]] = "tosa.const"() <{values = dense<[false, true, true, false]> : tensor<4xi1>}> : () -> tensor<4xi1> + // CHECK: return %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] + %cst1 = "tosa.const"() {values = dense<[15, 24, -4, -12]> : tensor<4xi32>} : () -> tensor<4xi32> + %cst2 = "tosa.const"() {values = dense<[17, 23, -5, -11]> : tensor<4xi32>} : () -> tensor<4xi32> + %splat = "tosa.const"() {values = dense<15> : tensor<4xi32>} : () -> tensor<4xi32> + %x = tosa.greater %cst1, %splat : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> + %y = tosa.greater %splat, %cst1 : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> + %z = tosa.greater %cst1, %cst2 : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> + return %x, %y, %z : tensor<4xi1>, tensor<4xi1>, tensor<4xi1> +} + +// ----- + +// CHECK-LABEL: test_greater_equal_f32 +func.func @test_greater_equal_f32() -> (tensor<4xi1>, tensor<4xi1>, tensor<4xi1>) { + // CHECK: %[[VAL_0:.+]] = "tosa.const"() <{values = dense<[true, true, false, false]> : tensor<4xi1>}> : () -> tensor<4xi1> + // CHECK: %[[VAL_1:.+]] = "tosa.const"() <{values = dense<[true, false, true, true]> : tensor<4xi1>}> : () -> tensor<4xi1> + // CHECK: %[[VAL_2:.+]] = "tosa.const"() <{values = dense<[true, true, true, false]> : tensor<4xi1>}> : () -> tensor<4xi1> + // CHECK: return %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] + %cst1 = "tosa.const"() {values = dense<[1.5, 2.4, -0.4, -1.2]> : tensor<4xf32>} : () -> tensor<4xf32> + %splat = "tosa.const"() {values = dense<1.5> : tensor<4xf32>} : () -> tensor<4xf32> + %cst2 = "tosa.const"() {values = dense<[1.4, 2.4, -0.5, -1.1]> : tensor<4xf32>} : () -> tensor<4xf32> + %x = tosa.greater_equal %cst1, %splat : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> + %y = tosa.greater_equal %splat, %cst1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> + %z = tosa.greater_equal %cst1, %cst2 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> + return %x, %y, %z : tensor<4xi1>, tensor<4xi1>, tensor<4xi1> +} + +// ----- + +// CHECK-LABEL: test_greater_equal_i32 +func.func @test_greater_equal_i32() -> (tensor<4xi1>, tensor<4xi1>, tensor<4xi1>) { + // CHECK: %[[VAL_0:.+]] = "tosa.const"() <{values = dense<[false, true, false, false]> : tensor<4xi1>}> : () -> tensor<4xi1> + // CHECK: %[[VAL_1:.+]] = "tosa.const"() <{values = dense<[true, false, true, true]> : tensor<4xi1>}> : () -> tensor<4xi1> + // CHECK: %[[VAL_2:.+]] = "tosa.const"() <{values = dense<[true, true, true, false]> : tensor<4xi1>}> : () -> tensor<4xi1> + // CHECK: return %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] + %cst1 = "tosa.const"() {values = dense<[15, 24, -4, -12]> : tensor<4xi32>} : () -> tensor<4xi32> + %splat = "tosa.const"() {values = dense<16> : tensor<4xi32>} : () -> tensor<4xi32> + %cst2 = "tosa.const"() {values = dense<[14, 24, -5, -11]> : tensor<4xi32>} : () -> tensor<4xi32> + %x = tosa.greater_equal %cst1, %splat : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> + %y = tosa.greater_equal %splat, %cst1 : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> + %z = tosa.greater_equal %cst1, %cst2 : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> + return %x, %y, %z : tensor<4xi1>, tensor<4xi1>, tensor<4xi1> +} + +// ----- + +// CHECK-LABEL: test_equal_f32 +func.func @test_equal_f32() -> (tensor<4xi1>, tensor<4xi1>, tensor<4xi1>) { + // CHECK: %[[VAL_0:.+]] = "tosa.const"() <{values = dense<[true, false, false, false]> : tensor<4xi1>}> : () -> tensor<4xi1> + // CHECK: %[[VAL_1:.+]] = "tosa.const"() <{values = dense<[false, true, false, true]> : tensor<4xi1>}> : () -> tensor<4xi1> + // CHECK: return %[[VAL_0]], %[[VAL_0]], %[[VAL_1]] + %cst1 = "tosa.const"() {values = dense<[1.5, 2.4, -0.4, -1.2]> : tensor<4xf32>} : () -> tensor<4xf32> + %splat = "tosa.const"() {values = dense<1.5> : tensor<4xf32>} : () -> tensor<4xf32> + %cst2 = "tosa.const"() {values = dense<[1.4, 2.4, -0.5, -1.2]> : tensor<4xf32>} : () -> tensor<4xf32> + %x = tosa.equal %cst1, %splat : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> + %y = tosa.equal %splat, %cst1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> + %z = tosa.equal %cst1, %cst2 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> + return %x, %y, %z : tensor<4xi1>, tensor<4xi1>, tensor<4xi1> +} + +// ----- + +// CHECK-LABEL: test_equal_i32 +func.func @test_equal_i32() -> (tensor<4xi1>, tensor<4xi1>, tensor<4xi1>) { + // CHECK: %[[VAL_0:.+]] = "tosa.const"() <{values = dense<[true, false, false, false]> : tensor<4xi1>}> : () -> tensor<4xi1> + // CHECK: %[[VAL_1:.+]] = "tosa.const"() <{values = dense<[false, true, false, true]> : tensor<4xi1>}> : () -> tensor<4xi1> + // CHECK: return %[[VAL_0]], %[[VAL_0]], %[[VAL_1]] + %cst1 = "tosa.const"() {values = dense<[15, 24, -4, -12]> : tensor<4xi32>} : () -> tensor<4xi32> + %splat = "tosa.const"() {values = dense<15> : tensor<4xi32>} : () -> tensor<4xi32> + %cst2 = "tosa.const"() {values = dense<[14, 24, -5, -12]> : tensor<4xi32>} : () -> tensor<4xi32> + %x = tosa.equal %cst1, %splat : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> + %y = tosa.equal %splat, %cst1 : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> + %z = tosa.equal %cst1, %cst2 : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> + return %x, %y, %z : tensor<4xi1>, tensor<4xi1>, tensor<4xi1> }