Skip to content

Commit 279573d

Browse files
authored
Merge branch 'main' into tkuczynski/fix_benchmarks_bmg
2 parents facac0f + b3ce5fb commit 279573d

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

61 files changed

+2839
-2171
lines changed

.github/workflows/inductor-tests.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ env:
5252
inductor/test_select_algorithm.py
5353
inductor/test_max_autotune.py
5454
inductor/test_compile_subprocess.py
55+
inductor/test_analysis.py
5556
5657
jobs:
5758
compute-params:

.github/workflows/triton-benchmarks.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,9 @@ env:
6565
VERIFY: ${{ (github.event_name == 'pull_request' || github.event_name == 'schedule' || inputs.verify) && '1' || '0' }}
6666
TAG: ${{ inputs.tag || (github.event_name == 'pull_request' && format('pr-{0}', github.event.number)) || (github.event_name == 'schedule' && 'ci') || 'test' }}
6767
N_RUNS: ${{ inputs.n_runs || '1' }}
68+
# FIXME: Enable Level Zero v2 loader once it's stable.
69+
# https://github.com/intel/intel-xpu-backend-for-triton/issues/5572
70+
UR_LOADER_USE_LEVEL_ZERO_V2: "0"
6871

6972
jobs:
7073
build:

.github/workflows/try-latest-pytorch.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ jobs:
9696
inductor/test_select_algorithm.py
9797
inductor/test_max_autotune.py
9898
inductor/test_compile_subprocess.py
99+
inductor/test_analysis.py
99100
runner_label: ${{ inputs.runner_label }}
100101
python_version: "3.10"
101102

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,9 @@ pytest.ini
5151
# Instrumentation
5252
python/triton/instrumentation
5353

54+
# MLIR Plugin
55+
python/triton/plugins
56+
5457
# Python caches
5558
__pycache__/
5659
*.py[cod]

CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ option(TRITON_BUILD_PYTHON_MODULE "Build Python Triton bindings" OFF)
2020
option(TRITON_BUILD_PROTON "Build the Triton Proton profiler" ON)
2121
option(TRITON_BUILD_UT "Build C++ Triton Unit Tests" ON)
2222
option(TRITON_BUILD_WITH_CCACHE "Build with ccache (if available)" ON)
23+
option(LLVM_BUILD_SHARED_LIBS
24+
"Build all libraries as shared libraries instead of static" OFF)
2325
set(TRITON_CODEGEN_BACKENDS "" CACHE STRING "Enable different codegen backends")
2426

2527
if(TRITON_BUILD_WITH_CCACHE)
@@ -64,6 +66,7 @@ if(WIN32)
6466
set(CMAKE_EXE_LINKER_FLAGS_TRITONRELBUILDWITHASSERTS "/debug:fastlink /INCREMENTAL")
6567
set(CMAKE_MODULE_LINKER_FLAGS_TRITONRELBUILDWITHASSERTS "/debug:fastlink /INCREMENTAL")
6668
set(CMAKE_SHARED_LINKER_FLAGS_TRITONRELBUILDWITHASSERTS "/debug:fastlink /INCREMENTAL")
69+
set(LLVM_BUILD_SHARED_LIBS "0")
6770
else()
6871
message(FATAL_ERROR "Unsupported compiler")
6972
endif()

Makefile

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ test-unit: all
4343
$(PYTEST) --tb=short -vs python/examples/gluon/01-attention-forward.py
4444
TRITON_ALWAYS_COMPILE=1 TRITON_DISABLE_LINE_INFO=0 LLVM_PASS_PLUGIN_PATH=python/triton/instrumentation/libGPUInstrumentationTestLib.so \
4545
$(PYTEST) --capture=tee-sys -rfs -vvv python/test/unit/instrumentation/test_gpuhello.py
46+
TRITON_PASS_PLUGIN_PATH=python/triton/plugins/libTritonPluginsTestLib.so \
47+
$(PYTEST) -vvv python/test/unit/plugins/test_plugin.py
4648
$(PYTEST) --tb=short -s -n $(NUM_PROCS) python/test/gluon
4749

