Skip to content

Commit 9482841

Browse files
authored
[NFC] Remove uses of deprecated GEN_PASS_CLASSES for Triton/Transforms (#6785)
Continuation of triton-lang/triton#3971 Signed-off-by: Anatoly Myachev <[email protected]>
1 parent 6a3027a commit 9482841

File tree

9 files changed

+37
-77
lines changed

9 files changed

+37
-77
lines changed

bin/RegisterTritonDialects.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ void registerTestTritonAMDGPURangeAnalysis();
4242

4343
inline void registerTritonDialects(mlir::DialectRegistry &registry) {
4444
mlir::registerAllPasses();
45-
mlir::registerTritonPasses();
45+
mlir::triton::registerTritonPasses();
4646
mlir::triton::gpu::registerTritonGPUPasses();
4747
mlir::registerTritonNvidiaGPUPasses();
4848
mlir::test::registerTestAliasPass();

include/triton/Dialect/Triton/Transforms/Passes.h

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,14 @@
66
namespace mlir {
77
namespace triton {
88

9-
std::unique_ptr<Pass> createCombineOpsPass();
10-
11-
std::unique_ptr<Pass> createLoopInvariantCodeMotionPass();
12-
std::unique_ptr<Pass> createReorderBroadcastPass();
13-
std::unique_ptr<Pass> createRewriteTensorPointerPass();
14-
std::unique_ptr<Pass> createLoopUnrollPass();
15-
16-
} // namespace triton
9+
// Generate the pass class declarations.
10+
#define GEN_PASS_DECL
11+
#include "triton/Dialect/Triton/Transforms/Passes.h.inc"
1712

1813
#define GEN_PASS_REGISTRATION
1914
#include "triton/Dialect/Triton/Transforms/Passes.h.inc"
2015

16+
} // namespace triton
2117
} // namespace mlir
2218

2319
#endif

include/triton/Dialect/Triton/Transforms/Passes.td

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,6 @@ def TritonCombineOps : Pass</*cli-arg*/"triton-combine", /*Op*/"mlir::ModuleOp">
1919
=> dot(x,y,splat(0))`
2020
}];
2121

22-
let constructor = "mlir::triton::createCombineOpsPass()";
23-
2422
let dependentDialects = ["mlir::arith::ArithDialect"];
2523
}
2624

@@ -33,7 +31,7 @@ def TritonReorderBroadcast : Pass</*cli-arg*/"triton-reorder-broadcast", /*Op*/"
3331
In the event of a match, the broadcast (or splat) operation is delayed
3432
and performed after the ElementWise operation.
3533
}];
36-
let constructor = "mlir::triton::createReorderBroadcastPass()";
34+
3735
let dependentDialects = ["mlir::triton::TritonDialect"];
3836
}
3937

@@ -45,8 +43,6 @@ def TritonRewriteTensorPointer : Pass</*cli-arg*/"triton-rewrite-tensor-pointer"
4543
the pointer/mask/other for each load/store.
4644
}];
4745

48-
let constructor = "mlir::triton::createRewriteTensorPointerPass()";
49-
5046
let dependentDialects = ["mlir::triton::TritonDialect"];
5147
}
5248

@@ -56,7 +52,7 @@ def TritonLoopUnroll : Pass</*cli-arg*/"triton-loop-unroll", /*Op*/"mlir::Module
5652
The pass unrolls a scf loop with tt.loop_unroll_factor attribute. The attribute specialises how many iterations
5753
the loop should be unrolled.
5854
}];
59-
let constructor = "mlir::triton::createLoopUnrollPass()";
55+
6056
let dependentDialects = ["mlir::triton::TritonDialect"];
6157
}
6258

@@ -68,7 +64,7 @@ def TritonLoopInvariantCodeMotion : Pass</*cli-arg*/"triton-licm", /*Op*/"mlir::
6864
generates a trip-count check. For scf.while loops, it clones the condition
6965
from the before body.
7066
}];
71-
let constructor = "mlir::triton::createLoopInvariantCodeMotionPass()";
67+
7268
let dependentDialects = ["mlir::triton::TritonDialect"];
7369
}
7470

lib/Dialect/Triton/Transforms/Combine.cpp

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
#include <memory>
2-
31
#include "mlir/IR/BuiltinAttributes.h"
42
#include "mlir/IR/Matchers.h"
53
#include "mlir/IR/PatternMatch.h"
@@ -10,10 +8,11 @@
108
#include "triton/Dialect/Triton/IR/Dialect.h"
119
#include "triton/Dialect/Triton/Transforms/Passes.h"
1210

13-
#define GEN_PASS_CLASSES
11+
namespace mlir::triton {
12+
13+
#define GEN_PASS_DEF_TRITONCOMBINEOPS
1414
#include "triton/Dialect/Triton/Transforms/Passes.h.inc"
1515

16-
namespace mlir::triton {
1716
namespace {
1817

1918
bool isZero(Value val) {
@@ -240,7 +239,9 @@ class RankedReduceDescriptorLoads : public mlir::OpRewritePattern<ReshapeOp> {
240239
}
241240
};
242241

243-
class CombineOpsPass : public TritonCombineOpsBase<CombineOpsPass> {
242+
} // anonymous namespace
243+
244+
class CombineOpsPass : public impl::TritonCombineOpsBase<CombineOpsPass> {
244245
public:
245246
void runOnOperation() override {
246247
MLIRContext *context = &getContext();
@@ -264,10 +265,4 @@ class CombineOpsPass : public TritonCombineOpsBase<CombineOpsPass> {
264265
}
265266
};
266267

267-
} // anonymous namespace
268-
269-
std::unique_ptr<mlir::Pass> createCombineOpsPass() {
270-
return std::make_unique<CombineOpsPass>();
271-
}
272-
273268
} // namespace mlir::triton

lib/Dialect/Triton/Transforms/LoopInvariantCodeMotion.cpp

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,18 @@
44
#include "triton/Dialect/Triton/Transforms/Passes.h"
55
#include "llvm/Support/Debug.h"
66

7-
#define GEN_PASS_CLASSES
7+
namespace mlir::triton {
8+
9+
#define GEN_PASS_DEF_TRITONLOOPINVARIANTCODEMOTION
810
#include "triton/Dialect/Triton/Transforms/Passes.h.inc"
911

1012
#define DEBUG_TYPE "triton-licm"
1113
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
1214
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
1315

14-
namespace mlir::triton {
15-
16-
namespace {
17-
1816
class LoopInvariantCodeMotionPass
19-
: public TritonLoopInvariantCodeMotionBase<LoopInvariantCodeMotionPass> {
17+
: public impl::TritonLoopInvariantCodeMotionBase<
18+
LoopInvariantCodeMotionPass> {
2019

2120
DenseMap<LoopLikeOpInterface, bool> isLoopMemoryEffectFreeOrOnlyRead;
2221

@@ -81,10 +80,4 @@ class LoopInvariantCodeMotionPass
8180
}
8281
};
8382

84-
} // anonymous namespace
85-
86-
std::unique_ptr<mlir::Pass> createLoopInvariantCodeMotionPass() {
87-
return std::make_unique<LoopInvariantCodeMotionPass>();
88-
}
89-
9083
} // namespace mlir::triton

lib/Dialect/Triton/Transforms/LoopUnroll.cpp

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
#include <memory>
2-
31
#include "mlir/Dialect/SCF/Utils/Utils.h"
42
#include "mlir/IR/BuiltinAttributes.h"
53
#include "mlir/IR/Matchers.h"
@@ -12,18 +10,16 @@
1210
#include "triton/Dialect/Triton/Transforms/Passes.h"
1311
#include "llvm/Support/Debug.h"
1412

15-
#define GEN_PASS_CLASSES
13+
namespace mlir::triton {
14+
15+
#define GEN_PASS_DEF_TRITONLOOPUNROLL
1616
#include "triton/Dialect/Triton/Transforms/Passes.h.inc"
1717

1818
#define DEBUG_TYPE "triton-loop-unroll"
1919
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
2020
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
2121

22-
namespace mlir::triton {
23-
24-
namespace {
25-
26-
class LoopUnrollPass : public TritonLoopUnrollBase<LoopUnrollPass> {
22+
class LoopUnrollPass : public impl::TritonLoopUnrollBase<LoopUnrollPass> {
2723

2824
int getUnrollFactorOrDefault(scf::ForOp forOp) {
2925
// Use the attribute attached to the loop if it exists otherwise set the
@@ -38,8 +34,6 @@ class LoopUnrollPass : public TritonLoopUnrollBase<LoopUnrollPass> {
3834
const char *pipelineStagesAttrName = "tt.num_stages";
3935

4036
public:
41-
LoopUnrollPass() = default;
42-
LoopUnrollPass(const LoopUnrollPass &) {}
4337
void runOnOperation() override {
4438
LDBG("Loop unroll pass");
4539
SmallVector<scf::ForOp, 4> loops;
@@ -65,10 +59,4 @@ class LoopUnrollPass : public TritonLoopUnrollBase<LoopUnrollPass> {
6559
}
6660
};
6761

68-
} // anonymous namespace
69-
70-
std::unique_ptr<mlir::Pass> createLoopUnrollPass() {
71-
return std::make_unique<LoopUnrollPass>();
72-
}
73-
7462
} // namespace mlir::triton

lib/Dialect/Triton/Transforms/ReorderBroadcast.cpp

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,11 @@
1010
#include "triton/Dialect/Triton/IR/Dialect.h"
1111
#include "triton/Dialect/Triton/Transforms/Passes.h"
1212

13-
// TODO(jlebar): Move this and all other generatede code into namespace
14-
// mlir::triton.
13+
namespace mlir::triton {
14+
1515
#define GEN_PASS_DEF_TRITONREORDERBROADCAST
1616
#include "triton/Dialect/Triton/Transforms/Passes.h.inc"
1717

18-
namespace mlir::triton {
1918
namespace {
2019

2120
Operation *cloneWithNewArgsAndResultTypes(PatternRewriter &rewriter,
@@ -208,8 +207,10 @@ struct MoveBroadcastAfterElementwisePattern
208207
}
209208
};
210209

210+
} // namespace
211+
211212
class ReorderBroadcastPass
212-
: public ::impl::TritonReorderBroadcastBase<ReorderBroadcastPass> {
213+
: public impl::TritonReorderBroadcastBase<ReorderBroadcastPass> {
213214
public:
214215
void runOnOperation() override {
215216
MLIRContext *context = &getContext();
@@ -228,10 +229,4 @@ class ReorderBroadcastPass
228229
}
229230
};
230231

231-
} // namespace
232-
233-
std::unique_ptr<mlir::Pass> createReorderBroadcastPass() {
234-
return std::make_unique<ReorderBroadcastPass>();
235-
}
236-
237232
} // namespace mlir::triton

lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
#include <memory>
21
#include <stack>
32

43
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
@@ -10,9 +9,9 @@
109
#include "triton/Dialect/Triton/IR/Utility.h"
1110
#include "triton/Dialect/Triton/Transforms/Passes.h"
1211

13-
using namespace mlir;
12+
namespace mlir::triton {
1413

15-
#define GEN_PASS_CLASSES
14+
#define GEN_PASS_DEF_TRITONREWRITETENSORPOINTER
1615
#include "triton/Dialect/Triton/Transforms/Passes.h.inc"
1716

1817
namespace {
@@ -196,7 +195,7 @@ struct RewritedInfo {
196195
// very fragile and to solve we should expose convert Ptr of tensor to a
197196
// structure containins all values and not only offsets.
198197
class RewriteTensorPointerPass
199-
: public TritonRewriteTensorPointerBase<RewriteTensorPointerPass> {
198+
: public impl::TritonRewriteTensorPointerBase<RewriteTensorPointerPass> {
200199
private:
201200
DenseMap<Value, RewritedInfo> rewritedInfo;
202201

@@ -560,6 +559,4 @@ class RewriteTensorPointerPass
560559
}
561560
};
562561

563-
std::unique_ptr<Pass> triton::createRewriteTensorPointerPass() {
564-
return std::make_unique<RewriteTensorPointerPass>();
565-
}
562+
} // namespace mlir::triton

python/src/passes.cc

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,12 @@ void init_triton_passes_common(py::module &&m) {
3636

3737
void init_triton_passes_ttir(py::module &&m) {
3838
using namespace mlir::triton;
39-
ADD_PASS_WRAPPER_0("add_combine", createCombineOpsPass);
40-
ADD_PASS_WRAPPER_0("add_reorder_broadcast", createReorderBroadcastPass);
39+
ADD_PASS_WRAPPER_0("add_combine", createTritonCombineOps);
40+
ADD_PASS_WRAPPER_0("add_reorder_broadcast", createTritonReorderBroadcast);
4141
ADD_PASS_WRAPPER_0("add_rewrite_tensor_pointer",
42-
createRewriteTensorPointerPass);
43-
ADD_PASS_WRAPPER_0("add_loop_unroll", createLoopUnrollPass);
44-
ADD_PASS_WRAPPER_0("add_triton_licm", createLoopInvariantCodeMotionPass);
42+
createTritonRewriteTensorPointer);
43+
ADD_PASS_WRAPPER_0("add_loop_unroll", createTritonLoopUnroll);
44+
ADD_PASS_WRAPPER_0("add_triton_licm", createTritonLoopInvariantCodeMotion);
4545
ADD_PASS_OPTION_WRAPPER_4("add_convert_to_ttgpuir",
4646
createConvertTritonToTritonGPU, const std::string &,
4747
int, int, int);

0 commit comments

Comments
 (0)