From a8fd688428cc8751b658dcbc56726d01e9338ced Mon Sep 17 00:00:00 2001 From: Luke Hutton Date: Mon, 4 Nov 2024 08:54:36 +0000 Subject: [PATCH] [mlir][tosa] Add folder for multiply like reduce_prod operation This commit uses mulBinaryFolder for reduce_prod operations that have a constant 1D input of two values. Change-Id: Icb234282c70898189083231506ed38a3ab40efb2 Signed-off-by: Luke Hutton --- .../Dialect/Tosa/IR/TosaCanonicalizations.cpp | 30 ++++++++++++++++- mlir/test/Dialect/Tosa/canonicalize.mlir | 33 +++++++++++++++++++ 2 files changed, 62 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index 3e99c1f717d09..05c90a3371bb5 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -902,10 +902,38 @@ REDUCE_FOLDER(ReduceAllOp) REDUCE_FOLDER(ReduceAnyOp) REDUCE_FOLDER(ReduceMaxOp) REDUCE_FOLDER(ReduceMinOp) -REDUCE_FOLDER(ReduceProductOp) REDUCE_FOLDER(ReduceSumOp) #undef REDUCE_FOLDER +OpFoldResult ReduceProductOp::fold(FoldAdaptor adaptor) { + auto inputTy = llvm::cast(getInput().getType()); + if (!inputTy.hasRank()) + return {}; + + if (inputTy == getType() && + (inputTy.getRank() == 0 || inputTy.getDimSize(getAxis()) == 1)) + return getInput(); + + if (inputTy.getRank() != 1 || inputTy.getDimSize(0) != 2) + return {}; + + // inputTy has shape { 2 } : try folding reduce_product using mulBinaryFolder + const auto resultTy = llvm::dyn_cast(getType()); + if (!resultTy) + return {}; + + const auto elements = + llvm::dyn_cast_if_present(adaptor.getInput()); + if (!elements) + return {}; + + const auto lhsAttr = + DenseElementsAttr::get(resultTy, {elements.getValues()[0]}); + const auto rhsAttr = + DenseElementsAttr::get(resultTy, {elements.getValues()[1]}); + return mulBinaryFolder(lhsAttr, rhsAttr, resultTy, 0); +} + OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) { auto inputTy = llvm::dyn_cast(getInput1().getType()); auto outputTy = llvm::dyn_cast(getType()); diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir index 4242f68609634..c1537c1267042 100644 --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -1040,3 +1040,36 @@ func.func @do_not_fold_int_div_division_by_0() -> tensor<1x24x2xi32> { %16 = tosa.int_div %4, %1 : (tensor<1x24x2xi32>, tensor<1x24x2xi32>) -> tensor<1x24x2xi32> return %16 : tensor<1x24x2xi32> } + +// ----- + +// CHECK-LABEL: @fold_reduce_prod_is_mul +func.func @fold_reduce_prod_is_mul() -> tensor<1xi32> { + // CHECK-DAG: %[[VAL_0:.*]] = "tosa.const"() <{values = dense<77> : tensor<1xi32>}> : () -> tensor<1xi32> + // CHECK: return %[[VAL_0]] : tensor<1xi32> + %0 = "tosa.const"() <{values = dense<[1, 77]> : tensor<2xi32>}> : () -> tensor<2xi32> + %1 = "tosa.reduce_product"(%0) <{axis = 0 : i32}> : (tensor<2xi32>) -> tensor<1xi32> + return %1 : tensor<1xi32> +} + +// ----- + +// CHECK-LABEL: @no_fold_reduce_prod_rank_2 +func.func @no_fold_reduce_prod_rank_2() -> tensor<1x1xi32> { + // check that reduce_product folding does not happen for input with rank > 1 + // CHECK: tosa.reduce_product + %0 = "tosa.const"() <{values = dense<[[1, 77]]> : tensor<1x2xi32>}> : () -> tensor<1x2xi32> + %1 = "tosa.reduce_product"(%0) <{axis = 1 : i32}> : (tensor<1x2xi32>) -> tensor<1x1xi32> + return %1 : tensor<1x1xi32> +} + +// ----- + +// CHECK-LABEL: @no_fold_reduce_prod_dim_3 +func.func @no_fold_reduce_prod_dim_3() -> tensor<1xi32> { + // check that reduce_product folding does not happen for input with dim[0] != 2 + // CHECK: tosa.reduce_product + %0 = "tosa.const"() <{values = dense<[1, 77, 1]> : tensor<3xi32>}> : () -> tensor<3xi32> + %1 = "tosa.reduce_product"(%0) <{axis = 0 : i32}> : (tensor<3xi32>) -> tensor<1xi32> + return %1 : tensor<1xi32> +}