Skip to content

Commit b6c4829

Browse files
authored
[Triton] Verify all tt.reduce operands have the same shape (#4957)
Add `SameOperandsShape` to `tt.reduce` to verify all operands have the same shape. This matches `triton.language.reduce` (and similar) semantics. This change may enable further optimizations and even may help simplify the code dealing with this operation. Followup PRs will tackle this. The core Triton is a small number of people, and we receive many PRs (thank you!). To help us review your code more quickly, **if you are a new contributor (less than 3 PRs merged) we ask that you complete the following tasks and include the filled-out checklist in your PR description.** Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. - [X] I am not making a trivial change, such as fixing a typo in a comment. - [X] I have written a PR description following these [rules](https://cbea.ms/git-commit/#why-not-how). - [X] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`. - Select one of the following. - [X] I have added tests. - `/test` for `lit` tests - `/unittest` for C++ tests - `/python/test` for end-to-end tests - [ ] This PR does not need a test because `FILL THIS IN`. - Select one of the following. - [ ] I have not added any `lit` tests. - [X] The `lit` tests I have added follow these [best practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices), including the "tests should be minimal" section. (Usually running Python code and using the instructions it generates is not minimal.) Signed-off-by: victor-eds <[email protected]>
1 parent 4ddebd2 commit b6c4829

File tree

2 files changed

+14
-0
lines changed

2 files changed

+14
-0
lines changed

include/triton/Dialect/Triton/IR/TritonOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -710,6 +710,7 @@ def TT_DotScaledOp : TT_Op<"dot_scaled", [Pure,
710710
//
711711
def TT_ReduceOp: TT_Op<"reduce",
712712
[Pure,
713+
SameOperandsShape,
713714
SameOperandsEncoding,
714715
SingleBlock,
715716
DeclareOpInterfaceMethods<InferTypeOpInterface>]> {

test/Triton/invalid.mlir

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,19 @@ tt.func public @fn(%v: tensor<4x128xf64>) {
108108

109109
// -----
110110

111+
tt.func @reduce_different_input_shapes(%arg0: tensor<32x32x64xf32>, %arg1: tensor<16x32x64xf32>) -> (tensor<32x64xf32>, tensor<16x64xf32>) {
112+
// expected-error @below {{op requires the same shape for all operands}}
113+
%0:2 = "tt.reduce" (%arg0, %arg1) <{axis = 1 : i32}> ({
114+
^bb0(%acc0: f32, %acc1: f32, %cur0: f32, %cur1: f32):
115+
%1 = arith.addf %acc0, %cur0 : f32
116+
%2 = arith.addf %acc1, %cur1 : f32
117+
tt.reduce.return %1, %2 : f32, f32
118+
}) : (tensor<32x32x64xf32>, tensor<16x32x64xf32>) -> (tensor<32x64xf32>, tensor<16x64xf32>)
119+
tt.return %0#0, %0#1 : tensor<32x64xf32>, tensor<16x64xf32>
120+
}
121+
122+
// -----
123+
111124
tt.func public @fn(%v: tensor<4x128xf32>) {
112125
// expected-error @+1 {{requires the same shape}}
113126
%a = "tt.scan" (%v) ({

0 commit comments

Comments
 (0)