Skip to content

Commit 8bf8f42

Browse files
masahiMogball
andauthored
[WS] Reimplement PartitionLoops pass supporting general control flow (triton-lang#7415)
`PartitionLoops` is now implemented using the idea of mutual recursion between for ops, if ops etc partitioning, as pioneered by the Meta WS branch. It should be ready for nested loops, although that has not been tested. If we agree that this work is on the right track, I'll work on manually creating a lit test for nested loops. The code is mostly ported from [the code split implementation in the NVWS branch](https://github.com/triton-lang/triton/blob/aref_auto_ws/third_party/nvidia/lib/Dialect/NVWS/Transforms/ArefCodeSplit.cpp). As discussed between NV and @jeffniu-openai, we are now emitting `nvws.warp_group` op as the result of code splitting, and to make the new implementation a drop-in replacement for the existing one, the `nvws.warp_group` -> `ttg.warp_specialize` conversion pass by @mbrookhart is immediately run. Thus, all lit tests for `tritongpu-partition-loops` pass as is and no e2e workload should break as the result of this refactoring. Depending on a future direction, we may move the code spitting part earlier in the pipeline, or retire `nvws.warp_group` and just emit `ttg.warp_specialize` directly. For now, I'm adding a dependency on the NVWS dialect to TritonGPU, to use `nvws.warp_group` op and `nvws-lower-warp-group` from `PartitionLoops.cpp`. I'll move them to TritonGPU if the new implementation looks good. cc @3gx @htyu @manman-ren --------- Co-authored-by: Jeff Niu <[email protected]>
1 parent 5bb1890 commit 8bf8f42

File tree

9 files changed

+528
-210
lines changed

9 files changed

+528
-210
lines changed

include/triton/Dialect/TritonGPU/Transforms/Passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#define TRITON_DIALECT_TRITONGPU_TRANSFORMS_PASSES_H_
33

44
#include "mlir/Pass/Pass.h"
5+
#include "nvidia/include/Dialect/NVWS/IR/Dialect.h"
56
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
67

78
namespace mlir {

include/triton/Dialect/TritonGPU/Transforms/Passes.td

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,8 @@ def TritonGPUAutomaticWarpSpecialization : Pass<"tritongpu-automatic-warp-specia
106106
"mlir::triton::gpu::TritonGPUDialect",
107107
"mlir::scf::SCFDialect",
108108
"mlir::arith::ArithDialect",
109-
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect"
109+
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect",
110+
"triton::nvws::NVWSDialect"
110111
];
111112

112113
let options = [
@@ -143,7 +144,10 @@ def TritonGPUPartitionLoops : Pass<"tritongpu-partition-loops", "mlir::ModuleOp"
143144
between any of the partitions.
144145
}];
145146

146-
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect"];
147+
let dependentDialects = [
148+
"mlir::triton::gpu::TritonGPUDialect",
149+
"triton::nvws::NVWSDialect"
150+
];
147151
}
148152

149153
def TritonGPUOptimizePartitionWarps : Pass<"tritongpu-optimize-partition-warps", "mlir::ModuleOp"> {

lib/Dialect/TritonGPU/Transforms/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ add_triton_library(TritonGPUTransforms
4646
TritonTransforms
4747
TritonGPUIR
4848
TritonNvidiaGPUIR
49+
NVWSIR
50+
NVWSTransforms
4951
TritonToTritonGPU
5052
TritonInstrumentIR
5153
MLIRTransformUtils

0 commit comments

Comments
 (0)