1- #include < memory>
2-
31#include " mlir/Dialect/SCF/Utils/Utils.h"
42#include " mlir/IR/BuiltinAttributes.h"
53#include " mlir/IR/Matchers.h"
1210#include " triton/Dialect/Triton/Transforms/Passes.h"
1311#include " llvm/Support/Debug.h"
1412
15- #define GEN_PASS_CLASSES
13+ namespace mlir ::triton {
14+
15+ #define GEN_PASS_DEF_TRITONLOOPUNROLL
1616#include " triton/Dialect/Triton/Transforms/Passes.h.inc"
1717
1818#define DEBUG_TYPE " triton-loop-unroll"
1919#define DBGS () (llvm::dbgs() << " [" DEBUG_TYPE " ]: " )
2020#define LDBG (X ) LLVM_DEBUG(DBGS() << X << " \n " )
2121
22- namespace mlir ::triton {
23-
24- namespace {
25-
26- class LoopUnrollPass : public TritonLoopUnrollBase <LoopUnrollPass> {
22+ class LoopUnrollPass : public impl ::TritonLoopUnrollBase<LoopUnrollPass> {
2723
2824 int getUnrollFactorOrDefault (scf::ForOp forOp) {
2925 // Use the attribute attached to the loop if it exists otherwise set the
@@ -38,8 +34,6 @@ class LoopUnrollPass : public TritonLoopUnrollBase<LoopUnrollPass> {
3834 const char *pipelineStagesAttrName = " tt.num_stages" ;
3935
4036public:
41- LoopUnrollPass () = default ;
42- LoopUnrollPass (const LoopUnrollPass &) {}
4337 void runOnOperation () override {
4438 LDBG (" Loop unroll pass" );
4539 SmallVector<scf::ForOp, 4 > loops;
@@ -65,10 +59,4 @@ class LoopUnrollPass : public TritonLoopUnrollBase<LoopUnrollPass> {
6559 }
6660};
6761
68- } // anonymous namespace
69-
70- std::unique_ptr<mlir::Pass> createLoopUnrollPass () {
71- return std::make_unique<LoopUnrollPass>();
72- }
73-
7462} // namespace mlir::triton
0 commit comments