Skip to content

Commit 80e2abd

Browse files
authored
[BACKEND] Remove decomposition of splat -> shared conversion (#5450)
1 parent a52c88a commit 80e2abd

File tree

4 files changed

+0
-33
lines changed

4 files changed

+0
-33
lines changed

include/triton/Conversion/TritonGPUToLLVM/Patterns.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,6 @@ namespace triton::gpu {
1313
/// |module| op because the codegen doesn't handle `blocked -> dot_op` directly.
1414
void decomposeBlockedToDotLayoutConversion(ModuleOp module);
1515

16-
/// Replaces `splat -> shared` with `splat -> blocked -> shared` in the given
17-
/// |module| op.
18-
void decomposeSplatOpToSharedLayoutConversion(ModuleOp module);
19-
2016
/// Replaces `mma/mfma -> dot_op` with `mma/mfma -> blocked -> dot_op` in the
2117
/// given |module| op, but bypass the decomposition if |shortcutFn| returns
2218
/// true.

lib/Conversion/TritonGPUToLLVM/DecomposeUnsupportedConversions.cpp

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -18,32 +18,6 @@ static void addAttrs(Operation *op, ArrayRef<mlir::NamedAttribute> attrs) {
1818

1919
namespace mlir::triton::gpu {
2020

21-
void decomposeSplatOpToSharedLayoutConversion(ModuleOp module) {
22-
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(module);
23-
int numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(module);
24-
int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp(module);
25-
module.walk([&](triton::SplatOp splatOp) -> void {
26-
auto dstType = cast<RankedTensorType>(splatOp.getType());
27-
auto shared =
28-
dyn_cast<triton::gpu::SharedEncodingAttr>(dstType.getEncoding());
29-
if (shared) {
30-
OpBuilder builder(splatOp);
31-
SmallVector<unsigned, 4> sizePerThread(dstType.getRank(), 1);
32-
auto newType = RankedTensorType::get(
33-
dstType.getShape(), dstType.getElementType(),
34-
triton::gpu::BlockedEncodingAttr::get(
35-
module.getContext(), dstType.getShape(), sizePerThread,
36-
getOrder(shared), numWarps, threadsPerWarp, numCTAs));
37-
auto newSplat = builder.create<triton::SplatOp>(splatOp.getLoc(), newType,
38-
splatOp.getSrc());
39-
auto newConvert = builder.create<triton::gpu::ConvertLayoutOp>(
40-
splatOp.getLoc(), dstType, newSplat.getResult());
41-
splatOp.replaceAllUsesWith(newConvert.getResult());
42-
splatOp.erase();
43-
}
44-
});
45-
}
46-
4721
void decomposeTensorCoreToDotLayoutConversion(ModuleOp module,
4822
ShortcutFn shortcutFn) {
4923
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(module);

third_party/amd/lib/TritonAMDGPUToLLVM/DecomposeUnsupportedConversions.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,6 @@ struct DecomposeUnsupportedAMDConversions
3434
int numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(mod);
3535
int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod);
3636

37-
triton::gpu::decomposeSplatOpToSharedLayoutConversion(mod);
38-
3937
auto isShortcut =
4038
mlir::triton::gpu::ShortcutFn(std::not_fn(cvtNeedsSharedMemory));
4139

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DecomposeUnsupportedConversions.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,6 @@ struct DecomposeUnsupportedConversions
7676
auto nvidiaShortCutFn = [&](RankedTensorType srcTy,
7777
RankedTensorType dstTy) { return true; };
7878
ModuleOp mod = getOperation();
79-
triton::gpu::decomposeSplatOpToSharedLayoutConversion(mod);
8079
triton::gpu::decomposeTensorCoreToDotLayoutConversion(mod,
8180
nvidiaShortCutFn);
8281
triton::gpu::decomposeBlockedToDotLayoutConversion(mod);

0 commit comments

Comments
 (0)