4850
.PHONY: test-distributed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ def inspect_stages(_self, stages, options, language, capability):
272272
# inspect or modify add_stages here
273273
triton.knobs.runtime.add_stages_inspection_hook = inspect_stages
274274
```
275-
275+
Examples of how to use this for out of tree plugin passes is [here](lib/Plugins/README.md)
276276
# Changelog
277277

278278
Version 2.0 is out! New features include:

bin/RegisterTritonDialects.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,9 @@
5555
#include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h"
5656
#include "mlir/Conversion/UBToLLVM/UBToLLVM.h"
5757

58+
#include "triton/Tools/PluginUtils.h"
59+
#include "triton/Tools/Sys/GetEnv.hpp"
60+
5861
namespace mlir {
5962
namespace test {
6063
namespace intel {
@@ -165,6 +168,21 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
165168
mlir::triton::proton::gpu::registerScheduleBufferStorePass();
166169
mlir::triton::proton::gpu::registerAddSchedBarriersPass();
167170

171+
// Plugin passes
172+
if (std::string filename =
173+
mlir::triton::tools::getStrEnv("TRITON_PASS_PLUGIN_PATH");
174+
!filename.empty()) {
175+
176+
TritonPlugin TP(filename);
177+
std::vector<const char *> passNames;
178+
if (auto result = TP.getPassHandles(passNames); !result)
179+
llvm::report_fatal_error(result.takeError());
180+
181+
for (const char *passName : passNames)
182+
if (auto result = TP.registerPass(passName); !result)
183+
llvm::report_fatal_error(result.takeError());
184+
}
185+
168186
registry.insert<
169187
mlir::triton::TritonDialect, mlir::cf::ControlFlowDialect,
170188
mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect,

include/triton/Tools/PluginUtils.h

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
#ifndef TRITON_PLUGIN_UTILS_H
2+
#define TRITON_PLUGIN_UTILS_H
3+
4+
#include "mlir/Pass/PassManager.h"
5+
#include "llvm/Support/DynamicLibrary.h"
6+
#include "llvm/Support/Error.h"
7+
#include <cstdint>
8+
9+
extern "C" {
10+
enum TritonPluginResult {
11+
TP_SUCCESS = 0,
12+
TP_GENERIC_FAILURE = 1,
13+
};
14+
};
15+
16+
#if defined(_WIN32)
17+
#define EXPORT_FUNC __declspec(dllexport)
18+
#else
19+
#define EXPORT_FUNC __attribute__((visibility("default")))
20+
#endif
21+
22+
#define TRITON_PLUGIN_API extern "C" EXPORT_FUNC TritonPluginResult
23+
24+
struct TritonPlugin {
25+
TritonPlugin() = delete;
26+
TritonPlugin(std::string filename) : filename(filename) {}
27+
28+
private:
29+
using enumeratePyBindHandlesType =
30+
std::function<TritonPluginResult(uint32_t *, const char **)>;
31+
using enumeratePyBindHandlesCType = TritonPluginResult (*)(uint32_t *,
32+
const char **);
33+
34+
// Put enumerate API names here, these can be involved with
35+
// enumeratePyBindHandles
36+
const std::string ENUMERATE_PASSES = "tritonEnumeratePluginPasses";
37+
38+
const std::string ADD_PASS = "tritonAddPluginPass";
39+
using addPassType =
40+
std::function<TritonPluginResult(mlir::PassManager *, const char *)>;
41+
using addPassCType = TritonPluginResult (*)(mlir::PassManager *,
42+
const char *);
43+
44+
const std::string REGISTER_PASS = "tritonRegisterPluginPass";
45+
using registerPassType = std::function<TritonPluginResult(const char *)>;
46+
using registerPassCType = TritonPluginResult (*)(const char *);
47+
48+
llvm::Error checkLibraryValid(const std::string &error) const;
49+
50+
llvm::Expected<intptr_t> getAddressOfSymbol(const std::string &symbol) const;
51+
52+
template <typename T, typename U>
53+
llvm::Expected<T> getAPI(const std::string &symbol) const {
54+
llvm::Expected<intptr_t> getDetailsFn = getAddressOfSymbol(symbol);
55+
if (auto Err = getDetailsFn.takeError()) {
56+
return Err;
57+
}
58+
auto func = reinterpret_cast<U>(*getDetailsFn);
59+
return func;
60+
}
61+
62+
llvm::Expected<TritonPluginResult> checkAPIResult(TritonPluginResult result,
63+
const char *handle) const;
64+
llvm::Expected<TritonPluginResult>
65+
enumeratePyBindHandles(enumeratePyBindHandlesType &enumeratePyBindHandles,
66+
std::vector<const char *> &passNames);
67+
68+
public:
69+
std::runtime_error err2exp(llvm::Error Err);
70+
71+
llvm::Error loadPlugin();
72+
73+
llvm::Expected<TritonPluginResult>
74+
getPassHandles(std::vector<const char *> &handles);
75+
76+
llvm::Expected<TritonPluginResult> addPass(mlir::PassManager *pm,
77+
const char *passHandle);
78+
79+
llvm::Expected<TritonPluginResult> registerPass(const char *passHandle);
80+
81+
private:
82+
std::string filename = "";
83+
mutable llvm::sys::DynamicLibrary library;
84+
enumeratePyBindHandlesType enumeratePassesAPI;
85+
addPassType addPassAPI;
86+
registerPassType registerPassAPI;
87+
bool isLoaded = false;
88+
};
89+
90+
#endif // TRITON_PLUGIN_UTILS_H

include/triton/Tools/Sys/GetEnv.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ inline const std::set<std::string> CACHE_INVALIDATING_ENV_VARS = {
4545
"TRITON_F32_DEFAULT",
4646
"TRITON_PREFER_TMEM_16x256_LAYOUT",
4747
"TRITON_ENABLE_EXPERIMENTAL_CONSAN",
48+
"TRITON_PASS_PLUGIN_PATH",
4849
"TRITON_INTEL_2DBLOCK_ASSERT",
4950
"TRITON_INTEL_AGGRESSIVE_DPAS_REUSE",
5051
"TRITON_INTEL_ENABLE_BLOCK_IO_ALL_LAYOUTS",

0 commit comments

Comments
 (0)