Skip to content

Commit 853711e

Browse files
committed
reviewd changes
1 parent 016e2c1 commit 853711e

File tree

2 files changed

+28
-16
lines changed

2 files changed

+28
-16
lines changed

backend/compiler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ def _ttsharedir_to_llir(ttsharedir: str):
8181
context = ir.context()
8282
triton_shared.ir.load_dialects(context)
8383
mod = ir.parse_mlir_module(ttshared_path, context)
84+
# TritonShared-MLIR to LLVM-MLIR
8485

8586
pm = ir.pass_manager(context)
8687
pm.enable_debug()

triton_shared.cc

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// pybind11
1+
// PyBind11
22
#include <pybind11/pybind11.h>
33
#include <pybind11/stl.h>
44
#include <pybind11/stl_bind.h>
@@ -55,8 +55,8 @@ namespace py = pybind11;
5555
#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h"
5656
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
5757

58-
// llvm:Debug
59-
#include "llvm/Support/Debug.h" // 关键头文件
58+
// LLVM: Debug
59+
#include "llvm/Support/Debug.h" // Key header file
6060

6161
// MLIR: Top-level Transforms
6262
#include "mlir/Transforms/Passes.h"
@@ -71,29 +71,32 @@ namespace py = pybind11;
7171

7272
#define ADD_PASS_WRAPPER_0(name, builder) \
7373
m.def(name, [](mlir::PassManager &pm) { pm.addPass(builder()); })
74+
7475
#define ADD_PASS_WRAPPER_1(name, builder, ty0) \
75-
m.def(name, \
76+
m.def(name, \
7677
[](mlir::PassManager &pm, ty0 val0) { pm.addPass(builder(val0)); })
7778

78-
#define ADD_PASS_WRAPPER_1_ARG(name, builder, ty0, arg0, val0) \
79-
m.def( \
80-
name, \
79+
#define ADD_PASS_WRAPPER_1_ARG(name, builder, ty0, arg0, val0) \
80+
m.def( \
81+
name, \
8182
[](mlir::PassManager &pm, ty0 arg0) { pm.addPass(builder(val0)); }, \
8283
py::arg("pm"), py::arg(#arg0) = val0);
8384

84-
// 一个函数,用于设置 MLIR/LLVM 的调试类型
85+
// Function to set MLIR/LLVM debug type
8586
void enable_mlir_debug(const std::string &debug_type) {
8687
::llvm::DebugFlag = true;
8788
llvm::setCurrentDebugType(debug_type.c_str());
8889
}
8990

9091
void init_to_llvm(py::module &&m) {
9192
using namespace mlir;
93+
9294
// Note: Linalg conversions may not be available in this MLIR version
9395
ADD_PASS_WRAPPER_0("add_convert_linalg_to_affine_loops",
9496
createConvertLinalgToAffineLoopsPass);
9597
ADD_PASS_WRAPPER_0("add_empty_tensor_to_alloc_tensor",
9698
bufferization::createEmptyTensorToAllocTensorPass);
99+
97100
ADD_PASS_WRAPPER_1_ARG(
98101
"add_one_shot_bufferize_with_options",
99102
[](bool allowReturnAllocsFromLoops) {
@@ -102,6 +105,7 @@ void init_to_llvm(py::module &&m) {
102105
return mlir::bufferization::createOneShotBufferizePass(options);
103106
},
104107
bool, allow_return_allocs_from_loops, true);
108+
105109
ADD_PASS_WRAPPER_0("add_one_shot_bufferize",
106110
bufferization::createOneShotBufferizePass);
107111
ADD_PASS_WRAPPER_0("add_lower_affine", createLowerAffinePass);
@@ -132,32 +136,39 @@ void init_to_llvm(py::module &&m) {
132136
void init_triton_shared_ir(py::module &&m) {
133137
m.def("load_dialects", [](mlir::MLIRContext &context) {
134138
mlir::DialectRegistry registry;
139+
140+
// Register core dialects
135141
registry.insert<
136142
::mlir::triton::TritonDialect,
137143
// ::mlir::triton::gpu::TritonGPUDialect,
138144
// ::mlir::triton::instrument::TritonInstrumentDialect,
139145
::mlir::linalg::LinalgDialect,
140146
::mlir::bufferization::BufferizationDialect,
141147
::mlir::tptr::TPtrDialect,
142-
::mlir::math::MathDialect, ::mlir::memref::MemRefDialect,
143-
::mlir::arith::ArithDialect, ::mlir::scf::SCFDialect,
144-
::mlir::vector::VectorDialect, ::mlir::cf::ControlFlowDialect,
145-
::mlir::triton::proton::ProtonDialect, ::mlir::LLVM::LLVMDialect,
146-
::mlir::ub::UBDialect, ::mlir::func::FuncDialect>();
148+
::mlir::math::MathDialect,
149+
::mlir::memref::MemRefDialect,
150+
::mlir::arith::ArithDialect,
151+
::mlir::scf::SCFDialect,
152+
::mlir::vector::VectorDialect,
153+
::mlir::cf::ControlFlowDialect,
154+
::mlir::triton::proton::ProtonDialect,
155+
::mlir::LLVM::LLVMDialect,
156+
::mlir::ub::UBDialect,
157+
::mlir::func::FuncDialect>();
158+
159+
// Register interfaces and translations
147160
// ::mlir::registerAllDialects(registry);
148161
::mlir::LLVM::registerInlinerInterface(registry);
149162
::mlir::registerBuiltinDialectTranslation(registry);
150163
::mlir::registerLLVMDialectTranslation(registry);
151164
::mlir::LLVM::registerInlinerInterface(registry);
152165

166+
// Register bufferizable op interface external models
153167
::mlir::arith::registerBufferizableOpInterfaceExternalModels(registry);
154168
::mlir::scf::registerBufferizableOpInterfaceExternalModels(registry);
155169
::mlir::linalg::registerBufferizableOpInterfaceExternalModels(registry);
156170
::mlir::vector::registerBufferizableOpInterfaceExternalModels(registry);
157171
::mlir::tensor::registerBufferizableOpInterfaceExternalModels(registry);
158-
//! didn't know if exists
159-
// ::mlir::memref::registerBufferizableOpInterfaceExternalModels(registry);
160-
// ::mlir::func::registerBufferizableOpInterfaceExternalModels(registry);
161172

162173
::mlir::bufferization::func_ext::
163174
registerBufferizableOpInterfaceExternalModels(registry);

0 commit comments

Comments
 (0)