Skip to content

Commit 62b2409

Browse files
lhutton1Tai78641
authored andcommitted
[mlir][TOSA] Add folder for multiply like reduce_prod operation
This commit uses mulBinaryFolder for reduce_prod operations that have a constant 1D input of two values. Change-Id: Icb234282c70898189083231506ed38a3ab40efb2 Signed-off-by: Luke Hutton <[email protected]>
1 parent d1dde17 commit 62b2409

File tree

2 files changed

+40
-1
lines changed

2 files changed

+40
-1
lines changed

mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -963,10 +963,38 @@ REDUCE_FOLDER(ReduceAllOp)
963963
REDUCE_FOLDER(ReduceAnyOp)
964964
REDUCE_FOLDER(ReduceMaxOp)
965965
REDUCE_FOLDER(ReduceMinOp)
966-
REDUCE_FOLDER(ReduceProdOp)
967966
REDUCE_FOLDER(ReduceSumOp)
968967
#undef REDUCE_FOLDER
969968

969+
OpFoldResult ReduceProdOp::fold(FoldAdaptor adaptor) {
970+
ShapedType inputTy = llvm::cast<ShapedType>(getInput().getType());
971+
if (!inputTy.hasRank())
972+
return {};
973+
if (inputTy == getType() &&
974+
(inputTy.getRank() == 0 || inputTy.getDimSize(getAxis()) == 1))
975+
return getInput();
976+
977+
// Fold multiply like reduce_prod operators using mulBinaryFolder
978+
if (inputTy.getRank() == 1 && inputTy.getDimSize(0) == 2) {
979+
const auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
980+
if (!resultTy)
981+
return {};
982+
983+
const auto elements =
984+
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput());
985+
if (!elements)
986+
return {};
987+
988+
const auto lhsAttr =
989+
DenseElementsAttr::get(resultTy, {elements.getValues<Attribute>()[0]});
990+
const auto rhsAttr =
991+
DenseElementsAttr::get(resultTy, {elements.getValues<Attribute>()[1]});
992+
return mulBinaryFolder(lhsAttr, rhsAttr, resultTy, 0);
993+
}
994+
995+
return {};
996+
}
997+
970998
OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
971999
auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
9721000
auto outputTy = llvm::dyn_cast<RankedTensorType>(getType());

mlir/test/Dialect/Tosa/canonicalize.mlir

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1012,3 +1012,14 @@ func.func nested @do_not_fold_reciprocal_int() -> tensor<3x600x1200xi32> {
10121012
%2 = "tosa.reciprocal"(%1): (tensor<3x600x1200xi32>) -> tensor<3x600x1200xi32>
10131013
return %2 : tensor<3x600x1200xi32>
10141014
}
1015+
1016+
// -----
1017+
1018+
// CHECK-LABEL: @fold_reduce_prod_is_mul
1019+
func.func @fold_reduce_prod_is_mul() -> tensor<1xi32> {
1020+
// CHECK-DAG: %[[VAL_0:.*]] = "tosa.const"() <{value = dense<77> : tensor<1xi32>}> : () -> tensor<1xi32>
1021+
// CHECK: return %[[VAL_0]] : tensor<1xi32>
1022+
%0 = "tosa.const"() <{value = dense<[1, 77]> : tensor<2xi32>}> : () -> tensor<2xi32>
1023+
%1 = "tosa.reduce_prod"(%0) <{axis = 0 : i32}> : (tensor<2xi32>) -> tensor<1xi32>
1024+
return %1 : tensor<1xi32>
1025+
}

0 commit comments

Comments
 (0)