@@ -56,7 +56,7 @@ namespace py = pybind11;
5656#include " mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
5757
5858// LLVM: Debug
59- #include " llvm/Support/Debug.h" // Key header file
59+ #include " llvm/Support/Debug.h" // Key header file
6060
6161// MLIR: Top-level Transforms
6262#include " mlir/Transforms/Passes.h"
@@ -69,17 +69,17 @@ namespace py = pybind11;
6969#include " triton/Dialect/TritonGPU/IR/Dialect.h"
7070#include " triton/Dialect/TritonInstrument/IR/Dialect.h"
7171
72- #define ADD_PASS_WRAPPER_0 (name, builder ) \
72+ #define ADD_PASS_WRAPPER_0 (name, builder ) \
7373 m.def(name, [](mlir::PassManager &pm) { pm.addPass (builder ()); })
7474
75- #define ADD_PASS_WRAPPER_1 (name, builder, ty0 ) \
76- m.def(name, \
75+ #define ADD_PASS_WRAPPER_1 (name, builder, ty0 ) \
76+ m.def(name, \
7777 [](mlir::PassManager &pm, ty0 val0) { pm.addPass (builder (val0)); })
7878
79- #define ADD_PASS_WRAPPER_1_ARG (name, builder, ty0, arg0, val0 ) \
80- m.def( \
81- name, \
82- [](mlir::PassManager &pm, ty0 arg0) { pm.addPass (builder (val0)); }, \
79+ #define ADD_PASS_WRAPPER_1_ARG (name, builder, ty0, arg0, val0 ) \
80+ m.def( \
81+ name, \
82+ [](mlir::PassManager &pm, ty0 arg0) { pm.addPass (builder (val0)); }, \
8383 py::arg (" pm" ), py::arg(#arg0) = val0);
8484
8585// Function to set MLIR/LLVM debug type
@@ -91,7 +91,8 @@ void enable_mlir_debug(const std::string &debug_type) {
9191void init_to_llvm (py::module &&m) {
9292 using namespace mlir ;
9393
94- // Note: Linalg conversions may not be available in this MLIR version
94+ ADD_PASS_WRAPPER_0 (" add_eliminate_empty_tensors" ,
95+ bufferization::createEmptyTensorEliminationPass);
9596 ADD_PASS_WRAPPER_0 (" add_convert_linalg_to_affine_loops" ,
9697 createConvertLinalgToAffineLoopsPass);
9798 ADD_PASS_WRAPPER_0 (" add_empty_tensor_to_alloc_tensor" ,
@@ -143,18 +144,12 @@ void init_triton_shared_ir(py::module &&m) {
143144 // ::mlir::triton::gpu::TritonGPUDialect,
144145 // ::mlir::triton::instrument::TritonInstrumentDialect,
145146 ::mlir::linalg::LinalgDialect,
146- ::mlir::bufferization::BufferizationDialect,
147- ::mlir::tptr::TPtrDialect,
148- ::mlir::math::MathDialect,
149- ::mlir::memref::MemRefDialect,
150- ::mlir::arith::ArithDialect,
151- ::mlir::scf::SCFDialect,
152- ::mlir::vector::VectorDialect,
153- ::mlir::cf::ControlFlowDialect,
154- ::mlir::triton::proton::ProtonDialect,
155- ::mlir::LLVM::LLVMDialect,
156- ::mlir::ub::UBDialect,
157- ::mlir::func::FuncDialect>();
147+ ::mlir::bufferization::BufferizationDialect, ::mlir::tptr::TPtrDialect,
148+ ::mlir::math::MathDialect, ::mlir::memref::MemRefDialect,
149+ ::mlir::arith::ArithDialect, ::mlir::scf::SCFDialect,
150+ ::mlir::vector::VectorDialect, ::mlir::cf::ControlFlowDialect,
151+ ::mlir::triton::proton::ProtonDialect, ::mlir::LLVM::LLVMDialect,
152+ ::mlir::ub::UBDialect, ::mlir::func::FuncDialect>();
158153
159154 // Register interfaces and translations
160155 // ::mlir::registerAllDialects(registry);
0 commit comments