Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ pytest.ini
# Instrumentation
python/triton/instrumentation

# MLIR Plugin
python/triton/plugins

# Python caches
__pycache__/
*.py[cod]
Expand Down
3 changes: 3 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ option(TRITON_BUILD_PYTHON_MODULE "Build Python Triton bindings" OFF)
option(TRITON_BUILD_PROTON "Build the Triton Proton profiler" ON)
option(TRITON_BUILD_UT "Build C++ Triton Unit Tests" ON)
option(TRITON_BUILD_WITH_CCACHE "Build with ccache (if available)" ON)
option(LLVM_BUILD_SHARED_LIBS
"Build all libraries as shared libraries instead of static" OFF)
set(TRITON_CODEGEN_BACKENDS "" CACHE STRING "Enable different codegen backends")

if(TRITON_BUILD_WITH_CCACHE)
Expand Down Expand Up @@ -64,6 +66,7 @@ if(WIN32)
set(CMAKE_EXE_LINKER_FLAGS_TRITONRELBUILDWITHASSERTS "/debug:fastlink /INCREMENTAL")
set(CMAKE_MODULE_LINKER_FLAGS_TRITONRELBUILDWITHASSERTS "/debug:fastlink /INCREMENTAL")
set(CMAKE_SHARED_LINKER_FLAGS_TRITONRELBUILDWITHASSERTS "/debug:fastlink /INCREMENTAL")
set(LLVM_BUILD_SHARED_LIBS "0")
else()
message(FATAL_ERROR "Unsupported compiler")
endif()
Expand Down
2 changes: 2 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ test-unit: all
$(PYTEST) --tb=short -vs python/examples/gluon/01-attention-forward.py
TRITON_ALWAYS_COMPILE=1 TRITON_DISABLE_LINE_INFO=0 LLVM_PASS_PLUGIN_PATH=python/triton/instrumentation/libGPUInstrumentationTestLib.so \
$(PYTEST) --capture=tee-sys -rfs -vvv python/test/unit/instrumentation/test_gpuhello.py
TRITON_PASS_PLUGIN_PATH=python/triton/plugins/libTritonPluginsTestLib.so \
$(PYTEST) -vvv python/test/unit/plugins/test_plugin.py
$(PYTEST) --tb=short -s -n $(NUM_PROCS) python/test/gluon

