Skip to content

Commit 7ff2b0b

Browse files
authored
[NFC] Remove uses of deprecated GEN_PASS_CLASSES for TritonNvidiaGPU/Transforms (#6898)
Continuation of #6785 and #3971 --------- Signed-off-by: Anatoly Myachev <[email protected]>
1 parent 7568a4d commit 7ff2b0b

File tree

14 files changed

+270
-309
lines changed

14 files changed

+270
-309
lines changed

bin/RegisterTritonDialects.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
4444
mlir::registerAllPasses();
4545
mlir::triton::registerTritonPasses();
4646
mlir::triton::gpu::registerTritonGPUPasses();
47-
mlir::registerTritonNvidiaGPUPasses();
47+
mlir::triton::nvidia_gpu::registerTritonNvidiaGPUPasses();
4848
mlir::test::registerTestAliasPass();
4949
mlir::test::registerTestAlignmentPass();
5050
mlir::test::registerTestAllocationPass();

include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h

Lines changed: 5 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -38,38 +38,19 @@ struct ClusterInfo {
3838
int clusterDimZ;
3939
};
4040

41-
} // namespace nvidia_gpu
42-
} // namespace triton
43-
} // namespace mlir
44-
45-
namespace mlir {
46-
4741
std::unique_ptr<Pass> createTritonNvidiaGPUPlanCTAPass(
4842
mlir::triton::nvidia_gpu::ClusterInfo *clusterInfo = nullptr);
4943

50-
std::unique_ptr<Pass>
51-
createTritonNvidiaGPUFenceInsertionPass(int computeCapability = 90);
52-
53-
std::unique_ptr<Pass> createTritonNvidiaGPUTMALoweringPass();
54-
55-
std::unique_ptr<Pass> createTensorMemoryAllocationPass();
56-
57-
std::unique_ptr<Pass> createTritonNvidiaGPUMMALoweringPass();
58-
59-
std::unique_ptr<Pass> createTritonNvidiaGPUPromoteLHSToTMemPass();
60-
61-
std::unique_ptr<Pass> createTritonNvidiaGPURemoveTMEMTokensPass();
62-
63-
std::unique_ptr<Pass> createTritonNvidiaGPUOptimizeDescriptorEncodingPass();
64-
65-
std::unique_ptr<Pass> createTritonNvidiaGPUOptimizeTMemLayoutsPass();
66-
67-
std::unique_ptr<Pass> createTritonNvidiaGPUInterleaveTMemPass();
44+
#define GEN_PASS_DECL
45+
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h.inc"
6846

6947
/// Generate the code for registering passes.
7048
#define GEN_PASS_REGISTRATION
7149
#define GEN_PASS_DECL_TRITONNVIDIAGPULEGALIZETMALAYOUTS
7250
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h.inc"
7351

52+
} // namespace nvidia_gpu
53+
} // namespace triton
7454
} // namespace mlir
55+
7556
#endif // TRITON_DIALECT_TRITONNVIDIAGPU_TRANSFORMS_PASSES_H_

include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.td

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def TritonGPUPlanCTAPass : Pass<"triton-nvidia-gpu-plan-cta", "mlir::ModuleOp">
3232
and StoreLikeOps operations.
3333
}];
3434

35-
let constructor = "mlir::createTritonNvidiaGPUPlanCTAPass()";
35+
let constructor = "mlir::triton::nvidia_gpu::createTritonNvidiaGPUPlanCTAPass()";
3636

3737
let dependentDialects = [
3838
"mlir::triton::gpu::TritonGPUDialect",
@@ -48,8 +48,6 @@ def TritonGPUFenceInsertion : Pass<"triton-nvidia-gpu-fence-insertion", "mlir::M
4848
properly ordered across generic and async operations.
4949
}];
5050

51-
let constructor = "mlir::createTritonNvidiaGPUFenceInsertionPass()";
52-
5351
let dependentDialects = [
5452
"mlir::triton::gpu::TritonGPUDialect",
5553
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect"
@@ -69,22 +67,18 @@ def TritonNvidiaGPUTMALoweringPass : Pass<"triton-nvidia-tma-lowering", "mlir::M
6967
Lower Triton experimental descriptor load to TMA load/store operations in TritonNvidiaGPUDialect.
7068
}];
7169

72-
let constructor = "mlir::createTritonNvidiaGPUTMALoweringPass()";
73-
7470
let dependentDialects = [
7571
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect"
7672
];
7773
}
7874

