Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/triton/Dialect/Triton/IR/TritonOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -723,6 +723,7 @@ def TT_ReduceOp: TT_Op<"reduce",
];
let hasVerifier = 1;
let hasRegionVerifier = 1;
let hasCanonicalizer = 1;
let extraClassDeclaration = [{
llvm::SmallVector<RankedTensorType> getInputTypes();
llvm::SmallVector<Type> getElementTypes();
Expand Down
51 changes: 51 additions & 0 deletions lib/Dialect/Triton/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,57 @@ LogicalResult ReduceOp::verifyRegions() {
return verifyRegionsImpl<ReduceReturnOp>(*this);
}

namespace {
/// Replace reduction operations with equivalent reshape operations.
///
/// This pattern replaces reductions whose input tensor size is 1 in the
/// reduction dimension:
/// ```mlir
/// "tt.reduce"(%0, ...) <{axis = N}> ({...})
/// : (tensor<S0x...xSN-1x1xSN+1x...>, ...) -> tensor<S0x...xSN-1xSN+1x...>
/// ```
/// With equivalent reshape operations (one per operand):
/// ```mlir
/// tt.reshape %0 allow_reorder
/// : tensor<S0x...xSN-1x1xSN+1x...> -> tensor<S0x...xSN-1xSN+1x...>
/// "tt.reduce"(%0) <{axis = N}> ({...})
/// {...}
/// : (tensor<S0x...xSN-1x1xSN+1x...>, ...) -> tensor<S0x...xSN-1xSN+1x...>
/// ```
struct CanonicalizeReshapeReduceOpPattern final : OpRewritePattern<ReduceOp> {
using OpRewritePattern<ReduceOp>::OpRewritePattern;

LogicalResult matchAndRewrite(ReduceOp reduceOp,
PatternRewriter &rewriter) const final {
Type resultType = reduceOp->getResultTypes().front();
// `tensor<NxTy>->Ty` case. `tt.reshape` does not support scalar result
// types, so we simply skip this case.
if (!isa<RankedTensorType>(resultType))
return failure();
RankedTensorType inputType = reduceOp.getInputTypes().front();
int32_t axis = reduceOp.getAxis();
if (inputType.getShape()[axis] != 1)
return failure();
SmallVector<Value> reshapes(reduceOp.getNumOperands());
llvm::transform(
llvm::zip_equal(reduceOp.getSrcs(), reduceOp->getResultTypes()),
reshapes.begin(),
[loc = reduceOp.getLoc(), &rewriter](auto pair) -> Value {
auto &[value, resultType] = pair;
return rewriter.create<ReshapeOp>(loc, resultType, value,
/*allow_reorder=*/true);
});
rewriter.replaceOp(reduceOp, reshapes);
return success();
}
};
} // namespace

void ReduceOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<CanonicalizeReshapeReduceOpPattern>(context);
}

llvm::SmallVector<RankedTensorType> ReduceOp::getInputTypes() {
return getInputTypesImpl(this->getOperands());
}
Expand Down
17 changes: 17 additions & 0 deletions test/Triton/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,20 @@ tt.func @fn(%arg0: tensor<1xf32, #sliced0>) -> (tensor<32x1xf32, #blocked0>){
tt.return %b : tensor<32x1xf32, #blocked0>
}
} // end module

// -----

// CHECK-LABEL: tt.func @reduce(
// CHECK-SAME: %[[ARG0:.*]]: tensor<2x1x16xf32>,
// CHECK-SAME: %[[ARG1:.*]]: tensor<2x1x16xf16>
tt.func @reduce(%arg0: tensor<2x1x16xf32>, %arg1: tensor<2x1x16xf16>) -> (tensor<2x16xf32>, tensor<2x16xf16>) {
// CHECK: tt.reshape %[[ARG0]] allow_reorder : tensor<2x1x16xf32> -> tensor<2x16xf32>
// CHECK: tt.reshape %[[ARG1]] allow_reorder : tensor<2x1x16xf16> -> tensor<2x16xf16>
%0:2 = "tt.reduce"(%arg0, %arg1) <{axis=1 : i32}> ({
^bb0(%acc0: f32, %acc1: f16, %curr0: f32, %curr1: f16):
%1 = arith.addf %acc0, %curr0 : f32
%2 = arith.mulf %acc1, %curr1 : f16
tt.reduce.return %1, %2 : f32, f16
}) : (tensor<2x1x16xf32>, tensor<2x1x16xf16>) -> (tensor<2x16xf32>, tensor<2x16xf16>)
tt.return %0#0, %0#1 : tensor<2x16xf32>, tensor<2x16xf16>
}