Skip to content

Commit b2026c2

Browse files
lhutton1Tai78641
authored andcommitted
[mlir][tosa] Add folder for multiply like reduce_prod operation
This commit uses mulBinaryFolder for reduce_prod operations that have a constant 1D input of two values. Change-Id: Icb234282c70898189083231506ed38a3ab40efb2 Signed-off-by: Luke Hutton <[email protected]>
1 parent b19ed9c commit b2026c2

File tree

2 files changed

+62
-1
lines changed

2 files changed

+62
-1
lines changed

mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -902,10 +902,38 @@ REDUCE_FOLDER(ReduceAllOp)
902902
REDUCE_FOLDER(ReduceAnyOp)
903903
REDUCE_FOLDER(ReduceMaxOp)
904904
REDUCE_FOLDER(ReduceMinOp)
905-
REDUCE_FOLDER(ReduceProductOp)
906905
REDUCE_FOLDER(ReduceSumOp)
907906
#undef REDUCE_FOLDER
908907

908+
OpFoldResult ReduceProductOp::fold(FoldAdaptor adaptor) {
909+
auto inputTy = llvm::cast<ShapedType>(getInput().getType());
910+
if (!inputTy.hasRank())
911+
return {};
912+
913+
if (inputTy == getType() &&
914+
(inputTy.getRank() == 0 || inputTy.getDimSize(getAxis()) == 1))
915+
return getInput();
916+
917+
if (inputTy.getRank() != 1 || inputTy.getDimSize(0) != 2)
918+
return {};
919+
920+
// inputTy has shape { 2 } : try folding reduce_product using mulBinaryFolder
921+
const auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
922+
if (!resultTy)
923+
return {};
924+
925+
const auto elements =
926+
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput());
927+
if (!elements)
928+
return {};
929+
930+
const auto lhsAttr =
931+
DenseElementsAttr::get(resultTy, {elements.getValues<Attribute>()[0]});
932+
const auto rhsAttr =
933+
DenseElementsAttr::get(resultTy, {elements.getValues<Attribute>()[1]});
934+
return mulBinaryFolder(lhsAttr, rhsAttr, resultTy, 0);
935+
}
936+
909937
OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
910938
auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
911939
auto outputTy = llvm::dyn_cast<RankedTensorType>(getType());

mlir/test/Dialect/Tosa/canonicalize.mlir

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1040,3 +1040,36 @@ func.func @do_not_fold_int_div_division_by_0() -> tensor<1x24x2xi32> {
10401040
%16 = tosa.int_div %4, %1 : (tensor<1x24x2xi32>, tensor<1x24x2xi32>) -> tensor<1x24x2xi32>
10411041
return %16 : tensor<1x24x2xi32>
10421042
}
1043+
1044+
// -----
1045+
1046+
// CHECK-LABEL: @fold_reduce_prod_is_mul
1047+
func.func @fold_reduce_prod_is_mul() -> tensor<1xi32> {
1048+
// CHECK-DAG: %[[VAL_0:.*]] = "tosa.const"() <{values = dense<77> : tensor<1xi32>}> : () -> tensor<1xi32>
1049+
// CHECK: return %[[VAL_0]] : tensor<1xi32>
1050+
%0 = "tosa.const"() <{values = dense<[1, 77]> : tensor<2xi32>}> : () -> tensor<2xi32>
1051+
%1 = "tosa.reduce_product"(%0) <{axis = 0 : i32}> : (tensor<2xi32>) -> tensor<1xi32>
1052+
return %1 : tensor<1xi32>
1053+
}
1054+
1055+
// -----
1056+
1057+
// CHECK-LABEL: @no_fold_reduce_prod_rank_2
1058+
func.func @no_fold_reduce_prod_rank_2() -> tensor<1x1xi32> {
1059+
// check that reduce_product folding does not happen for input with rank > 1
1060+
// CHECK: tosa.reduce_product
1061+
%0 = "tosa.const"() <{values = dense<[[1, 77]]> : tensor<1x2xi32>}> : () -> tensor<1x2xi32>
1062+
%1 = "tosa.reduce_product"(%0) <{axis = 1 : i32}> : (tensor<1x2xi32>) -> tensor<1x1xi32>
1063+
return %1 : tensor<1x1xi32>
1064+
}
1065+
1066+
// -----
1067+
1068+
// CHECK-LABEL: @no_fold_reduce_prod_dim_3
1069+
func.func @no_fold_reduce_prod_dim_3() -> tensor<1xi32> {
1070+
// check that reduce_product folding does not happen for input with dim[0] != 2
1071+
// CHECK: tosa.reduce_product
1072+
%0 = "tosa.const"() <{values = dense<[1, 77, 1]> : tensor<3xi32>}> : () -> tensor<3xi32>
1073+
%1 = "tosa.reduce_product"(%0) <{axis = 0 : i32}> : (tensor<3xi32>) -> tensor<1xi32>
1074+
return %1 : tensor<1xi32>
1075+
}

0 commit comments

Comments
 (0)