Skip to content

Commit 2e4c4b7

Browse files
authored
Merge OpenAI Triton commit e4e68d1 (#5597)
This PR changes the Triton base from acd8104 to e4e68d1 (Nov 23). Pass rate: 95.42% --------- Signed-off-by: Anatoly Myachev <[email protected]>
2 parents 318c0f9 + ed1efc4 commit 2e4c4b7

File tree

30 files changed

+914
-170
lines changed

30 files changed

+914
-170
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,9 @@ pytest.ini
5151
# Instrumentation
5252
python/triton/instrumentation
5353

54+
# MLIR Plugin
55+
python/triton/plugins
56+
5457
# Python caches
5558
__pycache__/
5659
*.py[cod]

CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ option(TRITON_BUILD_PYTHON_MODULE "Build Python Triton bindings" OFF)
2020
option(TRITON_BUILD_PROTON "Build the Triton Proton profiler" ON)
2121
option(TRITON_BUILD_UT "Build C++ Triton Unit Tests" ON)
2222
option(TRITON_BUILD_WITH_CCACHE "Build with ccache (if available)" ON)
23+
option(LLVM_BUILD_SHARED_LIBS
24+
"Build all libraries as shared libraries instead of static" OFF)
2325
set(TRITON_CODEGEN_BACKENDS "" CACHE STRING "Enable different codegen backends")
2426

2527
if(TRITON_BUILD_WITH_CCACHE)
@@ -64,6 +66,7 @@ if(WIN32)
6466
set(CMAKE_EXE_LINKER_FLAGS_TRITONRELBUILDWITHASSERTS "/debug:fastlink /INCREMENTAL")
6567
set(CMAKE_MODULE_LINKER_FLAGS_TRITONRELBUILDWITHASSERTS "/debug:fastlink /INCREMENTAL")
6668
set(CMAKE_SHARED_LINKER_FLAGS_TRITONRELBUILDWITHASSERTS "/debug:fastlink /INCREMENTAL")
69+
set(LLVM_BUILD_SHARED_LIBS "0")
6770
else()
6871
message(FATAL_ERROR "Unsupported compiler")
6972
endif()

Makefile

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ test-unit: all
4343
$(PYTEST) --tb=short -vs python/examples/gluon/01-attention-forward.py
4444
TRITON_ALWAYS_COMPILE=1 TRITON_DISABLE_LINE_INFO=0 LLVM_PASS_PLUGIN_PATH=python/triton/instrumentation/libGPUInstrumentationTestLib.so \
4545
$(PYTEST) --capture=tee-sys -rfs -vvv python/test/unit/instrumentation/test_gpuhello.py
46+
TRITON_PASS_PLUGIN_PATH=python/triton/plugins/libTritonPluginsTestLib.so \
47+
$(PYTEST) -vvv python/test/unit/plugins/test_plugin.py
4648
$(PYTEST) --tb=short -s -n $(NUM_PROCS) python/test/gluon
4749

4850
.PHONY: test-distributed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ def inspect_stages(_self, stages, options, language, capability):
272272
# inspect or modify add_stages here
273273
triton.knobs.runtime.add_stages_inspection_hook = inspect_stages
274274
```
275-
275+
Examples of how to use this for out of tree plugin passes is [here](lib/Plugins/README.md)
276276
# Changelog
277277

278278
Version 2.0 is out! New features include:

bin/RegisterTritonDialects.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,9 @@
5555
#include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h"
5656
#include "mlir/Conversion/UBToLLVM/UBToLLVM.h"
5757

58+
#include "triton/Tools/PluginUtils.h"
59+
#include "triton/Tools/Sys/GetEnv.hpp"
60+
5861
namespace mlir {
5962
namespace test {
6063
namespace intel {
@@ -165,6 +168,21 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
165168
mlir::triton::proton::gpu::registerScheduleBufferStorePass();
166169
mlir::triton::proton::gpu::registerAddSchedBarriersPass();
167170

171+
// Plugin passes
172+
if (std::string filename =
173+
mlir::triton::tools::getStrEnv("TRITON_PASS_PLUGIN_PATH");
174+
!filename.empty()) {
175+
176+
TritonPlugin TP(filename);
177+
std::vector<const char *> passNames;
178+
if (auto result = TP.getPassHandles(passNames); !result)
179+
llvm::report_fatal_error(result.takeError());
180+
181+
for (const char *passName : passNames)
182+
if (auto result = TP.registerPass(passName); !result)
183+
llvm::report_fatal_error(result.takeError());
184+
}
185+
168186
registry.insert<
169187
mlir::triton::TritonDialect, mlir::cf::ControlFlowDialect,
170188
mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect,

include/triton/Tools/PluginUtils.h

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
#ifndef TRITON_PLUGIN_UTILS_H
2+
#define TRITON_PLUGIN_UTILS_H
3+
4+
#include "mlir/Pass/PassManager.h"
5+
#include "llvm/Support/DynamicLibrary.h"
6+
#include "llvm/Support/Error.h"
7+
#include <cstdint>
8+
9+
extern "C" {
10+
enum TritonPluginResult {
11+
TP_SUCCESS = 0,
12+
TP_GENERIC_FAILURE = 1,
13+
};
14+
};
15+
16+
#if defined(_WIN32)
17+
#define EXPORT_FUNC __declspec(dllexport)
18+
#else
19+
#define EXPORT_FUNC __attribute__((visibility("default")))
20+
#endif
21+
22+
#define TRITON_PLUGIN_API extern "C" EXPORT_FUNC TritonPluginResult
23+
24+
struct TritonPlugin {
25+
TritonPlugin() = delete;
26+
TritonPlugin(std::string filename) : filename(filename) {}
27+
28+
private:
29+
using enumeratePyBindHandlesType =
30+
std::function<TritonPluginResult(uint32_t *, const char **)>;
31+
using enumeratePyBindHandlesCType = TritonPluginResult (*)(uint32_t *,
32+
const char **);
33+
34+
// Put enumerate API names here, these can be involved with
35+
// enumeratePyBindHandles
36+
const std::string ENUMERATE_PASSES = "tritonEnumeratePluginPasses";
37+
38+
const std::string ADD_PASS = "tritonAddPluginPass";
39+
using addPassType =
40+
std::function<TritonPluginResult(mlir::PassManager *, const char *)>;
41+
using addPassCType = TritonPluginResult (*)(mlir::PassManager *,
42+
const char *);
43+
44+
const std::string REGISTER_PASS = "tritonRegisterPluginPass";
45+
using registerPassType = std::function<TritonPluginResult(const char *)>;
46+
using registerPassCType = TritonPluginResult (*)(const char *);
47+
48+
llvm::Error checkLibraryValid(const std::string &error) const;
49+
50+
llvm::Expected<intptr_t> getAddressOfSymbol(const std::string &symbol) const;
51+
52+
template <typename T, typename U>
53+
llvm::Expected<T> getAPI(const std::string &symbol) const {
54+
llvm::Expected<intptr_t> getDetailsFn = getAddressOfSymbol(symbol);
55+
if (auto Err = getDetailsFn.takeError()) {
56+
return Err;
57+
}
58+
auto func = reinterpret_cast<U>(*getDetailsFn);
59+
return func;
60+
}
61+
62+
llvm::Expected<TritonPluginResult> checkAPIResult(TritonPluginResult result,
63+
const char *handle) const;
64+
llvm::Expected<TritonPluginResult>
65+
enumeratePyBindHandles(enumeratePyBindHandlesType &enumeratePyBindHandles,
66+
std::vector<const char *> &passNames);
67+
68+
public:
69+
std::runtime_error err2exp(llvm::Error Err);
70+
71+
llvm::Error loadPlugin();
72+
73+
llvm::Expected<TritonPluginResult>
74+
getPassHandles(std::vector<const char *> &handles);
75+
76+
llvm::Expected<TritonPluginResult> addPass(mlir::PassManager *pm,
77+
const char *passHandle);
78+
79+
llvm::Expected<TritonPluginResult> registerPass(const char *passHandle);
80+
81+
private:
82+
std::string filename = "";
83+
mutable llvm::sys::DynamicLibrary library;
84+
enumeratePyBindHandlesType enumeratePassesAPI;
85+
addPassType addPassAPI;
86+
registerPassType registerPassAPI;
87+
bool isLoaded = false;
88+
};
89+
90+
#endif // TRITON_PLUGIN_UTILS_H

include/triton/Tools/Sys/GetEnv.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ inline const std::set<std::string> CACHE_INVALIDATING_ENV_VARS = {
4545
"TRITON_F32_DEFAULT",
4646
"TRITON_PREFER_TMEM_16x256_LAYOUT",
4747
"TRITON_ENABLE_EXPERIMENTAL_CONSAN",
48+
"TRITON_PASS_PLUGIN_PATH",
4849
"TRITON_INTEL_2DBLOCK_ASSERT",
4950
"TRITON_INTEL_AGGRESSIVE_DPAS_REUSE",
5051
"TRITON_INTEL_ENABLE_BLOCK_IO_ALL_LAYOUTS",

lib/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,4 @@ add_subdirectory(Conversion)
33
add_subdirectory(Dialect)
44
add_subdirectory(Target)
55
add_subdirectory(Tools)
6-
add_subdirectory(Instrumentation)
6+
add_subdirectory(Plugins)

lib/Instrumentation/CMakeLists.txt

Lines changed: 0 additions & 47 deletions
This file was deleted.

lib/Instrumentation/PrintLoadStoreMemSpaces.cpp

Lines changed: 0 additions & 106 deletions
This file was deleted.

0 commit comments

Comments
 (0)