79-
def TritionTensorMemoryAllocationPass : Pass<"triton-tensor-memory-allocation", "mlir::ModuleOp"> {
75+
def TritonTensorMemoryAllocationPass : Pass<"triton-tensor-memory-allocation", "mlir::ModuleOp"> {
8076
let summary = "Assign tensor memory allocation";
8177

8278
let description = [{
8379
Decide on tensor memory allocation and assign attributes to each allocation.
8480
}];
8581

86-
let constructor = "mlir::createTensorMemoryAllocationPass()";
87-
8882
let dependentDialects = [
8983
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect"
9084
];
@@ -97,8 +91,6 @@ def TritonNvidiaGPUMMALoweringPass : Pass<"triton-nvidia-mma-lowering", "mlir::M
9791
Lower MMA ops to prepare for conversion to LLVM.
9892
}];
9993

100-
let constructor = "mlir::createTritonNvidiaGPUMMALoweringPass()";
101-
10294
let dependentDialects = [
10395
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect"
10496
];
@@ -111,8 +103,6 @@ def TritonNvidiaGPUPromoteLHSToTMemPass : Pass<"tritongpu-promote-lhs-to-tmem",
111103
Promote LHS operand of MMAv5 op to Tensor Memory.
112104
}];
113105

114-
let constructor = "mlir::createTritonNvidiaGPUPromoteLHSToTMemPass()";
115-
116106
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
117107
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect",
118108
"mlir::triton::TritonDialect"];

lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp

Lines changed: 17 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -13,24 +13,21 @@
1313
//
1414
//===----------------------------------------------------------------------===//
1515

16-
using namespace mlir;
17-
namespace tt = ::mlir::triton;
18-
namespace ttg = ::mlir::triton::gpu;
19-
namespace ttng = ::mlir::triton::nvidia_gpu;
16+
namespace ttg = mlir::triton::gpu;
2017

