Skip to content

Commit 6aada16

Browse files
committed
Address review comments
1 parent 45f1dbb commit 6aada16

File tree

2 files changed

+7
-4
lines changed

2 files changed

+7
-4
lines changed

lib/Dialect/Triton/IR/Ops.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -502,12 +502,14 @@ namespace {
502502
/// reduction dimension:
503503
/// ```mlir
504504
/// "tt.reduce"(%0, ...) <{axis = N}> ({...})
505-
/// : (tensor<S0x...xSN-1x1xSN+1x...>, ...) -> tensor<S0x...xSN-1xSN+1x...>
505+
/// : (tensor<S0 x ... x SN-1 x 1 x SN+1 x ...>, ...) ->
506+
/// (tensor<S0 x ... x SN-1 x SN+1 x ...>, ...)
506507
/// ```
507508
/// With equivalent reshape operations (one per operand):
508509
/// ```mlir
509510
/// tt.reshape %0 allow_reorder
510-
/// : tensor<S0x...xSN-1x1xSN+1x...> -> tensor<S0x...xSN-1xSN+1x...>
511+
/// : tensor<S0 x ... x SN-1 x 1 x SN+1 x ...> ->
512+
/// tensor<S0 x ... x SN-1 x SN+1 x ...>
511513
/// ```
512514
struct CanonicalizeReshapeReduceOpPattern final : OpRewritePattern<ReduceOp> {
513515
using OpRewritePattern<ReduceOp>::OpRewritePattern;

test/Triton/canonicalize.mlir

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,13 +57,14 @@ tt.func @fn(%arg0: tensor<1xf32, #sliced0>) -> (tensor<32x1xf32, #blocked0>){
5757
// CHECK-SAME: %[[ARG0:.*]]: tensor<2x1x16xf32>,
5858
// CHECK-SAME: %[[ARG1:.*]]: tensor<2x1x16xf16>
5959
tt.func @reduce(%arg0: tensor<2x1x16xf32>, %arg1: tensor<2x1x16xf16>) -> (tensor<2x16xf32>, tensor<2x16xf16>) {
60-
// CHECK: tt.reshape %[[ARG0]] allow_reorder : tensor<2x1x16xf32> -> tensor<2x16xf32>
61-
// CHECK: tt.reshape %[[ARG1]] allow_reorder : tensor<2x1x16xf16> -> tensor<2x16xf16>
60+
// CHECK: %[[VAL0:.*]] = tt.reshape %[[ARG0]] allow_reorder : tensor<2x1x16xf32> -> tensor<2x16xf32>
61+
// CHECK: %[[VAL1:.*]] = tt.reshape %[[ARG1]] allow_reorder : tensor<2x1x16xf16> -> tensor<2x16xf16>
6262
%0:2 = "tt.reduce"(%arg0, %arg1) <{axis=1 : i32}> ({
6363
^bb0(%acc0: f32, %acc1: f16, %curr0: f32, %curr1: f16):
6464
%1 = arith.addf %acc0, %curr0 : f32
6565
%2 = arith.mulf %acc1, %curr1 : f16
6666
tt.reduce.return %1, %2 : f32, f16
6767
}) : (tensor<2x1x16xf32>, tensor<2x1x16xf16>) -> (tensor<2x16xf32>, tensor<2x16xf16>)
68+
// CHECK: tt.return %[[VAL0]], %[[VAL1]] : tensor<2x16xf32>, tensor<2x16xf16>
6869
tt.return %0#0, %0#1 : tensor<2x16xf32>, tensor<2x16xf16>
6970
}

0 commit comments

Comments
 (0)