Skip to content

Commit 37fc68a

Browse files
committed
add tptr-to-llvm pass and using pass_manager in compiler.py to lower to llir
1 parent 2b728ad commit 37fc68a

File tree

18 files changed

+2032
-51
lines changed

18 files changed

+2032
-51
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,5 @@ build/*
55
.vscode/*
66
/python/examples/test_core.py
77
/python/examples/test_annotations.py
8+
9+
.clangd/*

CMakeLists.txt

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,17 +22,28 @@ add_subdirectory(lib)
2222
add_subdirectory(test)
2323
add_subdirectory(tools/triton-shared-opt)
2424

25-
if (TRITON_SHARED_BUILD_CPU_BACKEND)
25+
if(TRITON_SHARED_BUILD_CPU_BACKEND)
2626
add_triton_plugin(TritonShared ${CMAKE_CURRENT_SOURCE_DIR}/triton_shared.cc LINK_LIBS TritonSharedAnalysis TritonToLinalg TritonTilingExtIR)
27-
target_link_libraries(TritonShared PRIVATE Python3::Module pybind11::headers)
27+
target_link_libraries(TritonShared
28+
PRIVATE
29+
30+
MLIRRegisterAllDialects
31+
TPtrToLLVM
32+
33+
Python3::Module
34+
pybind11::headers
35+
)
2836
endif()
2937

30-
# Add symlinks to selected pytest files in triton. The tests are imported into triton-shared’s test folder to
31-
# run under triton-shared's conftest configuration.
38+
# Add symlinks to selected pytest files and the clang-format setting in triton. The tests are imported into triton-shared’s test folder to
39+
# run under triton-shared's conftest configuration, and the clang-format link ensures consistent code style enforcement across both repositories.
3240
cmake_path(APPEND CMAKE_CURRENT_SOURCE_DIR "python" "examples" "test_core.py" OUTPUT_VARIABLE TRITON_SHARED_TEST_CORE)
3341
cmake_path(APPEND CMAKE_CURRENT_SOURCE_DIR "python" "examples" "test_annotations.py" OUTPUT_VARIABLE TRITON_SHARED_TEST_ANNOTATIONS)
42+
cmake_path(APPEND CMAKE_CURRENT_SOURCE_DIR ".clang-format" OUTPUT_VARIABLE TRITON_SHARED_CLANG_FORMAT_SETTING)
3443
cmake_path(APPEND CMAKE_SOURCE_DIR "python" "test" "unit" "language" "test_core.py" OUTPUT_VARIABLE TRITON_TEST_CORE)
3544
cmake_path(APPEND CMAKE_SOURCE_DIR "python" "test" "unit" "language" "test_annotations.py" OUTPUT_VARIABLE TRITON_TEST_ANNOTATIONS)
45+
cmake_path(APPEND CMAKE_SOURCE_DIR ".clang-format" OUTPUT_VARIABLE TRITON_CLANG_FORMAT_SETTING)
3646

3747
add_symlink(${TRITON_TEST_CORE} ${TRITON_SHARED_TEST_CORE})
3848
add_symlink(${TRITON_TEST_ANNOTATIONS} ${TRITON_SHARED_TEST_ANNOTATIONS})
49+
add_symlink(${TRITON_CLANG_FORMAT_SETTING} ${TRITON_SHARED_CLANG_FORMAT_SETTING})

backend/compiler.py

Lines changed: 48 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from triton.backends.compiler import BaseBackend, GPUTarget
2-
from triton._C.libtriton import ir, passes
2+
from triton._C.libtriton import ir, passes, triton_shared
33
from dataclasses import dataclass
44
from typing import Any, Dict, Tuple
55
from types import ModuleType
@@ -42,7 +42,7 @@ def _get_sanitizer_type():
4242
if sanitizer_type != "" and sanitizer_type != "asan" and sanitizer_type != "tsan":
4343
# throw error
4444
raise Exception(f"TRITON_SHARED_SANITIZER_TYPE {sanitizer_type} is invalid.")
45-
45+
4646
return sanitizer_type
4747

4848
def _ttir_to_ttsharedir(mod):
@@ -78,41 +78,47 @@ def _ttsharedir_to_llir(ttsharedir: str):
7878
llmlir_path = os.path.join(tmpdir, "ll.mlir")
7979
llir_path = os.path.join(tmpdir, "ll.ir")
8080
Path(ttshared_path).write_text(ttsharedir)
81-
mlir_opt_path = _get_llvm_bin_path("mlir-opt")
81+
context = ir.context()
82+
triton_shared.ir.load_dialects(context)
83+
mod = ir.parse_mlir_module(ttshared_path, context)
8284
# TritonShared-MLIR to LLVM-MLIR
83-
subprocess.check_call([mlir_opt_path, ttshared_path,
84-
"--convert-linalg-to-affine-loops",
85-
# Note: eliminate-empty-tensors fails when there are multiple func.return ops
86-
# in a single kernel which are the results of early returns.
87-
# See python/examples/test_early_return.py for examples.
88-
# We disable this pass for now since performance on CPU isn't the main
89-
# focus at the moment.
90-
# "--eliminate-empty-tensors",
91-
"--empty-tensor-to-alloc-tensor",
92-
"--one-shot-bufferize=allow-return-allocs-from-loops=true",
93-
"--lower-affine",
94-
"--convert-linalg-to-loops",
95-
"--expand-strided-metadata",
96-
"--convert-scf-to-cf",
97-
"--convert-arith-to-llvm",
98-
"--convert-math-to-llvm",
99-
"--convert-complex-to-llvm",
100-
"--convert-vector-to-llvm",
101-
"--convert-index-to-llvm",
102-
"--memref-expand",
103-
"--finalize-memref-to-llvm",
104-
"--convert-func-to-llvm",
105-
"--convert-cf-to-llvm",
106-
# Lowering memrefs creates more affine.apply ops.
107-
# Lowering these affine ops again creates further arith ops,
108-
# so we have to run these two passes again here.
109-
"--lower-affine",
110-
"--convert-arith-to-llvm",
111-
# Remove all unrealized casts created
112-
"--reconcile-unrealized-casts",
113-
"--mlir-print-debuginfo",
114-
"-o",
115-
llmlir_path])
85+
86+
pm = ir.pass_manager(context)
87+
pm.enable_debug()
88+
triton_shared.to_llir.add_convert_linalg_to_affine_loops(pm)
89+
# Note: eliminate-empty-tensors fails when there are multiple func.return ops
90+
# in a single kernel which are the results of early returns.
91+
# See python/examples/test_early_return.py for examples.
92+
# We disable this pass for now since performance on CPU isn't the main
93+
# focus at the moment.
94+
# triton_shared.to_llir.add_eliminate_empty_tensors(pm)
95+
triton_shared.to_llir.add_empty_tensor_to_alloc_tensor(pm)
96+
triton_shared.to_llir.add_one_shot_bufferize_with_options(
97+
pm, allow_return_allocs_from_loops=True)
98+
triton_shared.to_llir.add_lower_affine(pm)
99+
triton_shared.to_llir.add_convert_linalg_to_loops(pm)
100+
triton_shared.to_llir.add_expand_strided_metadata(pm)
101+
triton_shared.to_llir.add_convert_scf_to_cf(pm)
102+
triton_shared.to_llir.add_convert_tptr_to_llvm(pm)
103+
triton_shared.to_llir.add_convert_arith_to_llvm(pm)
104+
triton_shared.to_llir.add_convert_math_to_llvm(pm)
105+
triton_shared.to_llir.add_convert_complex_to_llvm(pm)
106+
triton_shared.to_llir.add_convert_vector_to_llvm(pm)
107+
triton_shared.to_llir.add_convert_index_to_llvm(pm)
108+
triton_shared.to_llir.add_memref_expand(pm)
109+
triton_shared.to_llir.add_finalize_memref_to_llvm(pm)
110+
triton_shared.to_llir.add_convert_func_to_llvm(pm)
111+
triton_shared.to_llir.add_convert_cf_to_llvm(pm)
112+
# Lowering memrefs creates more affine.apply ops.
113+
# Lowering these affine ops again creates further arith ops,
114+
# so we have to run these two passes again here.
115+
triton_shared.to_llir.add_lower_affine(pm)
116+
triton_shared.to_llir.add_convert_arith_to_llvm(pm)
117+
# Remove all unrealized casts created
118+
triton_shared.to_llir.add_reconcile_unrealized_casts(pm)
119+
pm.run(mod)
120+
121+
Path(llmlir_path).write_text(str(mod))
116122

117123
# LLVM-MLIR to LLVM-IR
118124
mlir_translate_path = _get_llvm_bin_path("mlir-translate")
@@ -145,16 +151,16 @@ def _llir_to_bin(llir: str, metadata):
145151
# using a sanitizer
146152
# invoke pass to append sanitizer attributes
147153
instrumented_src_path = os.path.join(tmpdir, "kernel-instrumented.ll")
148-
154+
149155
opt_path = _get_llvm_bin_path("opt")
150156
top_level_triton_path = os.path.dirname(triton.__file__)
151157
sanitizer_attributes_pass_path = str(next(Path(top_level_triton_path).rglob("libSanitizerAttributes.so"), None))
152158

153159
if not sanitizer_attributes_pass_path:
154160
raise Exception(f"libSanitizerAttributes.so does not exist.")
155161

156-
subprocess.check_call([opt_path, "-load-pass-plugin", sanitizer_attributes_pass_path,
157-
"-passes=sanitizer-attributes", f"-sanitizer-type={sanitizer_type}", "-S", src_path,
162+
subprocess.check_call([opt_path, "-load-pass-plugin", sanitizer_attributes_pass_path,
163+
"-passes=sanitizer-attributes", f"-sanitizer-type={sanitizer_type}", "-S", src_path,
158164
"-o", instrumented_src_path])
159165

160166
# compile to object file
@@ -166,12 +172,12 @@ def _llir_to_bin(llir: str, metadata):
166172
subprocess_args.extend(["-g", "-fsanitize=address", "-mllvm", "-asan-stack=0"])
167173
elif sanitizer_type == "tsan":
168174
subprocess_args.extend(["-g", "-fsanitize=thread"])
169-
175+
170176
subprocess.check_call(subprocess_args)
171177
else:
172178
llc_path = _get_llvm_bin_path("llc")
173179
subprocess.check_call([llc_path, src_path, "-filetype=obj", "-relocation-model=pic", "-o", dst_path])
174-
180+
175181
return Path(dst_path).read_bytes()
176182

177183

@@ -265,11 +271,11 @@ def add_stages(self, stages, options, language):
265271
stages["llir"] = lambda src, metadata: _optimize_llir(_ttsharedir_to_llir(src))
266272
stages["obj"] = lambda src, metadata: _llir_to_bin(src, metadata)
267273

268-
269274
@functools.lru_cache()
270275
def hash(self):
271276
return self.target
272277

273278
# The CPU backend does not use any extra python modules, return an empty dictionary
274279
def get_module_map(self) -> Dict[str, ModuleType]:
275280
return {}
281+

include/triton-shared/Conversion/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@ add_subdirectory(TritonPtrToMemref)
66
add_subdirectory(TritonToUnstructured)
77
add_subdirectory(StructuredToMemref)
88
add_subdirectory(UnstructuredToMemref)
9+
add_subdirectory(TPtrToLLVM)
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
set(LLVM_TARGET_DEFINITIONS Passes.td)
2+
mlir_tablegen(Passes.h.inc -gen-pass-decls --name TPtrToLLVM)
3+
add_public_tablegen_target(TPtrToLLVMConversionPassIncGen)
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
#ifndef TPTR_TO_LLVM_CONVERSION_PASSES_H
2+
#define TPTR_TO_LLVM_CONVERSION_PASSES_H
3+
4+
#include "triton-shared/Conversion/TPtrToLLVM/TPtrToLLVM.h"
5+
6+
namespace mlir {
7+
namespace tptr {
8+
9+
#define GEN_PASS_REGISTRATION
10+
#include "triton-shared/Conversion/TPtrToLLVM/Passes.h.inc"
11+
12+
} // namespace triton
13+
} // namespace mlir
14+
15+
#endif
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
#ifndef TPTR_TO_LLVM_CONVERSION_PASSES
2+
#define TPTR_TO_LLVM_CONVERSION_PASSES
3+
4+
include "mlir/Pass/PassBase.td"
5+
6+
def TPtrToLLVM : Pass<"tptr-to-llvm", "mlir::ModuleOp"> {
7+
let summary = "Convert Tptr operations into LLVM";
8+
let dependentDialects = ["mlir::tptr::TPtrDialect", "mlir::LLVM::LLVMDialect", "mlir::ptr::PtrDialect"];
9+
}
10+
#endif
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
#ifndef TRITON_CONVERSION_TPTR_TO_LLVM_TPTRTOLLVM_H
2+
#define TRITON_CONVERSION_TPTR_TO_LLVM_TPTRTOLLVM_H
3+
4+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
5+
#include "mlir/Pass/Pass.h"
6+
#include "mlir/Transforms/DialectConversion.h"
7+
#include "triton-shared/Dialect/TPtr/IR/TPtrDialect.h"
8+
9+
namespace mlir {
10+
namespace tptr {
11+
12+
#define GEN_PASS_DECL
13+
#include "triton-shared/Conversion/TPtrToLLVM/Passes.h.inc"
14+
15+
void populateTPtrToLLVMConversionPatterns(RewritePatternSet &patterns,
16+
TypeConverter &typeconverter);
17+
18+
std::unique_ptr<OperationPass<ModuleOp>> createTPtrToLLVMPass();
19+
20+
} // namespace tptr
21+
} // namespace mlir
22+
23+
#endif

lib/Conversion/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@ add_subdirectory(TritonArithToLinalg)
66
add_subdirectory(StructuredToMemref)
77
add_subdirectory(TritonPtrToMemref)
88
add_subdirectory(UnstructuredToMemref)
9+
add_subdirectory(TPtrToLLVM)
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
add_triton_library(TPtrToLLVM
2+
TPtrToLLVMPass.cpp
3+
TPtrToLLVM.cpp
4+
5+
DEPENDS
6+
TPtrToLLVMConversionPassIncGen
7+
TPtrTableGen
8+
9+
LINK_LIBS PUBLIC
10+
MLIRIR
11+
MLIRPass
12+
MLIRTransforms
13+
MLIRSupport
14+
MLIRDialectUtils
15+
TPtrIR
16+
)

0 commit comments

Comments
 (0)