Skip to content

Commit 835c8db

Browse files
lhutton1Tai78641
authored andcommitted
[mlir][tosa] Add folder for multiply like reduce_prod operation
This commit uses mulBinaryFolder for reduce_prod operations that have a constant 1D input of two values. Change-Id: Icb234282c70898189083231506ed38a3ab40efb2 Signed-off-by: Luke Hutton <[email protected]>
1 parent 1c4e0f6 commit 835c8db

File tree

2 files changed

+62
-1
lines changed

2 files changed

+62
-1
lines changed

mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -949,10 +949,38 @@ REDUCE_FOLDER(ReduceAllOp)
949949
REDUCE_FOLDER(ReduceAnyOp)
950950
REDUCE_FOLDER(ReduceMaxOp)
951951
REDUCE_FOLDER(ReduceMinOp)
952-
REDUCE_FOLDER(ReduceProductOp)
953952
REDUCE_FOLDER(ReduceSumOp)
954953
#undef REDUCE_FOLDER
955954

955+
OpFoldResult ReduceProductOp::fold(FoldAdaptor adaptor) {
956+
auto inputTy = llvm::cast<ShapedType>(getInput().getType());
957+
if (!inputTy.hasRank())
958+
return {};
959+
960+
if (inputTy == getType() &&
961+
(inputTy.getRank() == 0 || inputTy.getDimSize(getAxis()) == 1))
962+
return getInput();
963+
964+
if (inputTy.getRank() != 1 || inputTy.getDimSize(0) != 2)
965+
return {};
966+
967+
// inputTy has shape { 2 } : try folding reduce_product using mulBinaryFolder
968+
const auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
969+
if (!resultTy)
970+
return {};
971+
972+
const auto elements =
973+
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput());
974+
if (!elements)
975+
return {};
976+
977+
const auto lhsAttr =
978+
DenseElementsAttr::get(resultTy, {elements.getValues<Attribute>()[0]});
979+
const auto rhsAttr =
980+
DenseElementsAttr::get(resultTy, {elements.getValues<Attribute>()[1]});
981+
return mulBinaryFolder(lhsAttr, rhsAttr, resultTy, 0);
982+
}
983+
956984
OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
957985
auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
958986
auto outputTy = llvm::dyn_cast<RankedTensorType>(getType());

mlir/test/Dialect/Tosa/canonicalize.mlir

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1034,3 +1034,36 @@ func.func @do_not_fold_int_div_division_by_0() -> tensor<1x24x2xi32> {
10341034
%16 = tosa.int_div %4, %1 : (tensor<1x24x2xi32>, tensor<1x24x2xi32>) -> tensor<1x24x2xi32>
10351035
return %16 : tensor<1x24x2xi32>
10361036
}
1037+
1038+
// -----
1039+
1040+
// CHECK-LABEL: @fold_reduce_prod_is_mul
1041+
func.func @fold_reduce_prod_is_mul() -> tensor<1xi32> {
1042+
// CHECK-DAG: %[[VAL_0:.*]] = "tosa.const"() <{value = dense<77> : tensor<1xi32>}> : () -> tensor<1xi32>
1043+
// CHECK: return %[[VAL_0]] : tensor<1xi32>
1044+
%0 = "tosa.const"() <{value = dense<[1, 77]> : tensor<2xi32>}> : () -> tensor<2xi32>
1045+
%1 = "tosa.reduce_product"(%0) <{axis = 0 : i32}> : (tensor<2xi32>) -> tensor<1xi32>
1046+
return %1 : tensor<1xi32>
1047+
}
1048+
1049+
// -----
1050+
1051+
// CHECK-LABEL: @no_fold_reduce_prod_rank_2
1052+
func.func @no_fold_reduce_prod_rank_2() -> tensor<1x1xi32> {
1053+
// check that reduce_product folding does not happen for input with rank > 1
1054+
// CHECK: tosa.reduce_product
1055+
%0 = "tosa.const"() <{value = dense<[[1, 77]]> : tensor<1x2xi32>}> : () -> tensor<1x2xi32>
1056+
%1 = "tosa.reduce_product"(%0) <{axis = 1 : i32}> : (tensor<1x2xi32>) -> tensor<1x1xi32>
1057+
return %1 : tensor<1x1xi32>
1058+
}
1059+
1060+
// -----
1061+
1062+
// CHECK-LABEL: @no_fold_reduce_prod_dim_3
1063+
func.func @no_fold_reduce_prod_dim_3() -> tensor<1xi32> {
1064+
// check that reduce_product folding does not happen for input with dim[0] != 2
1065+
// CHECK: tosa.reduce_product
1066+
%0 = "tosa.const"() <{value = dense<[1, 77, 1]> : tensor<3xi32>}> : () -> tensor<3xi32>
1067+
%1 = "tosa.reduce_product"(%0) <{axis = 0 : i32}> : (tensor<3xi32>) -> tensor<1xi32>
1068+
return %1 : tensor<1xi32>
1069+
}

0 commit comments

Comments
 (0)