1- //  pybind11 
1+ //  PyBind11 
22#include  < pybind11/pybind11.h> 
33#include  < pybind11/stl.h> 
44#include  < pybind11/stl_bind.h> 
@@ -55,8 +55,8 @@ namespace py = pybind11;
5555#include  " mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" 
5656#include  " mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" 
5757
58- //  llvm: Debug
59- #include  " llvm/Support/Debug.h"    //  关键头文件 
58+ //  LLVM:  Debug
59+ #include  " llvm/Support/Debug.h"    //  Key header file 
6060
6161//  MLIR: Top-level Transforms
6262#include  " mlir/Transforms/Passes.h" 
@@ -71,29 +71,32 @@ namespace py = pybind11;
7171
7272#define  ADD_PASS_WRAPPER_0 (name, builder ) \
7373  m.def(name, [](mlir::PassManager &pm) { pm.addPass (builder ()); })
74+ 
7475#define  ADD_PASS_WRAPPER_1 (name, builder, ty0 ) \
75-   m.def(name,                                   \
76+   m.def(name, \
7677        [](mlir::PassManager &pm, ty0 val0) { pm.addPass (builder (val0)); })
7778
78- #define  ADD_PASS_WRAPPER_1_ARG (name, builder, ty0, arg0, val0 )             \
79-   m.def(                                                                   \
80-       name,                                                                \
79+ #define  ADD_PASS_WRAPPER_1_ARG (name, builder, ty0, arg0, val0 ) \
80+   m.def( \
81+       name, \
8182      [](mlir::PassManager &pm, ty0 arg0) { pm.addPass (builder (val0)); }, \
8283      py::arg (" pm"  ), py::arg(#arg0) = val0);
8384
84- //  一个函数,用于设置  MLIR/LLVM 的调试类型 
85+ //  Function to set  MLIR/LLVM debug type 
8586void  enable_mlir_debug (const  std::string &debug_type) {
8687  ::llvm::DebugFlag = true ;
8788  llvm::setCurrentDebugType (debug_type.c_str ());
8889}
8990
9091void  init_to_llvm (py::module  &&m) {
9192  using  namespace  mlir ; 
93+ 
9294  //  Note: Linalg conversions may not be available in this MLIR version
9395  ADD_PASS_WRAPPER_0 (" add_convert_linalg_to_affine_loops"  ,
9496                     createConvertLinalgToAffineLoopsPass);
9597  ADD_PASS_WRAPPER_0 (" add_empty_tensor_to_alloc_tensor"  ,
9698                     bufferization::createEmptyTensorToAllocTensorPass);
99+ 
97100  ADD_PASS_WRAPPER_1_ARG (
98101      " add_one_shot_bufferize_with_options"  ,
99102      [](bool  allowReturnAllocsFromLoops) {
@@ -102,6 +105,7 @@ void init_to_llvm(py::module &&m) {
102105        return  mlir::bufferization::createOneShotBufferizePass (options);
103106      },
104107      bool , allow_return_allocs_from_loops, true );
108+ 
105109  ADD_PASS_WRAPPER_0 (" add_one_shot_bufferize"  ,
106110                     bufferization::createOneShotBufferizePass);
107111  ADD_PASS_WRAPPER_0 (" add_lower_affine"  , createLowerAffinePass);
@@ -132,32 +136,39 @@ void init_to_llvm(py::module &&m) {
132136void  init_triton_shared_ir (py::module  &&m) {
133137  m.def (" load_dialects"  , [](mlir::MLIRContext &context) {
134138    mlir::DialectRegistry registry;
139+ 
140+     //  Register core dialects
135141    registry.insert <
136142        ::mlir::triton::TritonDialect,
137143        //  ::mlir::triton::gpu::TritonGPUDialect,
138144        //  ::mlir::triton::instrument::TritonInstrumentDialect,
139145        ::mlir::linalg::LinalgDialect,
140146        ::mlir::bufferization::BufferizationDialect,
141147        ::mlir::tptr::TPtrDialect,
142-         ::mlir::math::MathDialect, ::mlir::memref::MemRefDialect,
143-         ::mlir::arith::ArithDialect, ::mlir::scf::SCFDialect,
144-         ::mlir::vector::VectorDialect, ::mlir::cf::ControlFlowDialect,
145-         ::mlir::triton::proton::ProtonDialect, ::mlir::LLVM::LLVMDialect,
146-         ::mlir::ub::UBDialect, ::mlir::func::FuncDialect>();
148+         ::mlir::math::MathDialect,
149+         ::mlir::memref::MemRefDialect,
150+         ::mlir::arith::ArithDialect,
151+         ::mlir::scf::SCFDialect,
152+         ::mlir::vector::VectorDialect,
153+         ::mlir::cf::ControlFlowDialect,
154+         ::mlir::triton::proton::ProtonDialect,
155+         ::mlir::LLVM::LLVMDialect,
156+         ::mlir::ub::UBDialect,
157+         ::mlir::func::FuncDialect>();
158+ 
159+     //  Register interfaces and translations
147160    //  ::mlir::registerAllDialects(registry);
148161    ::mlir::LLVM::registerInlinerInterface (registry);
149162    ::mlir::registerBuiltinDialectTranslation (registry);
150163    ::mlir::registerLLVMDialectTranslation (registry);
151164    ::mlir::LLVM::registerInlinerInterface (registry);
152165
166+     //  Register bufferizable op interface external models
153167    ::mlir::arith::registerBufferizableOpInterfaceExternalModels (registry);
154168    ::mlir::scf::registerBufferizableOpInterfaceExternalModels (registry);
155169    ::mlir::linalg::registerBufferizableOpInterfaceExternalModels (registry);
156170    ::mlir::vector::registerBufferizableOpInterfaceExternalModels (registry);
157171    ::mlir::tensor::registerBufferizableOpInterfaceExternalModels (registry);
158-     // ! didn't know if exists
159-     //  ::mlir::memref::registerBufferizableOpInterfaceExternalModels(registry);
160-     //  ::mlir::func::registerBufferizableOpInterfaceExternalModels(registry);
161172
162173    ::mlir::bufferization::func_ext::
163174        registerBufferizableOpInterfaceExternalModels (registry);
0 commit comments