Skip to content

Commit 7c28a58

Browse files
committed
Merge commit 'e4e68d1b9c7d307785b230c0521095c753ffe071'
Signed-off-by: Anatoly Myachev <[email protected]>
2 parents d415da6 + e4e68d1 commit 7c28a58

File tree

30 files changed

+903
-170
lines changed

30 files changed

+903
-170
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,9 @@ pytest.ini
5151
# Instrumentation
5252
python/triton/instrumentation
5353

54+
# MLIR Plugin
55+
python/triton/plugins
56+
5457
# Python caches
5558
__pycache__/
5659
*.py[cod]

CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ option(TRITON_BUILD_PYTHON_MODULE "Build Python Triton bindings" OFF)
2020
option(TRITON_BUILD_PROTON "Build the Triton Proton profiler" ON)
2121
option(TRITON_BUILD_UT "Build C++ Triton Unit Tests" ON)
2222
option(TRITON_BUILD_WITH_CCACHE "Build with ccache (if available)" ON)
23+
option(LLVM_BUILD_SHARED_LIBS
24+
"Build all libraries as shared libraries instead of static" OFF)
2325
set(TRITON_CODEGEN_BACKENDS "" CACHE STRING "Enable different codegen backends")
2426

2527
if(TRITON_BUILD_WITH_CCACHE)
@@ -64,6 +66,7 @@ if(WIN32)
6466
set(CMAKE_EXE_LINKER_FLAGS_TRITONRELBUILDWITHASSERTS "/debug:fastlink /INCREMENTAL")
6567
set(CMAKE_MODULE_LINKER_FLAGS_TRITONRELBUILDWITHASSERTS "/debug:fastlink /INCREMENTAL")
6668
set(CMAKE_SHARED_LINKER_FLAGS_TRITONRELBUILDWITHASSERTS "/debug:fastlink /INCREMENTAL")
69+
set(LLVM_BUILD_SHARED_LIBS "0")
6770
else()
6871
message(FATAL_ERROR "Unsupported compiler")
6972
endif()

Makefile

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ test-unit: all
4343
$(PYTEST) --tb=short -vs python/examples/gluon/01-attention-forward.py
4444
TRITON_ALWAYS_COMPILE=1 TRITON_DISABLE_LINE_INFO=0 LLVM_PASS_PLUGIN_PATH=python/triton/instrumentation/libGPUInstrumentationTestLib.so \
4545
$(PYTEST) --capture=tee-sys -rfs -vvv python/test/unit/instrumentation/test_gpuhello.py
46+
TRITON_PASS_PLUGIN_PATH=python/triton/plugins/libTritonPluginsTestLib.so \
47+
$(PYTEST) -vvv python/test/unit/plugins/test_plugin.py
4648
$(PYTEST) --tb=short -s -n $(NUM_PROCS) python/test/gluon
4749

