Skip to content

Commit 0edb877

Browse files
committed
using pass_manager to lower ttsharedir
1 parent eaeb554 commit 0edb877

File tree

4 files changed

+244
-44
lines changed

4 files changed

+244
-44
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,5 @@
22
.cache
33
compile_commands.json
44
build/*
5-
.vscode/*
5+
.vscode/*
6+
.clangd/*

CMakeLists.txt

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,40 @@ add_subdirectory(lib)
1313
add_subdirectory(test)
1414
add_subdirectory(tools/triton-shared-opt)
1515

16-
if (TRITON_SHARED_BUILD_CPU_BACKEND)
16+
if(TRITON_SHARED_BUILD_CPU_BACKEND)
1717
add_triton_plugin(TritonShared ${CMAKE_CURRENT_SOURCE_DIR}/triton_shared.cc LINK_LIBS TritonSharedAnalysis TritonToLinalg TritonTilingExtIR)
18-
target_link_libraries(TritonShared PRIVATE Python3::Module pybind11::headers ${Python3_LIBRARIES})
18+
target_link_libraries(TritonShared
19+
PUBLIC
20+
MLIRAffineToStandard
21+
MLIRReconcileUnrealizedCasts
22+
MLIRSCFToControlFlow
23+
24+
# ! transforms
25+
MLIRBufferizationTransforms
26+
MLIRArithTransforms
27+
MLIRMemRefTransforms
28+
MLIRSCFTransforms
29+
MLIRFuncTransforms
30+
MLIRVectorTransforms
31+
MLIRTensorTransforms
32+
33+
MLIRArithToLLVM
34+
MLIRIndexToLLVM
35+
MLIRMathToLLVM
36+
MLIRComplexToLLVM
37+
MLIRVectorToLLVM
38+
MLIRVectorToLLVMPass
39+
MLIRFuncToLLVM
40+
MLIRControlFlowToLLVM
41+
MLIRMemRefToLLVM
42+
MLIRVectorToSCF
43+
MLIRUBToLLVM
44+
# MLIRMathToLibm
45+
46+
47+
PRIVATE
48+
Python3::Module
49+
pybind11::headers
50+
${Python3_LIBRARIES}
51+
)
1952
endif()

backend/compiler.py

Lines changed: 31 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from triton.backends.compiler import BaseBackend, GPUTarget
2-
from triton._C.libtriton import ir, passes
2+
from triton._C.libtriton import ir, passes, triton_shared
33
from dataclasses import dataclass
44
from typing import Any, Dict, Tuple
55
from types import ModuleType
@@ -78,41 +78,36 @@ def _ttsharedir_to_llir(ttsharedir: str):
7878
llmlir_path = os.path.join(tmpdir, "ll.mlir")
7979
llir_path = os.path.join(tmpdir, "ll.ir")
8080
Path(ttshared_path).write_text(ttsharedir)
81-
mlir_opt_path = _get_llvm_bin_path("mlir-opt")
82-
# TritonShared-MLIR to LLVM-MLIR
83-
subprocess.check_call([mlir_opt_path, ttshared_path,
84-
"--convert-linalg-to-affine-loops",
85-
# Note: eliminate-empty-tensors fails when there are multiple func.return ops
86-
# in a single kernel which are the results of early returns.
87-
# See python/examples/test_early_return.py for examples.
88-
# We disable this pass for now since performance on CPU isn't the main
89-
# focus at the moment.
90-
# "--eliminate-empty-tensors",
91-
"--empty-tensor-to-alloc-tensor",
92-
"--one-shot-bufferize=allow-return-allocs-from-loops=true",
93-
"--lower-affine",
94-
"--convert-linalg-to-loops",
95-
"--expand-strided-metadata",
96-
"--convert-scf-to-cf",
97-
"--convert-arith-to-llvm",
98-
"--convert-math-to-llvm",
99-
"--convert-complex-to-llvm",
100-
"--convert-vector-to-llvm",
101-
"--convert-index-to-llvm",
102-
"--memref-expand",
103-
"--finalize-memref-to-llvm",
104-
"--convert-func-to-llvm",
105-
"--convert-cf-to-llvm",
106-
# Lowering memrefs creates more affine.apply ops.
107-
# Lowering these affine ops again creates further arith ops,
108-
# so we have to run these two passes again here.
109-
"--lower-affine",
110-
"--convert-arith-to-llvm",
111-
# Remove all unrealized casts created
112-
"--reconcile-unrealized-casts",
113-
"--mlir-print-debuginfo",
114-
"-o",
115-
llmlir_path])
81+
context = ir.context()
82+
triton_shared.ir.load_dialects(context)
83+
mod = ir.parse_mlir_module(ttshared_path, context)
84+
85+
pm = ir.pass_manager(context)
86+
pm.enable_debug()
87+
triton_shared.to_llir.add_convert_linalg_to_affine_loops(pm)
88+
triton_shared.to_llir.add_empty_tensor_to_alloc_tensor(pm)
89+
triton_shared.to_llir.add_one_shot_bufferize_with_options(
90+
pm, allow_return_allocs_from_loops=True)
91+
triton_shared.to_llir.add_lower_affine(pm)
92+
triton_shared.to_llir.add_convert_linalg_to_loops(pm)
93+
triton_shared.to_llir.add_expand_strided_metadata(pm)
94+
triton_shared.to_llir.add_convert_scf_to_cf(pm)
95+
triton_shared.to_llir.add_convert_arith_to_llvm(pm)
96+
triton_shared.to_llir.add_convert_math_to_llvm(pm)
97+
triton_shared.to_llir.add_convert_complex_to_llvm(pm)
98+
triton_shared.to_llir.add_convert_vector_to_llvm(pm)
99+
triton_shared.to_llir.add_convert_index_to_llvm(pm)
100+
triton_shared.to_llir.add_memref_expand(pm)
101+
triton_shared.to_llir.add_finalize_memref_to_llvm(pm)
102+
triton_shared.to_llir.add_convert_func_to_llvm(pm)
103+
104+
triton_shared.to_llir.add_convert_cf_to_llvm(pm)
105+
triton_shared.to_llir.add_lower_affine(pm)
106+
triton_shared.to_llir.add_convert_arith_to_llvm(pm)
107+
triton_shared.to_llir.add_reconcile_unrealized_casts(pm)
108+
pm.run(mod)
109+
110+
Path(llmlir_path).write_text(str(mod))
116111

117112
# LLVM-MLIR to LLVM-IR
118113
mlir_translate_path = _get_llvm_bin_path("mlir-translate")

triton_shared.cc

Lines changed: 176 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,179 @@
1-
#include <pybind11/pybind11.h>
1+
// pybind11
2+
#include <pybind11/pybind11.h>
3+
#include <pybind11/stl.h>
4+
#include <pybind11/stl_bind.h>
25

36
namespace py = pybind11;
47

5-
// The CPU backend with triton_shared doesn't do compilation from within python
6-
// but rather externally through triton-shared-opt, so we leave this function
7-
// blank.
8-
void init_triton_triton_shared(py::module &&m) {}
8+
// LLVM
9+
#include "llvm/IR/Constants.h"
10+
#include "llvm/Support/TargetSelect.h"
11+
12+
// MLIR: Conversion Passes
13+
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
14+
#include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"
15+
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
16+
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
17+
#include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h"
18+
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
19+
#include "mlir/Conversion/Passes.h"
20+
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
21+
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.h"
22+
23+
// MLIR: Dialects
24+
#include "mlir/Dialect/Affine/IR/AffineOps.h"
25+
#include "mlir/Dialect/Arith/IR/Arith.h"
26+
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
27+
#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
28+
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
29+
#include "mlir/Dialect/Func/IR/FuncOps.h"
30+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
31+
#include "mlir/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.h"
32+
#include "mlir/Dialect/Linalg/IR/Linalg.h"
33+
#include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h"
34+
// #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
35+
#include "mlir/Dialect/Linalg/Passes.h"
36+
#include "mlir/Dialect/Math/IR/Math.h"
37+
#include "mlir/Dialect/MemRef/IR/MemRef.h"
38+
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
39+
#include "mlir/Dialect/SCF/IR/SCF.h"
40+
#include "mlir/Dialect/Tensor/IR/Tensor.h"
41+
#include "mlir/Dialect/Tensor/Transforms/Passes.h"
42+
#include "mlir/Dialect/UB/IR/UBOps.h"
43+
#include "mlir/Dialect/Vector/IR/VectorOps.h"
44+
#include "mlir/Dialect/Vector/Transforms/Passes.h"
45+
46+
// MLIR: Core IR and Passes
47+
#include "mlir/IR/DialectRegistry.h"
48+
#include "mlir/IR/MLIRContext.h"
49+
#include "mlir/InitAllDialects.h"
50+
#include "mlir/Pass/Pass.h"
51+
#include "mlir/Pass/PassManager.h"
52+
53+
// MLIR: Target and Translation
54+
// #include "mlir/Target/LLVMIR/Dialect/AMX/AMXToLLVMIRTranslation.h"
55+
#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h"
56+
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
57+
58+
// llvm:Debug
59+
#include "llvm/Support/Debug.h" // 关键头文件
60+
61+
// MLIR: Top-level Transforms
62+
#include "mlir/Transforms/Passes.h"
63+
64+
// Triton and other third-party dialects
65+
#include "third_party/proton/dialect/include/Dialect/Proton/IR/Dialect.h"
66+
#include "triton/Dialect/Triton/IR/Dialect.h"
67+
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
68+
#include "triton/Dialect/TritonInstrument/IR/Dialect.h"
69+
70+
#define ADD_PASS_WRAPPER_0(name, builder) \
71+
m.def(name, [](mlir::PassManager &pm) { pm.addPass(builder()); })
72+
#define ADD_PASS_WRAPPER_1(name, builder, ty0) \
73+
m.def(name, \
74+
[](mlir::PassManager &pm, ty0 val0) { pm.addPass(builder(val0)); })
75+
76+
#define ADD_PASS_WRAPPER_1_ARG(name, builder, ty0, arg0, val0) \
77+
m.def( \
78+
name, \
79+
[](mlir::PassManager &pm, ty0 arg0) { pm.addPass(builder(val0)); }, \
80+
py::arg("pm"), py::arg(#arg0) = val0);
81+
82+
// 一个函数,用于设置 MLIR/LLVM 的调试类型
83+
void enable_mlir_debug(const std::string &debug_type) {
84+
::llvm::DebugFlag = true;
85+
llvm::setCurrentDebugType(debug_type.c_str());
86+
}
87+
88+
void init_to_llvm(py::module &&m) {
89+
using namespace mlir;
90+
// Note: Linalg conversions may not be available in this MLIR version
91+
ADD_PASS_WRAPPER_0("add_convert_linalg_to_affine_loops",
92+
createConvertLinalgToAffineLoopsPass);
93+
ADD_PASS_WRAPPER_0("add_empty_tensor_to_alloc_tensor",
94+
bufferization::createEmptyTensorToAllocTensorPass);
95+
ADD_PASS_WRAPPER_1_ARG(
96+
"add_one_shot_bufferize_with_options",
97+
[](bool allowReturnAllocsFromLoops) {
98+
mlir::bufferization::OneShotBufferizePassOptions options;
99+
options.allowReturnAllocsFromLoops = allowReturnAllocsFromLoops;
100+
return mlir::bufferization::createOneShotBufferizePass(options);
101+
},
102+
bool, allow_return_allocs_from_loops, true);
103+
ADD_PASS_WRAPPER_0("add_one_shot_bufferize",
104+
bufferization::createOneShotBufferizePass);
105+
ADD_PASS_WRAPPER_0("add_lower_affine", createLowerAffinePass);
106+
ADD_PASS_WRAPPER_0("add_convert_linalg_to_loops",
107+
createConvertLinalgToLoopsPass);
108+
ADD_PASS_WRAPPER_0("add_expand_strided_metadata",
109+
memref::createExpandStridedMetadataPass);
110+
ADD_PASS_WRAPPER_0("add_convert_scf_to_cf", createSCFToControlFlowPass);
111+
ADD_PASS_WRAPPER_0("add_convert_arith_to_llvm",
112+
createArithToLLVMConversionPass);
113+
ADD_PASS_WRAPPER_0("add_convert_math_to_llvm", createConvertMathToLLVMPass);
114+
ADD_PASS_WRAPPER_0("add_convert_complex_to_llvm",
115+
createConvertComplexToLLVMPass);
116+
ADD_PASS_WRAPPER_0("add_convert_vector_to_llvm",
117+
createConvertVectorToLLVMPass);
118+
ADD_PASS_WRAPPER_0("add_convert_index_to_llvm", createConvertIndexToLLVMPass);
119+
ADD_PASS_WRAPPER_0("add_memref_expand", memref::createExpandOpsPass);
120+
ADD_PASS_WRAPPER_0("add_finalize_memref_to_llvm",
121+
createFinalizeMemRefToLLVMConversionPass);
122+
ADD_PASS_WRAPPER_0("add_convert_func_to_llvm", createConvertFuncToLLVMPass);
123+
ADD_PASS_WRAPPER_0("add_convert_cf_to_llvm",
124+
createConvertControlFlowToLLVMPass);
125+
ADD_PASS_WRAPPER_0("add_reconcile_unrealized_casts",
126+
createReconcileUnrealizedCastsPass);
127+
}
128+
129+
void init_triton_shared_ir(py::module &&m) {
130+
m.def("load_dialects", [](mlir::MLIRContext &context) {
131+
mlir::DialectRegistry registry;
132+
registry.insert<
133+
::mlir::triton::TritonDialect,
134+
// ::mlir::triton::gpu::TritonGPUDialect,
135+
// ::mlir::triton::instrument::TritonInstrumentDialect,
136+
::mlir::linalg::LinalgDialect,
137+
::mlir::bufferization::BufferizationDialect,
138+
::mlir::tptr::TPtrDialect,
139+
::mlir::math::MathDialect, ::mlir::memref::MemRefDialect,
140+
::mlir::arith::ArithDialect, ::mlir::scf::SCFDialect,
141+
::mlir::vector::VectorDialect, ::mlir::cf::ControlFlowDialect,
142+
::mlir::triton::proton::ProtonDialect, ::mlir::LLVM::LLVMDialect,
143+
::mlir::ub::UBDialect, ::mlir::func::FuncDialect>();
144+
// ::mlir::registerAllDialects(registry);
145+
::mlir::LLVM::registerInlinerInterface(registry);
146+
::mlir::registerBuiltinDialectTranslation(registry);
147+
::mlir::registerLLVMDialectTranslation(registry);
148+
::mlir::LLVM::registerInlinerInterface(registry);
149+
150+
::mlir::arith::registerBufferizableOpInterfaceExternalModels(registry);
151+
::mlir::scf::registerBufferizableOpInterfaceExternalModels(registry);
152+
::mlir::linalg::registerBufferizableOpInterfaceExternalModels(registry);
153+
::mlir::vector::registerBufferizableOpInterfaceExternalModels(registry);
154+
::mlir::tensor::registerBufferizableOpInterfaceExternalModels(registry);
155+
//! didn't know if exists
156+
// ::mlir::memref::registerBufferizableOpInterfaceExternalModels(registry);
157+
// ::mlir::func::registerBufferizableOpInterfaceExternalModels(registry);
158+
159+
::mlir::bufferization::func_ext::
160+
registerBufferizableOpInterfaceExternalModels(registry);
161+
// ::mlir::cf::registerBufferizableOpInterfaceExternalModels(registry);
162+
163+
context.appendDialectRegistry(registry);
164+
context.loadAllAvailableDialects();
165+
});
166+
}
167+
168+
void init_triton_shared_debug(py::module &&m) {
169+
m.def("enable_mlir_debug", enable_mlir_debug,
170+
"Enables a specific MLIR/LLVM debug type (e.g., 'pattern-rewrite'). "
171+
"Pass an empty string to disable.",
172+
py::arg("debug_type"));
173+
}
174+
175+
void init_triton_triton_shared(py::module &&m) {
176+
init_to_llvm(m.def_submodule("to_llir"));
177+
init_triton_shared_ir(m.def_submodule("ir"));
178+
init_triton_shared_debug(m.def_submodule("debug"));
179+
}

0 commit comments

Comments
 (0)