Skip to content

Commit b4dcd8e

Browse files
authored
[Backend][Hopper] Add a skeleton third-party warp specialization pass (#6624)
Adding a skeleton third-party warp specialization pass. It currently doesn't do anything but serves as a placefolder for upcoming changes.
1 parent 784b550 commit b4dcd8e

File tree

10 files changed

+96
-0
lines changed

10 files changed

+96
-0
lines changed

bin/RegisterTritonDialects.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
1818
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h"
1919

20+
#include "nvidia/hopper/include/Transforms/Passes.h"
2021
#include "nvidia/include/Dialect/NVWS/Transforms/Passes.h"
2122
#include "nvidia/include/NVGPUToLLVM/Passes.h"
2223
#include "nvidia/include/TritonNVIDIAGPUToLLVM/Passes.h"
@@ -81,6 +82,9 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
8182
// NVWS passes
8283
mlir::registerNVWSTransformsPasses();
8384

85+
// NVGPU transform passes
86+
mlir::registerNVHopperTransformsPasses();
87+
8488
registry.insert<
8589
mlir::triton::TritonDialect, mlir::cf::ControlFlowDialect,
8690
mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect,

third_party/nvidia/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,4 @@ endif()
99
if(TRITON_BUILD_UT)
1010
add_subdirectory(unittest)
1111
endif()
12+
add_subdirectory(hopper)
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
add_subdirectory(include)
2+
add_subdirectory(lib)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
add_subdirectory(Transforms)
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
set(LLVM_TARGET_DEFINITIONS Passes.td)
2+
mlir_tablegen(Passes.h.inc -gen-pass-decls -name NVHopperTransforms)
3+
add_public_tablegen_target(NVHopperTransformsIncGen)
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
2+
#ifndef DIALECT_NV_TRANSFORMS_PASSES_H_
3+
#define DIALECT_NV_TRANSFORMS_PASSES_H_
4+
5+
#include "mlir/Pass/Pass.h"
6+
7+
namespace mlir {
8+
9+
// Generate the pass class declarations.
10+
#define GEN_PASS_DECL
11+
#include "nvidia/hopper/include/Transforms/Passes.h.inc"
12+
13+
/// Generate the code for registering passes.
14+
#define GEN_PASS_REGISTRATION
15+
#include "nvidia/hopper/include/Transforms/Passes.h.inc"
16+
17+
} // namespace mlir
18+
#endif // DIALECT_NV_TRANSFORMS_PASSES_H_
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
#ifndef NV_TRANSFORMS_PASSES
2+
#define NV_TRANSFORMS_PASSES
3+
4+
include "mlir/Pass/PassBase.td"
5+
6+
def NVGPUWarpSpecialization : Pass<"nvgpu-warp-specialization", "mlir::ModuleOp"> {
7+
let summary = "Automaticl Warp specialization for NVIDIA GPU";
8+
9+
let description = [{
10+
This pass automatically partitions user-defined kernels into
11+
warp-specialized kernels, enabling finer-grained scheduling
12+
and improved utilization of hardware resources.
13+
}];
14+
15+
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect"];
16+
let options = [
17+
Option<"numWarpGroups", "num-warp-groups",
18+
"int32_t", /*default*/"0",
19+
"number of warp groups for warp specialization">
20+
];
21+
}
22+
23+
#endif // NV_TRANSFORMS_PASSES
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
add_subdirectory(Transforms)
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
add_triton_library(NVHopperTransforms
2+
WarpSpecialization.cpp
3+
4+
DEPENDS
5+
NVHopperTransformsIncGen
6+
7+
LINK_LIBS PUBLIC
8+
TritonIR
9+
TritonGPUIR
10+
MLIRTransformUtils
11+
)
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
#include "mlir/Pass/Pass.h"
2+
#include "mlir/Pass/PassManager.h"
3+
#include "mlir/Transforms/Passes.h"
4+
#include "nvidia/hopper/include/Transforms/Passes.h"
5+
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
6+
7+
#define DEBUG_TYPE "nvgpu-warp-specialization"
8+
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
9+
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
10+
11+
namespace mlir {
12+
13+
#define GEN_PASS_DEF_NVGPUWARPSPECIALIZATION
14+
#include "nvidia/hopper/include/Transforms/Passes.h.inc"
15+
16+
class NVGPUWarpSpecializationPass
17+
: public impl::NVGPUWarpSpecializationBase<NVGPUWarpSpecializationPass> {
18+
public:
19+
using impl::NVGPUWarpSpecializationBase<
20+
NVGPUWarpSpecializationPass>::NVGPUWarpSpecializationBase;
21+
22+
void runOnFuncOp(triton::FuncOp funcOp) {
23+
if (numWarpGroups <= 0)
24+
return;
25+
}
26+
27+
void runOnOperation() override {
28+
getOperation()->walk([&](triton::FuncOp funcOp) { runOnFuncOp(funcOp); });
29+
}
30+
};
31+
32+
} // namespace mlir

0 commit comments

Comments
 (0)