4850
.PHONY: test-distributed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ def inspect_stages(_self, stages, options, language, capability):
272272
# inspect or modify add_stages here
273273
triton.knobs.runtime.add_stages_inspection_hook = inspect_stages
274274
```
275-
275+
Examples of how to use this for out of tree plugin passes is [here](lib/Plugins/README.md)
276276
# Changelog
277277

278278
Version 2.0 is out! New features include:

bin/RegisterTritonDialects.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,9 @@
5555
#include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h"
5656
#include "mlir/Conversion/UBToLLVM/UBToLLVM.h"
5757

58+
#include "triton/Tools/PluginUtils.h"
59+
#include "triton/Tools/Sys/GetEnv.hpp"
60+
5861
namespace mlir {
5962
namespace test {
6063
namespace intel {
@@ -165,6 +168,21 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
165168
mlir::triton::proton::gpu::registerScheduleBufferStorePass();
166169
mlir::triton::proton::gpu::registerAddSchedBarriersPass();
167170

171+
// Plugin passes
172+
if (std::string filename =
173+
mlir::triton::tools::getStrEnv("TRITON_PASS_PLUGIN_PATH");
174+
!filename.empty()) {
175+
176+
TritonPlugin TP(filename);
177+
std::vector<const char *> passNames;
178+
if (auto result = TP.getPassHandles(passNames); !result)
179+
llvm::report_fatal_error(result.takeError());
180+
181+
for (const char *passName : passNames)
182+
if (auto result = TP.registerPass(passName); !result)
183+
llvm::report_fatal_error(result.takeError());
184+
}
185+
168186
registry.insert<
169187
mlir::triton::TritonDialect, mlir::cf::ControlFlowDialect,
170188
mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect,

include/triton/Tools/PluginUtils.h

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
#ifndef TRITON_PLUGIN_UTILS_H
2+
#define TRITON_PLUGIN_UTILS_H
3+
4+
#include "mlir/Pass/PassManager.h"
5+
#include "llvm/Support/DynamicLibrary.h"
6+
#include "llvm/Support/Error.h"
7+
#include <cstdint>
8+
9+
extern "C" {
10+
enum TritonPluginResult {
11+
TP_SUCCESS = 0,
12+
TP_GENERIC_FAILURE = 1,
13+
};
14+
};
15+
#define TRITON_PLUGIN_API \
16+
extern "C" __attribute__((visibility("default"))) TritonPluginResult
17+
18+
struct TritonPlugin {
19+
TritonPlugin() = delete;
20+
TritonPlugin(std::string filename) : filename(filename) {}
21+
22+
private:
23+
using enumeratePyBindHandlesType =
24+
std::function<TritonPluginResult(uint32_t *, const char **)>;
25+
using enumeratePyBindHandlesCType = TritonPluginResult (*)(uint32_t *,
26+
const char **);
27+
28+
// Put enumerate API names here, these can be involved with
29+
// enumeratePyBindHandles
30+
const std::string ENUMERATE_PASSES = "tritonEnumeratePluginPasses";
31+
32+
const std::string ADD_PASS = "tritonAddPluginPass";
33+
using addPassType =
34+
std::function<TritonPluginResult(mlir::PassManager *, const char *)>;
35+
using addPassCType = TritonPluginResult (*)(mlir::PassManager *,
36+
const char *);
37+
38+
const std::string REGISTER_PASS = "tritonRegisterPluginPass";
39+
using registerPassType = std::function<TritonPluginResult(const char *)>;
40+
using registerPassCType = TritonPluginResult (*)(const char *);
41+
42+
llvm::Error checkLibraryValid(const std::string &error) const;
43+
44+
llvm::Expected<intptr_t> getAddressOfSymbol(const std::string &symbol) const;
45+
46+
template <typename T, typename U>
47+
llvm::Expected<T> getAPI(const std::string &symbol) const {
48+
llvm::Expected<intptr_t> getDetailsFn = getAddressOfSymbol(symbol);
49+
if (auto Err = getDetailsFn.takeError()) {
50+
return Err;
51+
}
52+
auto func = reinterpret_cast<U>(*getDetailsFn);
53+
return func;
54+
}
55+
56+
llvm::Expected<TritonPluginResult> checkAPIResult(TritonPluginResult result,
57+
const char *handle) const;
58+
llvm::Expected<TritonPluginResult>
59+
enumeratePyBindHandles(enumeratePyBindHandlesType &enumeratePyBindHandles,
60+
std::vector<const char *> &passNames);
61+
62+
public:
63+
std::runtime_error err2exp(llvm::Error Err);
64+
65+
llvm::Error loadPlugin();
66+
67+
llvm::Expected<TritonPluginResult>
68+
getPassHandles(std::vector<const char *> &handles);
69+
70+
llvm::Expected<TritonPluginResult> addPass(mlir::PassManager *pm,
71+
const char *passHandle);
72+
73+
llvm::Expected<TritonPluginResult> registerPass(const char *passHandle);
74+
75+
private:
76+
std::string filename = "";
77+
mutable llvm::sys::DynamicLibrary library;
78+
enumeratePyBindHandlesType enumeratePassesAPI;
79+
addPassType addPassAPI;
80+
registerPassType registerPassAPI;
81+
bool isLoaded = false;
82+
};
83+
84+
#endif // TRITON_PLUGIN_UTILS_H

include/triton/Tools/Sys/GetEnv.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ inline const std::set<std::string> CACHE_INVALIDATING_ENV_VARS = {
4545
"TRITON_F32_DEFAULT",
4646
"TRITON_PREFER_TMEM_16x256_LAYOUT",
4747
"TRITON_ENABLE_EXPERIMENTAL_CONSAN",
48+
"TRITON_PASS_PLUGIN_PATH",
4849
"TRITON_INTEL_2DBLOCK_ASSERT",
4950
"TRITON_INTEL_AGGRESSIVE_DPAS_REUSE",
5051
"TRITON_INTEL_ENABLE_BLOCK_IO_ALL_LAYOUTS",

lib/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,4 @@ add_subdirectory(Conversion)
33
add_subdirectory(Dialect)
44
add_subdirectory(Target)
55
add_subdirectory(Tools)
6-
add_subdirectory(Instrumentation)
6+
add_subdirectory(Plugins)

lib/Instrumentation/CMakeLists.txt

Lines changed: 0 additions & 47 deletions
This file was deleted.

lib/Instrumentation/PrintLoadStoreMemSpaces.cpp

Lines changed: 0 additions & 106 deletions
This file was deleted.

0 commit comments

Comments
 (0)