1- // pybind11
1+ // PyBind11
22#include < pybind11/pybind11.h>
33#include < pybind11/stl.h>
44#include < pybind11/stl_bind.h>
@@ -55,8 +55,8 @@ namespace py = pybind11;
5555#include " mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h"
5656#include " mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
5757
58- // llvm: Debug
59- #include " llvm/Support/Debug.h" // 关键头文件
58+ // LLVM: Debug
59+ #include " llvm/Support/Debug.h" // Key header file
6060
6161// MLIR: Top-level Transforms
6262#include " mlir/Transforms/Passes.h"
@@ -71,29 +71,32 @@ namespace py = pybind11;
7171
7272#define ADD_PASS_WRAPPER_0 (name, builder ) \
7373 m.def(name, [](mlir::PassManager &pm) { pm.addPass (builder ()); })
74+
7475#define ADD_PASS_WRAPPER_1 (name, builder, ty0 ) \
75- m.def(name, \
76+ m.def(name, \
7677 [](mlir::PassManager &pm, ty0 val0) { pm.addPass (builder (val0)); })
7778
78- #define ADD_PASS_WRAPPER_1_ARG (name, builder, ty0, arg0, val0 ) \
79- m.def( \
80- name, \
79+ #define ADD_PASS_WRAPPER_1_ARG (name, builder, ty0, arg0, val0 ) \
80+ m.def( \
81+ name, \
8182 [](mlir::PassManager &pm, ty0 arg0) { pm.addPass (builder (val0)); }, \
8283 py::arg (" pm" ), py::arg(#arg0) = val0);
8384
84- // 一个函数,用于设置 MLIR/LLVM 的调试类型
85+ // Function to set MLIR/LLVM debug type
8586void enable_mlir_debug (const std::string &debug_type) {
8687 ::llvm::DebugFlag = true ;
8788 llvm::setCurrentDebugType (debug_type.c_str ());
8889}
8990
9091void init_to_llvm (py::module &&m) {
9192 using namespace mlir ;
93+
9294 // Note: Linalg conversions may not be available in this MLIR version
9395 ADD_PASS_WRAPPER_0 (" add_convert_linalg_to_affine_loops" ,
9496 createConvertLinalgToAffineLoopsPass);
9597 ADD_PASS_WRAPPER_0 (" add_empty_tensor_to_alloc_tensor" ,
9698 bufferization::createEmptyTensorToAllocTensorPass);
99+
97100 ADD_PASS_WRAPPER_1_ARG (
98101 " add_one_shot_bufferize_with_options" ,
99102 [](bool allowReturnAllocsFromLoops) {
@@ -102,6 +105,7 @@ void init_to_llvm(py::module &&m) {
102105 return mlir::bufferization::createOneShotBufferizePass (options);
103106 },
104107 bool , allow_return_allocs_from_loops, true );
108+
105109 ADD_PASS_WRAPPER_0 (" add_one_shot_bufferize" ,
106110 bufferization::createOneShotBufferizePass);
107111 ADD_PASS_WRAPPER_0 (" add_lower_affine" , createLowerAffinePass);
@@ -132,32 +136,39 @@ void init_to_llvm(py::module &&m) {
132136void init_triton_shared_ir (py::module &&m) {
133137 m.def (" load_dialects" , [](mlir::MLIRContext &context) {
134138 mlir::DialectRegistry registry;
139+
140+ // Register core dialects
135141 registry.insert <
136142 ::mlir::triton::TritonDialect,
137143 // ::mlir::triton::gpu::TritonGPUDialect,
138144 // ::mlir::triton::instrument::TritonInstrumentDialect,
139145 ::mlir::linalg::LinalgDialect,
140146 ::mlir::bufferization::BufferizationDialect,
141147 ::mlir::tptr::TPtrDialect,
142- ::mlir::math::MathDialect, ::mlir::memref::MemRefDialect,
143- ::mlir::arith::ArithDialect, ::mlir::scf::SCFDialect,
144- ::mlir::vector::VectorDialect, ::mlir::cf::ControlFlowDialect,
145- ::mlir::triton::proton::ProtonDialect, ::mlir::LLVM::LLVMDialect,
146- ::mlir::ub::UBDialect, ::mlir::func::FuncDialect>();
148+ ::mlir::math::MathDialect,
149+ ::mlir::memref::MemRefDialect,
150+ ::mlir::arith::ArithDialect,
151+ ::mlir::scf::SCFDialect,
152+ ::mlir::vector::VectorDialect,
153+ ::mlir::cf::ControlFlowDialect,
154+ ::mlir::triton::proton::ProtonDialect,
155+ ::mlir::LLVM::LLVMDialect,
156+ ::mlir::ub::UBDialect,
157+ ::mlir::func::FuncDialect>();
158+
159+ // Register interfaces and translations
147160 // ::mlir::registerAllDialects(registry);
148161 ::mlir::LLVM::registerInlinerInterface (registry);
149162 ::mlir::registerBuiltinDialectTranslation (registry);
150163 ::mlir::registerLLVMDialectTranslation (registry);
151164 ::mlir::LLVM::registerInlinerInterface (registry);
152165
166+ // Register bufferizable op interface external models
153167 ::mlir::arith::registerBufferizableOpInterfaceExternalModels (registry);
154168 ::mlir::scf::registerBufferizableOpInterfaceExternalModels (registry);
155169 ::mlir::linalg::registerBufferizableOpInterfaceExternalModels (registry);
156170 ::mlir::vector::registerBufferizableOpInterfaceExternalModels (registry);
157171 ::mlir::tensor::registerBufferizableOpInterfaceExternalModels (registry);
158- // ! didn't know if exists
159- // ::mlir::memref::registerBufferizableOpInterfaceExternalModels(registry);
160- // ::mlir::func::registerBufferizableOpInterfaceExternalModels(registry);
161172
162173 ::mlir::bufferization::func_ext::
163174 registerBufferizableOpInterfaceExternalModels (registry);
0 commit comments