Skip to content

Commit d907aed

Browse files
CRobeckplotfi
andauthored
[BACKEND] Add support for out of tree TTIR/TTGIR passes (#8401)
Follow on to: triton-lang/triton#8137. This PR coupled with the Triton pass pipeline override adds the ability to call out of tree TTIR and TTGIR MLIR passes. Integration has been done with both the Python and triton-opt interfaces. Co-authored w/ @plotfi --------- Co-authored-by: Puyan Lotfi <[email protected]>
1 parent acd8104 commit d907aed

File tree

24 files changed

+839
-1
lines changed

24 files changed

+839
-1
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@ pytest.ini
4343
# Instrumentation
4444
python/triton/instrumentation
4545

46+
# MLIR Plugin
47+
python/triton/plugins
48+
4649
# Python caches
4750
__pycache__/
4851
*.py[cod]

CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ option(TRITON_BUILD_PYTHON_MODULE "Build Python Triton bindings" OFF)
2020
option(TRITON_BUILD_PROTON "Build the Triton Proton profiler" ON)
2121
option(TRITON_BUILD_UT "Build C++ Triton Unit Tests" ON)
2222
option(TRITON_BUILD_WITH_CCACHE "Build with ccache (if available)" ON)
23+
option(LLVM_BUILD_SHARED_LIBS
24+
"Build all libraries as shared libraries instead of static" OFF)
2325
set(TRITON_CODEGEN_BACKENDS "" CACHE STRING "Enable different codegen backends")
2426

2527
if(TRITON_BUILD_WITH_CCACHE)
@@ -60,6 +62,7 @@ else()
6062
set(CMAKE_EXE_LINKER_FLAGS_TRITONRELBUILDWITHASSERTS "/debug:fastlink /INCREMENTAL")
6163
set(CMAKE_MODULE_LINKER_FLAGS_TRITONRELBUILDWITHASSERTS "/debug:fastlink /INCREMENTAL")
6264
set(CMAKE_SHARED_LINKER_FLAGS_TRITONRELBUILDWITHASSERTS "/debug:fastlink /INCREMENTAL")
65+
set(LLVM_BUILD_SHARED_LIBS "0")
6366
endif()
6467

6568
# Default build type

Makefile

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ test-unit: all
4343
$(PYTEST) --tb=short -vs python/examples/gluon/01-attention-forward.py
4444
TRITON_ALWAYS_COMPILE=1 TRITON_DISABLE_LINE_INFO=0 LLVM_PASS_PLUGIN_PATH=python/triton/instrumentation/libGPUInstrumentationTestLib.so \
4545
$(PYTEST) --capture=tee-sys -rfs -vvv python/test/unit/instrumentation/test_gpuhello.py
46+
TRITON_PASS_PLUGIN_PATH=python/triton/plugins/libTritonPluginsTestLib.so \
47+
$(PYTEST) -vvv python/test/unit/plugins/test_plugin.py
4648
$(PYTEST) --tb=short -s -n $(NUM_PROCS) python/test/gluon
4749

4850
.PHONY: test-distributed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ def inspect_stages(_self, stages, options, language, capability):
272272
# inspect or modify add_stages here
273273
triton.knobs.runtime.add_stages_inspection_hook = inspect_stages
274274
```
275-
275+
Examples of how to use this for out of tree plugin passes is [here](lib/Plugins/README.md)
276276
# Changelog
277277

278278
Version 2.0 is out! New features include:

bin/RegisterTritonDialects.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,9 @@
4545
#include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h"
4646
#include "mlir/Conversion/UBToLLVM/UBToLLVM.h"
4747

48+
#include "triton/Tools/PluginUtils.h"
49+
#include "triton/Tools/Sys/GetEnv.hpp"
50+
4851
namespace mlir {
4952
namespace test {
5053
void registerTestAliasPass();
@@ -136,6 +139,21 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
136139
mlir::triton::proton::gpu::registerScheduleBufferStorePass();
137140
mlir::triton::proton::gpu::registerAddSchedBarriersPass();
138141

142+
// Plugin passes
143+
if (std::string filename =
144+
mlir::triton::tools::getStrEnv("TRITON_PASS_PLUGIN_PATH");
145+
!filename.empty()) {
146+
147+
TritonPlugin TP(filename);
148+
std::vector<const char *> passNames;
149+
if (auto result = TP.getPassHandles(passNames); !result)
150+
llvm::report_fatal_error(result.takeError());
151+
152+
for (const char *passName : passNames)
153+
if (auto result = TP.registerPass(passName); !result)
154+
llvm::report_fatal_error(result.takeError());
155+
}
156+
139157
registry.insert<
140158
mlir::triton::TritonDialect, mlir::cf::ControlFlowDialect,
141159
mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect,

include/triton/Tools/PluginUtils.h

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
#ifndef TRITON_PLUGIN_UTILS_H
2+
#define TRITON_PLUGIN_UTILS_H
3+
4+
#include "mlir/Pass/PassManager.h"
5+
#include "llvm/Support/DynamicLibrary.h"
6+
#include "llvm/Support/Error.h"
7+
#include <cstdint>
8+
9+
extern "C" {
10+
enum TritonPluginResult {
11+
TP_SUCCESS = 0,
12+
TP_GENERIC_FAILURE = 1,
13+
};
14+
};
15+
#define TRITON_PLUGIN_API \
16+
extern "C" __attribute__((visibility("default"))) TritonPluginResult
17+
18+
struct TritonPlugin {
19+
TritonPlugin() = delete;
20+
TritonPlugin(std::string filename) : filename(filename) {}
21+
22+
private:
23+
using enumeratePyBindHandlesType =
24+
std::function<TritonPluginResult(uint32_t *, const char **)>;
25+
using enumeratePyBindHandlesCType = TritonPluginResult (*)(uint32_t *,
26+
const char **);
27+
28+
// Put enumerate API names here, these can be involved with
29+
// enumeratePyBindHandles
30+
const std::string ENUMERATE_PASSES = "tritonEnumeratePluginPasses";
31+
32+
const std::string ADD_PASS = "tritonAddPluginPass";
33+
using addPassType =
34+
std::function<TritonPluginResult(mlir::PassManager *, const char *)>;
35+
using addPassCType = TritonPluginResult (*)(mlir::PassManager *,
36+
const char *);
37+
38+
const std::string REGISTER_PASS = "tritonRegisterPluginPass";
39+
using registerPassType = std::function<TritonPluginResult(const char *)>;
40+
using registerPassCType = TritonPluginResult (*)(const char *);
41+
42+
llvm::Error checkLibraryValid(const std::string &error) const;
43+
44+
llvm::Expected<intptr_t> getAddressOfSymbol(const std::string &symbol) const;
45+
46+
template <typename T, typename U>
47+
llvm::Expected<T> getAPI(const std::string &symbol) const {
48+
llvm::Expected<intptr_t> getDetailsFn = getAddressOfSymbol(symbol);
49+
if (auto Err = getDetailsFn.takeError()) {
50+
return Err;
51+
}
52+
auto func = reinterpret_cast<U>(*getDetailsFn);
53+
return func;
54+
}
55+
56+
llvm::Expected<TritonPluginResult> checkAPIResult(TritonPluginResult result,
57+
const char *handle) const;
58+
llvm::Expected<TritonPluginResult>
59+
enumeratePyBindHandles(enumeratePyBindHandlesType &enumeratePyBindHandles,
60+
std::vector<const char *> &passNames);
61+
62+
public:
63+
std::runtime_error err2exp(llvm::Error Err);
64+
65+
llvm::Error loadPlugin();
66+
67+
llvm::Expected<TritonPluginResult>
68+
getPassHandles(std::vector<const char *> &handles);
69+
70+
llvm::Expected<TritonPluginResult> addPass(mlir::PassManager *pm,
71+
const char *passHandle);
72+
73+
llvm::Expected<TritonPluginResult> registerPass(const char *passHandle);
74+
75+
private:
76+
std::string filename = "";
77+
mutable llvm::sys::DynamicLibrary library;
78+
enumeratePyBindHandlesType enumeratePassesAPI;
79+
addPassType addPassAPI;
80+
registerPassType registerPassAPI;
81+
bool isLoaded = false;
82+
};
83+
84+
#endif // TRITON_PLUGIN_UTILS_H

include/triton/Tools/Sys/GetEnv.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ inline const std::set<std::string> CACHE_INVALIDATING_ENV_VARS = {
4545
"TRITON_F32_DEFAULT",
4646
"TRITON_PREFER_TMEM_16x256_LAYOUT",
4747
"TRITON_ENABLE_EXPERIMENTAL_CONSAN",
48+
"TRITON_PASS_PLUGIN_PATH"
4849
// clang-format on
4950
};
5051

lib/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@ add_subdirectory(Dialect)
44
add_subdirectory(Target)
55
add_subdirectory(Tools)
66
add_subdirectory(Instrumentation)
7+
add_subdirectory(Plugins)

lib/Plugins/CMakeLists.txt

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
set(LLVM_TARGET_DEFINITIONS Passes.td)
2+
mlir_tablegen(Passes.h.inc -gen-pass-decls -name Plugins)
3+
add_public_tablegen_target(TritonPluginsIncGen)
4+
5+
llvm_canonicalize_cmake_booleans(
6+
MLIR_ENABLE_BINDINGS_PYTHON
7+
)
8+
9+
set(TRITON_PLUGIN_PASSES
10+
TritonPluginsTestLib
11+
)
12+
13+
set(TritonPluginsTestLib_SOURCES
14+
TritonPlugin.cpp
15+
)
16+
17+
18+
foreach( plugin ${TRITON_PLUGIN_PASSES} )
19+
add_mlir_library(${plugin}
20+
${${plugin}_SOURCES}
21+
SHARED
22+
23+
ADDITIONAL_HEADER_DIRS
24+
${PROJECT_BINARY_DIR}/lib
25+
26+
DEPENDS
27+
TritonTableGen
28+
TritonCanonicalizeIncGen
29+
TritonPluginsIncGen
30+
31+
LINK_LIBS PUBLIC
32+
MLIRPass
33+
LLVMSupport
34+
MLIRSupport
35+
"$<$<PLATFORM_ID:Darwin>:-undefined dynamic_lookup>"
36+
)
37+
38+
set_target_properties(${plugin} PROPERTIES
39+
LIBRARY_OUTPUT_DIRECTORY
40+
"${CMAKE_LIBRARY_OUTPUT_DIRECTORY}/../plugins")
41+
42+
target_compile_options(${plugin} PRIVATE -fvisibility=hidden)
43+
endforeach()

lib/Plugins/Passes.td

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
#ifndef TRITONGPU_PLUGIN_PASSES
2+
#define TRITONGPU_PLUGIN_PASSES
3+
4+
include "mlir/Pass/PassBase.td"
5+
6+
def TritonGPUMLIRPlugin : Pass<"tritongpu-plugin", "mlir::ModuleOp"> {
7+
let summary = "Triton MLIR Plugin Pass";
8+
}
9+
#endif

0 commit comments

Comments
 (0)