.PHONY: test-distributed
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ def inspect_stages(_self, stages, options, language, capability):
# inspect or modify add_stages here
triton.knobs.runtime.add_stages_inspection_hook = inspect_stages
```

Examples of how to use this for out of tree plugin passes is [here](lib/Plugins/README.md)
# Changelog

Version 2.0 is out! New features include:
Expand Down
18 changes: 18 additions & 0 deletions bin/RegisterTritonDialects.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@
#include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h"
#include "mlir/Conversion/UBToLLVM/UBToLLVM.h"

#include "triton/Tools/PluginUtils.h"
#include "triton/Tools/Sys/GetEnv.hpp"

namespace mlir {
namespace test {
namespace intel {
Expand Down Expand Up @@ -165,6 +168,21 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
mlir::triton::proton::gpu::registerScheduleBufferStorePass();
mlir::triton::proton::gpu::registerAddSchedBarriersPass();

// Plugin passes
if (std::string filename =
mlir::triton::tools::getStrEnv("TRITON_PASS_PLUGIN_PATH");
!filename.empty()) {

TritonPlugin TP(filename);
std::vector<const char *> passNames;
if (auto result = TP.getPassHandles(passNames); !result)
llvm::report_fatal_error(result.takeError());

for (const char *passName : passNames)
if (auto result = TP.registerPass(passName); !result)
llvm::report_fatal_error(result.takeError());
}

registry.insert<
mlir::triton::TritonDialect, mlir::cf::ControlFlowDialect,
mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect,
Expand Down
90 changes: 90 additions & 0 deletions include/triton/Tools/PluginUtils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
#ifndef TRITON_PLUGIN_UTILS_H
#define TRITON_PLUGIN_UTILS_H

#include "mlir/Pass/PassManager.h"
#include "llvm/Support/DynamicLibrary.h"
#include "llvm/Support/Error.h"
#include <cstdint>

extern "C" {
enum TritonPluginResult {
TP_SUCCESS = 0,
TP_GENERIC_FAILURE = 1,
};
};

#if defined(_WIN32)
#define EXPORT_FUNC __declspec(dllexport)
#else
#define EXPORT_FUNC __attribute__((visibility("default")))
#endif

#define TRITON_PLUGIN_API extern "C" EXPORT_FUNC TritonPluginResult

struct TritonPlugin {
TritonPlugin() = delete;
TritonPlugin(std::string filename) : filename(filename) {}

private:
using enumeratePyBindHandlesType =
std::function<TritonPluginResult(uint32_t *, const char **)>;
using enumeratePyBindHandlesCType = TritonPluginResult (*)(uint32_t *,
const char **);

// Put enumerate API names here, these can be involved with
// enumeratePyBindHandles
const std::string ENUMERATE_PASSES = "tritonEnumeratePluginPasses";

const std::string ADD_PASS = "tritonAddPluginPass";
using addPassType =
std::function<TritonPluginResult(mlir::PassManager *, const char *)>;
using addPassCType = TritonPluginResult (*)(mlir::PassManager *,
const char *);

const std::string REGISTER_PASS = "tritonRegisterPluginPass";
using registerPassType = std::function<TritonPluginResult(const char *)>;
using registerPassCType = TritonPluginResult (*)(const char *);

llvm::Error checkLibraryValid(const std::string &error) const;

llvm::Expected<intptr_t> getAddressOfSymbol(const std::string &symbol) const;

template <typename T, typename U>
llvm::Expected<T> getAPI(const std::string &symbol) const {
llvm::Expected<intptr_t> getDetailsFn = getAddressOfSymbol(symbol);
if (auto Err = getDetailsFn.takeError()) {
return Err;
}
auto func = reinterpret_cast<U>(*getDetailsFn);
return func;
}

llvm::Expected<TritonPluginResult> checkAPIResult(TritonPluginResult result,
const char *handle) const;
llvm::Expected<TritonPluginResult>
enumeratePyBindHandles(enumeratePyBindHandlesType &enumeratePyBindHandles,
std::vector<const char *> &passNames);

public:
std::runtime_error err2exp(llvm::Error Err);

llvm::Error loadPlugin();

llvm::Expected<TritonPluginResult>
getPassHandles(std::vector<const char *> &handles);

llvm::Expected<TritonPluginResult> addPass(mlir::PassManager *pm,
const char *passHandle);

llvm::Expected<TritonPluginResult> registerPass(const char *passHandle);

private:
std::string filename = "";
mutable llvm::sys::DynamicLibrary library;
enumeratePyBindHandlesType enumeratePassesAPI;
addPassType addPassAPI;
registerPassType registerPassAPI;
bool isLoaded = false;
};

#endif // TRITON_PLUGIN_UTILS_H
1 change: 1 addition & 0 deletions include/triton/Tools/Sys/GetEnv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ inline const std::set<std::string> CACHE_INVALIDATING_ENV_VARS = {
"TRITON_F32_DEFAULT",
"TRITON_PREFER_TMEM_16x256_LAYOUT",
"TRITON_ENABLE_EXPERIMENTAL_CONSAN",
"TRITON_PASS_PLUGIN_PATH",
"TRITON_INTEL_2DBLOCK_ASSERT",
"TRITON_INTEL_AGGRESSIVE_DPAS_REUSE",
"TRITON_INTEL_ENABLE_BLOCK_IO_ALL_LAYOUTS",
Expand Down
2 changes: 1 addition & 1 deletion lib/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@ add_subdirectory(Conversion)
add_subdirectory(Dialect)
add_subdirectory(Target)
add_subdirectory(Tools)
add_subdirectory(Instrumentation)
add_subdirectory(Plugins)
47 changes: 0 additions & 47 deletions lib/Instrumentation/CMakeLists.txt

This file was deleted.

106 changes: 0 additions & 106 deletions lib/Instrumentation/PrintLoadStoreMemSpaces.cpp

This file was deleted.

43 changes: 43 additions & 0 deletions lib/Plugins/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls -name Plugins)
add_public_tablegen_target(TritonPluginsIncGen)

llvm_canonicalize_cmake_booleans(
MLIR_ENABLE_BINDINGS_PYTHON
)

set(TRITON_PLUGIN_PASSES
TritonPluginsTestLib
)

set(TritonPluginsTestLib_SOURCES
TritonPlugin.cpp
)


foreach( plugin ${TRITON_PLUGIN_PASSES} )
add_mlir_library(${plugin}
${${plugin}_SOURCES}
SHARED

ADDITIONAL_HEADER_DIRS
${PROJECT_BINARY_DIR}/lib

DEPENDS
TritonTableGen
TritonCanonicalizeIncGen
TritonPluginsIncGen

LINK_LIBS PUBLIC
MLIRPass
LLVMSupport
MLIRSupport
"$<$<PLATFORM_ID:Darwin>:-undefined dynamic_lookup>"
)

set_target_properties(${plugin} PROPERTIES
LIBRARY_OUTPUT_DIRECTORY
"${CMAKE_LIBRARY_OUTPUT_DIRECTORY}/../plugins")

target_compile_options(${plugin} PRIVATE -fvisibility=hidden)
endforeach()
9 changes: 9 additions & 0 deletions lib/Plugins/Passes.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
#ifndef TRITONGPU_PLUGIN_PASSES
#define TRITONGPU_PLUGIN_PASSES

include "mlir/Pass/PassBase.td"

def TritonGPUMLIRPlugin : Pass<"tritongpu-plugin", "mlir::ModuleOp"> {
let summary = "Triton MLIR Plugin Pass";
}
#endif
Loading