Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 29 additions & 1 deletion mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -902,10 +902,38 @@ REDUCE_FOLDER(ReduceAllOp)
REDUCE_FOLDER(ReduceAnyOp)
REDUCE_FOLDER(ReduceMaxOp)
REDUCE_FOLDER(ReduceMinOp)
REDUCE_FOLDER(ReduceProductOp)
REDUCE_FOLDER(ReduceSumOp)
#undef REDUCE_FOLDER

OpFoldResult ReduceProductOp::fold(FoldAdaptor adaptor) {
auto inputTy = llvm::cast<ShapedType>(getInput().getType());
if (!inputTy.hasRank())
return {};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we add a newline after the return here to aid readability?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


if (inputTy == getType() &&
(inputTy.getRank() == 0 || inputTy.getDimSize(getAxis()) == 1))
return getInput();

if (inputTy.getRank() != 1 || inputTy.getDimSize(0) != 2)
return {};

// inputTy has shape { 2 } : try folding reduce_product using mulBinaryFolder
const auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
if (!resultTy)
return {};

const auto elements =
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput());
if (!elements)
return {};

const auto lhsAttr =
DenseElementsAttr::get(resultTy, {elements.getValues<Attribute>()[0]});
const auto rhsAttr =
DenseElementsAttr::get(resultTy, {elements.getValues<Attribute>()[1]});
return mulBinaryFolder(lhsAttr, rhsAttr, resultTy, 0);
}

OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
auto outputTy = llvm::dyn_cast<RankedTensorType>(getType());
Expand Down
33 changes: 33 additions & 0 deletions mlir/test/Dialect/Tosa/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1040,3 +1040,36 @@ func.func @do_not_fold_int_div_division_by_0() -> tensor<1x24x2xi32> {
%16 = tosa.int_div %4, %1 : (tensor<1x24x2xi32>, tensor<1x24x2xi32>) -> tensor<1x24x2xi32>
return %16 : tensor<1x24x2xi32>
}

// -----

// CHECK-LABEL: @fold_reduce_prod_is_mul
func.func @fold_reduce_prod_is_mul() -> tensor<1xi32> {
// CHECK-DAG: %[[VAL_0:.*]] = "tosa.const"() <{values = dense<77> : tensor<1xi32>}> : () -> tensor<1xi32>
// CHECK: return %[[VAL_0]] : tensor<1xi32>
%0 = "tosa.const"() <{values = dense<[1, 77]> : tensor<2xi32>}> : () -> tensor<2xi32>
%1 = "tosa.reduce_product"(%0) <{axis = 0 : i32}> : (tensor<2xi32>) -> tensor<1xi32>
return %1 : tensor<1xi32>
}

// -----

// CHECK-LABEL: @no_fold_reduce_prod_rank_2
func.func @no_fold_reduce_prod_rank_2() -> tensor<1x1xi32> {
// check that reduce_product folding does not happen for input with rank > 1
// CHECK: tosa.reduce_product
%0 = "tosa.const"() <{values = dense<[[1, 77]]> : tensor<1x2xi32>}> : () -> tensor<1x2xi32>
%1 = "tosa.reduce_product"(%0) <{axis = 1 : i32}> : (tensor<1x2xi32>) -> tensor<1x1xi32>
return %1 : tensor<1x1xi32>
}

// -----

// CHECK-LABEL: @no_fold_reduce_prod_dim_3
func.func @no_fold_reduce_prod_dim_3() -> tensor<1xi32> {
// check that reduce_product folding does not happen for input with dim[0] != 2
// CHECK: tosa.reduce_product
%0 = "tosa.const"() <{values = dense<[1, 77, 1]> : tensor<3xi32>}> : () -> tensor<3xi32>
%1 = "tosa.reduce_product"(%0) <{axis = 0 : i32}> : (tensor<3xi32>) -> tensor<1xi32>
return %1 : tensor<1xi32>
}