21-
#define GEN_PASS_CLASSES
22-
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h.inc"
18+
namespace mlir {
19+
namespace triton {
20+
namespace nvidia_gpu {
2321

24-
namespace {
22+
#define GEN_PASS_DEF_TRITONGPUFENCEINSERTION
23+
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h.inc"
2524

2625
struct FenceInsertionPass
27-
: public TritonGPUFenceInsertionBase<FenceInsertionPass> {
26+
: public impl::TritonGPUFenceInsertionBase<FenceInsertionPass> {
2827

2928
public:
30-
FenceInsertionPass() = default;
31-
FenceInsertionPass(int computeCapability) {
32-
this->computeCapability = computeCapability;
33-
}
29+
using impl::TritonGPUFenceInsertionBase<
30+
FenceInsertionPass>::TritonGPUFenceInsertionBase;
3431
// TODO: support more general patterns to insert fences. eg. any op(generic)
3532
// to shared in use-def chain which refers by async proxy. We have generic(
3633
// convertlayout with sts/stmatix) + fence + async(wgmma) up to now
@@ -39,7 +36,7 @@ struct FenceInsertionPass
3936
if (computeCapability < 90)
4037
return;
4138
ModuleOp mod = getOperation();
42-
mod.walk([&](tt::DotOpInterface dotOp) {
39+
mod.walk([&](DotOpInterface dotOp) {
4340
Value a = dotOp.getA();
4441
Value b = dotOp.getB();
4542
bool aDependsOnShared = dependOnCopyRegToShared(a);
@@ -48,8 +45,8 @@ struct FenceInsertionPass
4845
return WalkResult::advance();
4946

5047
OpBuilder builder(dotOp);
51-
auto fence = builder.create<ttng::FenceAsyncSharedOp>(dotOp.getLoc(),
52-
/*bCluster=*/false);
48+
auto fence = builder.create<FenceAsyncSharedOp>(dotOp.getLoc(),
49+
/*bCluster=*/false);
5350
// If there is all the dependencies are outside of the loop try to hoist
5451
// the fence.
5552
while (auto loopOp = fence->getParentOfType<LoopLikeOpInterface>()) {
@@ -63,8 +60,8 @@ struct FenceInsertionPass
6360
}
6461

6562
// If the previous op is already a fence, this one isn't needed.
66-
if (auto lastFence = dyn_cast_or_null<ttng::FenceAsyncSharedOp>(
67-
fence->getPrevNode())) {
63+
if (auto lastFence =
64+
dyn_cast_or_null<FenceAsyncSharedOp>(fence->getPrevNode())) {
6865
if (lastFence.getBCluster() == fence.getBCluster())
6966
fence.erase();
7067
}
@@ -129,9 +126,7 @@ struct FenceInsertionPass
129126
return true;
130127
}
131128
};
132-
} // namespace
133129

134-
std::unique_ptr<Pass>
135-
mlir::createTritonNvidiaGPUFenceInsertionPass(int computeCapability) {
136-
return std::make_unique<FenceInsertionPass>(computeCapability);
137-
}
130+
} // namespace nvidia_gpu
131+
} // namespace triton
132+
} // namespace mlir

lib/Dialect/TritonNvidiaGPU/Transforms/InterleaveTMem.cpp

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,17 @@
44
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h"
55
#include "llvm/ADT/AddressRanges.h"
66

7-
namespace {
8-
9-
using namespace mlir;
7+
namespace ttg = mlir::triton::gpu;
108

11-
namespace ttng = triton::nvidia_gpu;
12-
namespace ttg = triton::gpu;
13-
namespace tt = triton;
9+
namespace mlir {
10+
namespace triton {
11+
namespace nvidia_gpu {
1412

15-
#define GEN_PASS_CLASSES
13+
#define GEN_PASS_DEF_TRITONNVIDIAGPUINTERLEAVETMEMPASS
1614
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h.inc"
1715

16+
namespace {
17+
1818
// If we don't know the effects of the op, we add all possible effects.
1919
void addAllValuelessEffects(
2020
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
@@ -77,7 +77,7 @@ std::pair<Value, AccessRange> findBufferAccess(Value a) {
7777

7878
Operation *defOp = a.getDefiningOp();
7979
// Accessing the alloc accesses the whole buffer.
80-
if (auto alloc = dyn_cast<ttng::TMEMAllocOp>(defOp)) {
80+
if (auto alloc = dyn_cast<TMEMAllocOp>(defOp)) {
8181
AccessRange access;
8282
for (uint64_t dim : alloc.getType().getShape())
8383
access.ranges.push_back({{0, dim}});
@@ -128,7 +128,7 @@ std::pair<Value, AccessRange> findBufferAccess(Value a) {
128128
}
129129

130130
// Subslice is a subview only on the N dimension.
131-
if (auto subslice = dyn_cast<ttng::TMEMSubSliceOp>(defOp)) {
131+
if (auto subslice = dyn_cast<TMEMSubSliceOp>(defOp)) {
132132
auto [alloc, parentAccess] = findBufferAccess(subslice.getSrc());
133133
if (!alloc)
134134
return {};
@@ -186,7 +186,7 @@ bool sinkOps(Value buffer, ArrayRef<Operation *> useChain) {
186186
}
187187
// Don't sink past barrier signals, since they may guard the liverange
188188
// of the buffer.
189-
if (isa<ttng::ArriveBarrierOp>(next))
189+
if (isa<ArriveBarrierOp>(next))
190190
break;
191191
if (!isMemoryEffectFree(next)) {
192192
SmallVector<MemoryEffects::EffectInstance> effects;
@@ -199,7 +199,7 @@ bool sinkOps(Value buffer, ArrayRef<Operation *> useChain) {
199199
dep = true;
200200
break;
201201
}
202-
if (isa<ttng::TensorMemory>(effect.getResource()) &&
202+
if (isa<TensorMemory>(effect.getResource()) &&
203203
(!effect.getValue() || tmemMayAlias(effect.getValue(), buffer))) {
204204
dep = true;
205205
break;
@@ -229,20 +229,22 @@ bool trySinkOp(Operation *op, Value buffer) {
229229
return sinkOps(buffer, useChain);
230230
}
231231

232+
} // anonymous namespace
233+
232234
struct TritonNvidiaGPUInterleaveTMemPass
233-
: public TritonNvidiaGPUInterleaveTMemPassBase<
235+
: public impl::TritonNvidiaGPUInterleaveTMemPassBase<
234236
TritonNvidiaGPUInterleaveTMemPass> {
235-
using TritonNvidiaGPUInterleaveTMemPassBase::
236-
TritonNvidiaGPUInterleaveTMemPassBase;
237+
using impl::TritonNvidiaGPUInterleaveTMemPassBase<
238+
TritonNvidiaGPUInterleaveTMemPass>::TritonNvidiaGPUInterleaveTMemPassBase;
237239

238240
void runOnOperation() override {
239241
MLIRContext *context = &getContext();
240242
ModuleOp m = getOperation();
241243
SmallVector<std::pair<Operation *, Value>> opsToSink;
242244
m.walk([&](Operation *op) {
243-
if (auto load = dyn_cast<ttng::TMEMLoadOp>(op))
245+
if (auto load = dyn_cast<TMEMLoadOp>(op))
244246
opsToSink.emplace_back(load, load.getSrc());
245-
else if (auto alloc = dyn_cast<ttng::TMEMAllocOp>(op))
247+
else if (auto alloc = dyn_cast<TMEMAllocOp>(op))
246248
opsToSink.emplace_back(alloc, alloc.getResult());
247249
});
248250
for (auto [op, buffer] : opsToSink) {
@@ -253,8 +255,6 @@ struct TritonNvidiaGPUInterleaveTMemPass
253255
}
254256
};
255257

256-
} // namespace
257-
258-
std::unique_ptr<Pass> mlir::createTritonNvidiaGPUInterleaveTMemPass() {
259-
return std::make_unique<TritonNvidiaGPUInterleaveTMemPass>();
260-
}
258+
} // namespace nvidia_gpu
259+
} // namespace triton
260+
} // namespace mlir

0 commit comments

Comments
 (0)