Skip to content

Commit 53e3e6a

Browse files
authored
[AMD] Added a canonicalizer to ConcatOp (#7273)
Added a canonicalization pattern to `concatOp`. The pattern removes a `concatOp` if it has a single input operand. This scenario can potentially happen as a result of ops refinement. A corresponding lit-test is included
1 parent 7f94609 commit 53e3e6a

File tree

3 files changed

+33
-0
lines changed

3 files changed

+33
-0
lines changed

test/TritonGPU/amd/amd-canonicalize-extract-slice.mlir

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,16 @@ module attributes {"ttg.compute-capability" = 0 : i32, "ttg.num-ctas" = 1 : i32,
2222
tt.return %2 : tensor<32x64xf32, #blocked>
2323
}
2424
}
25+
26+
// -----
27+
28+
#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
29+
module attributes {"ttg.compute-capability" = 0 : i32, "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} {
30+
tt.func @canonicalize_singleton_concat(%arg0: tensor<128x128xf32, #blocked>) -> tensor<128x128xf32, #blocked> {
31+
// CHECK-LABEL: tt.func @canonicalize_singleton_concat
32+
33+
%1 = amdgpu.concat %arg0: tensor<128x128xf32, #blocked> -> tensor<128x128xf32, #blocked>
34+
// CHECK: tt.return %arg0 : tensor<128x128xf32, #blocked>
35+
tt.return %1 : tensor<128x128xf32, #blocked>
36+
}
37+
}

third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,7 @@ def ConcatOp : TT_AMDGPU_Op<"concat", [Pure]> {
234234
}];
235235

236236
let hasVerifier = 1;
237+
let hasCanonicalizer = 1;
237238
}
238239

239240
//===----------------------------------------------------------------------===//

third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -446,4 +446,23 @@ LogicalResult ConcatOp::verify() {
446446

447447
return success();
448448
}
449+
450+
// This pattern removes a concatOp if it has a single input operand.
451+
// This scenario can potentially happen as a result of ops refinement.
452+
mlir::LogicalResult foldConcatOpFromSingleSource(amdgpu::ConcatOp op,
453+
PatternRewriter &rewriter) {
454+
auto sources = op.getSources();
455+
if (sources.size() == 1) {
456+
auto source = sources.front();
457+
auto result = op.getResult();
458+
result.replaceAllUsesWith(source);
459+
return success();
460+
}
461+
return failure();
462+
}
463+
464+
void ConcatOp::getCanonicalizationPatterns(mlir::RewritePatternSet &patterns,
465+
mlir::MLIRContext *context) {
466+
patterns.add(foldConcatOpFromSingleSource);
467+
}
449468
} // namespace mlir::triton::amdgpu

0 commit comments

Comments
 (0)