From d907aedf9b30be497d09ffa34bb9fb57acb78c3a Mon Sep 17 00:00:00 2001 From: Corbin Robeck Date: Fri, 21 Nov 2025 23:27:02 -0500 Subject: [PATCH 1/5] [BACKEND] Add support for out of tree TTIR/TTGIR passes (#8401) Follow on to: https://github.com/triton-lang/triton/pull/8137. This PR coupled with the Triton pass pipeline override adds the ability to call out of tree TTIR and TTGIR MLIR passes. Integration has been done with both the Python and triton-opt interfaces. Co-authored w/ @plotfi --------- Co-authored-by: Puyan Lotfi --- .gitignore | 3 + CMakeLists.txt | 3 + Makefile | 2 + README.md | 2 +- bin/RegisterTritonDialects.h | 18 ++ include/triton/Tools/PluginUtils.h | 84 ++++++ include/triton/Tools/Sys/GetEnv.hpp | 1 + lib/CMakeLists.txt | 1 + lib/Plugins/CMakeLists.txt | 43 +++ lib/Plugins/Passes.td | 9 + lib/Plugins/README.md | 311 ++++++++++++++++++++++ lib/Plugins/TritonPlugin.cpp | 87 ++++++ lib/Tools/CMakeLists.txt | 1 + lib/Tools/PluginUtils.cpp | 115 ++++++++ python/src/passes.cc | 28 ++ python/test/unit/plugins/custom_stages.py | 38 +++ python/test/unit/plugins/test_plugin.py | 43 +++ python/triton/compiler/compiler.py | 3 + python/triton/runtime/jit.py | 6 + scripts/build-llvm-project.sh | 2 + setup.py | 5 + test/Plugins/test-plugin.mlir | 30 +++ test/lit.cfg.py | 4 + test/lit.site.cfg.py.in | 1 + 24 files changed, 839 insertions(+), 1 deletion(-) create mode 100644 include/triton/Tools/PluginUtils.h create mode 100644 lib/Plugins/CMakeLists.txt create mode 100644 lib/Plugins/Passes.td create mode 100644 lib/Plugins/README.md create mode 100644 lib/Plugins/TritonPlugin.cpp create mode 100644 lib/Tools/PluginUtils.cpp create mode 100644 python/test/unit/plugins/custom_stages.py create mode 100644 python/test/unit/plugins/test_plugin.py create mode 100644 test/Plugins/test-plugin.mlir diff --git a/.gitignore b/.gitignore index 705de071cf..da3ac1c864 100644 --- a/.gitignore +++ b/.gitignore @@ -43,6 +43,9 @@ pytest.ini # Instrumentation python/triton/instrumentation +# MLIR Plugin +python/triton/plugins + # Python caches __pycache__/ *.py[cod] diff --git a/CMakeLists.txt b/CMakeLists.txt index c9620e3f4c..a534cf4307 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -20,6 +20,8 @@ option(TRITON_BUILD_PYTHON_MODULE "Build Python Triton bindings" OFF) option(TRITON_BUILD_PROTON "Build the Triton Proton profiler" ON) option(TRITON_BUILD_UT "Build C++ Triton Unit Tests" ON) option(TRITON_BUILD_WITH_CCACHE "Build with ccache (if available)" ON) +option(LLVM_BUILD_SHARED_LIBS + "Build all libraries as shared libraries instead of static" OFF) set(TRITON_CODEGEN_BACKENDS "" CACHE STRING "Enable different codegen backends") if(TRITON_BUILD_WITH_CCACHE) @@ -60,6 +62,7 @@ else() set(CMAKE_EXE_LINKER_FLAGS_TRITONRELBUILDWITHASSERTS "/debug:fastlink /INCREMENTAL") set(CMAKE_MODULE_LINKER_FLAGS_TRITONRELBUILDWITHASSERTS "/debug:fastlink /INCREMENTAL") set(CMAKE_SHARED_LINKER_FLAGS_TRITONRELBUILDWITHASSERTS "/debug:fastlink /INCREMENTAL") + set(LLVM_BUILD_SHARED_LIBS "0") endif() # Default build type diff --git a/Makefile b/Makefile index 53960a6a04..43ed4adcb5 100644 --- a/Makefile +++ b/Makefile @@ -43,6 +43,8 @@ test-unit: all $(PYTEST) --tb=short -vs python/examples/gluon/01-attention-forward.py TRITON_ALWAYS_COMPILE=1 TRITON_DISABLE_LINE_INFO=0 LLVM_PASS_PLUGIN_PATH=python/triton/instrumentation/libGPUInstrumentationTestLib.so \ $(PYTEST) --capture=tee-sys -rfs -vvv python/test/unit/instrumentation/test_gpuhello.py + TRITON_PASS_PLUGIN_PATH=python/triton/plugins/libTritonPluginsTestLib.so \ + $(PYTEST) -vvv python/test/unit/plugins/test_plugin.py $(PYTEST) --tb=short -s -n $(NUM_PROCS) python/test/gluon .PHONY: test-distributed diff --git a/README.md b/README.md index d235dab4fd..b069daa989 100644 --- a/README.md +++ b/README.md @@ -272,7 +272,7 @@ def inspect_stages(_self, stages, options, language, capability): # inspect or modify add_stages here triton.knobs.runtime.add_stages_inspection_hook = inspect_stages ``` - +Examples of how to use this for out of tree plugin passes is [here](lib/Plugins/README.md) # Changelog Version 2.0 is out! New features include: diff --git a/bin/RegisterTritonDialects.h b/bin/RegisterTritonDialects.h index 48049f0268..0a11484b43 100644 --- a/bin/RegisterTritonDialects.h +++ b/bin/RegisterTritonDialects.h @@ -45,6 +45,9 @@ #include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h" #include "mlir/Conversion/UBToLLVM/UBToLLVM.h" +#include "triton/Tools/PluginUtils.h" +#include "triton/Tools/Sys/GetEnv.hpp" + namespace mlir { namespace test { void registerTestAliasPass(); @@ -136,6 +139,21 @@ inline void registerTritonDialects(mlir::DialectRegistry ®istry) { mlir::triton::proton::gpu::registerScheduleBufferStorePass(); mlir::triton::proton::gpu::registerAddSchedBarriersPass(); + // Plugin passes + if (std::string filename = + mlir::triton::tools::getStrEnv("TRITON_PASS_PLUGIN_PATH"); + !filename.empty()) { + + TritonPlugin TP(filename); + std::vector passNames; + if (auto result = TP.getPassHandles(passNames); !result) + llvm::report_fatal_error(result.takeError()); + + for (const char *passName : passNames) + if (auto result = TP.registerPass(passName); !result) + llvm::report_fatal_error(result.takeError()); + } + registry.insert< mlir::triton::TritonDialect, mlir::cf::ControlFlowDialect, mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect, diff --git a/include/triton/Tools/PluginUtils.h b/include/triton/Tools/PluginUtils.h new file mode 100644 index 0000000000..d0b4de6adb --- /dev/null +++ b/include/triton/Tools/PluginUtils.h @@ -0,0 +1,84 @@ +#ifndef TRITON_PLUGIN_UTILS_H +#define TRITON_PLUGIN_UTILS_H + +#include "mlir/Pass/PassManager.h" +#include "llvm/Support/DynamicLibrary.h" +#include "llvm/Support/Error.h" +#include + +extern "C" { +enum TritonPluginResult { + TP_SUCCESS = 0, + TP_GENERIC_FAILURE = 1, +}; +}; +#define TRITON_PLUGIN_API \ + extern "C" __attribute__((visibility("default"))) TritonPluginResult + +struct TritonPlugin { + TritonPlugin() = delete; + TritonPlugin(std::string filename) : filename(filename) {} + +private: + using enumeratePyBindHandlesType = + std::function; + using enumeratePyBindHandlesCType = TritonPluginResult (*)(uint32_t *, + const char **); + + // Put enumerate API names here, these can be involved with + // enumeratePyBindHandles + const std::string ENUMERATE_PASSES = "tritonEnumeratePluginPasses"; + + const std::string ADD_PASS = "tritonAddPluginPass"; + using addPassType = + std::function; + using addPassCType = TritonPluginResult (*)(mlir::PassManager *, + const char *); + + const std::string REGISTER_PASS = "tritonRegisterPluginPass"; + using registerPassType = std::function; + using registerPassCType = TritonPluginResult (*)(const char *); + + llvm::Error checkLibraryValid(const std::string &error) const; + + llvm::Expected getAddressOfSymbol(const std::string &symbol) const; + + template + llvm::Expected getAPI(const std::string &symbol) const { + llvm::Expected getDetailsFn = getAddressOfSymbol(symbol); + if (auto Err = getDetailsFn.takeError()) { + return Err; + } + auto func = reinterpret_cast(*getDetailsFn); + return func; + } + + llvm::Expected checkAPIResult(TritonPluginResult result, + const char *handle) const; + llvm::Expected + enumeratePyBindHandles(enumeratePyBindHandlesType &enumeratePyBindHandles, + std::vector &passNames); + +public: + std::runtime_error err2exp(llvm::Error Err); + + llvm::Error loadPlugin(); + + llvm::Expected + getPassHandles(std::vector &handles); + + llvm::Expected addPass(mlir::PassManager *pm, + const char *passHandle); + + llvm::Expected registerPass(const char *passHandle); + +private: + std::string filename = ""; + mutable llvm::sys::DynamicLibrary library; + enumeratePyBindHandlesType enumeratePassesAPI; + addPassType addPassAPI; + registerPassType registerPassAPI; + bool isLoaded = false; +}; + +#endif // TRITON_PLUGIN_UTILS_H diff --git a/include/triton/Tools/Sys/GetEnv.hpp b/include/triton/Tools/Sys/GetEnv.hpp index ed31dac41d..21cdf0de0e 100644 --- a/include/triton/Tools/Sys/GetEnv.hpp +++ b/include/triton/Tools/Sys/GetEnv.hpp @@ -45,6 +45,7 @@ inline const std::set CACHE_INVALIDATING_ENV_VARS = { "TRITON_F32_DEFAULT", "TRITON_PREFER_TMEM_16x256_LAYOUT", "TRITON_ENABLE_EXPERIMENTAL_CONSAN", + "TRITON_PASS_PLUGIN_PATH" // clang-format on }; diff --git a/lib/CMakeLists.txt b/lib/CMakeLists.txt index e8ae340f2d..9e14a86250 100644 --- a/lib/CMakeLists.txt +++ b/lib/CMakeLists.txt @@ -4,3 +4,4 @@ add_subdirectory(Dialect) add_subdirectory(Target) add_subdirectory(Tools) add_subdirectory(Instrumentation) +add_subdirectory(Plugins) diff --git a/lib/Plugins/CMakeLists.txt b/lib/Plugins/CMakeLists.txt new file mode 100644 index 0000000000..593335d222 --- /dev/null +++ b/lib/Plugins/CMakeLists.txt @@ -0,0 +1,43 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls -name Plugins) +add_public_tablegen_target(TritonPluginsIncGen) + +llvm_canonicalize_cmake_booleans( + MLIR_ENABLE_BINDINGS_PYTHON +) + +set(TRITON_PLUGIN_PASSES + TritonPluginsTestLib + ) + +set(TritonPluginsTestLib_SOURCES + TritonPlugin.cpp + ) + + +foreach( plugin ${TRITON_PLUGIN_PASSES} ) + add_mlir_library(${plugin} + ${${plugin}_SOURCES} + SHARED + + ADDITIONAL_HEADER_DIRS + ${PROJECT_BINARY_DIR}/lib + + DEPENDS + TritonTableGen + TritonCanonicalizeIncGen + TritonPluginsIncGen + + LINK_LIBS PUBLIC + MLIRPass + LLVMSupport + MLIRSupport + "$<$:-undefined dynamic_lookup>" + ) + + set_target_properties(${plugin} PROPERTIES + LIBRARY_OUTPUT_DIRECTORY + "${CMAKE_LIBRARY_OUTPUT_DIRECTORY}/../plugins") + + target_compile_options(${plugin} PRIVATE -fvisibility=hidden) +endforeach() diff --git a/lib/Plugins/Passes.td b/lib/Plugins/Passes.td new file mode 100644 index 0000000000..a8007a09e8 --- /dev/null +++ b/lib/Plugins/Passes.td @@ -0,0 +1,9 @@ +#ifndef TRITONGPU_PLUGIN_PASSES +#define TRITONGPU_PLUGIN_PASSES + +include "mlir/Pass/PassBase.td" + +def TritonGPUMLIRPlugin : Pass<"tritongpu-plugin", "mlir::ModuleOp"> { + let summary = "Triton MLIR Plugin Pass"; +} +#endif diff --git a/lib/Plugins/README.md b/lib/Plugins/README.md new file mode 100644 index 0000000000..a12b0c79ae --- /dev/null +++ b/lib/Plugins/README.md @@ -0,0 +1,311 @@ +# Triton TTIR and TTGIR Out of Tree Plugin Passes + +## Overview +Triton’s existing pass pipelines are assembled in the various extended compiler.py files that live in Triton’s backends. Currently when we want to insert +passes either for downstream optimizations, custom ops, or instrumentation it is required for the compiler.py file itself to be modified and all of Triton to be +recompiled. + +In order to allow for more downstream configurability we have implemented a custom MLIR level (TTIR and TTGIR) pass plugin and configuration system that allows for either +overriding the compiler.py pipeline entirely or inserting passes and custom ops through a compiler pipeline hook. Example use cases include: +- Custom ops and lowering passes +- Custom optimization passes +- Instrumentation and analysis passes +- Specialized per kernel passes (e.g. kernel/model specific warp specialization) + +Custom passes/ops are implemented as a shared library that is loaded by Triton at JIT compile/runtime. The plugins can be implement entirely out of tree or in the Triton source tree as +long as the libtriton.so is linked to the plugin and the Triton include passes are used to build the plugin. + +## Example 1: Developing a custom pass and running triton-opt to inspect the modified IR +``` bash +export LLVM_BUILD_SHARED_LIBS=1; make dev-install-llvm +TRITON_PASS_PLUGIN_PATH=/home/triton/python/triton/plugins/libTritonPluginsTestLib.so triton-opt -tritongpu-plugin test/Plugins/test-plugin.mlir +``` +``` MLIR +module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:80"} { + tt.func @foo() { + tt.return + } +} +``` + +After the out of tree pass runs, becomes: +``` MLIR +module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:80"} { + tt.func @bar() { + tt.return + } +} +``` +Function "foo" is renamed to "bar" by the out of tree pass. + +## Example 2: Inserting a new pass into the compiler pipeline +Let's take the following toy kernel example: +``` python +import torch +import os + +import triton +import triton.language as tl +from triton._C.libtriton import ir, passes +from triton import knobs + +DEVICE = triton.runtime.driver.active.get_active_torch_device() + +@triton.jit +def kernel(BLOCK_SIZE: tl.constexpr): + return + +if __name__ == '__main__': + + size = 98432 + x = torch.rand(size, device=DEVICE) + output = torch.empty_like(x) + n_elements = output.numel() + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) + + h = kernel[grid](BLOCK_SIZE=1024) + print(h.asm["ttgir"]) +``` + +Running as is will produce the expected output of printing the TTGIR of the kernel: +``` bash +python test.py +``` +``` MLIR +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @kernel() attributes {noinline = false} { + tt.return loc(#loc1) + } loc(#loc) +} loc(#loc) +#loc = loc("/home/triton/test.py":13:0) +#loc1 = loc("/home/triton/test.py":14:4) +``` + +Running same code but loading the plugin library also produces the same results since, while the plugin pass has been loaded and registered with the +pass manager it is not inserted into the compiler pass pipeline: + +``` bash +TRITON_PASS_PLUGIN_PATH=/home/triton/python/triton/plugins/libTritonPluginsTestLib.so python test.py +``` + +``` MLIR +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @kernel() attributes {noinline = false} { + tt.return loc(#loc1) + } loc(#loc) +} loc(#loc) +#loc = loc("/home/triton/test.py":13:0) +#loc1 = loc("/home/triton/test.py":14:4) +``` + +Finally, if we both load the plugin at runtime and insert the pass pipeline hook into the kernel code: + +``` python +import torch +import os + +import triton +import triton.language as tl +from triton._C.libtriton import ir, passes +from triton import knobs + +DEVICE = triton.runtime.driver.active.get_active_torch_device() + +@triton.jit +def kernel(BLOCK_SIZE: tl.constexpr): + return + +#These two methods must be implemented by the plugin +def get_key(): + return pathlib.Path(__file__).read_text() +def get_hash(): + return hashlib.sha256(get_key().encode('utf-8')).hexdigest() + +def inspect_stages_hook(self=None, stages=None, options=None, language=None, capability=None): + # If the hook is called with no arguments we assume were just after the key and hash and don't want to + # actually execute the pipeline yet. + # This no argument early return must be implemented. + if all(arg is None for arg in (stages, options, language, capability)): + return get_key(), get_hash() + + def make_ttir_wrapper(mod, metadata, opt, capability): + mod = self.make_ttir(mod, metadata, opt, capability) + pm = ir.pass_manager(mod.context) + pm.enable_debug() + passes.plugin.add_plugin(pm) + pm.run(mod, 'make_ttir_plugin') + return mod + + stages["ttir"] = lambda src, metadata: make_ttir_wrapper(src, metadata, options, capability) + + return get_key(), get_hash() + +if __name__ == '__main__': + + size = 98432 + x = torch.rand(size, device=DEVICE) + output = torch.empty_like(x) + n_elements = output.numel() + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) + + h = kernel[grid](BLOCK_SIZE=1024) + print(h.asm["ttgir"]) + + if "TRITON_PASS_PLUGIN_PATH" in os.environ: + knobs.runtime.add_stages_inspection_hook = inspect_stages_hook + h = kernel[grid](BLOCK_SIZE=1024) + print(h.asm["ttgir"]) + + # Unset the hook to go back to the standard pipeline + knobs.runtime.add_stages_inspection_hook = None + h = kernel[grid](BLOCK_SIZE=1024) + print(h.asm["ttgir"]) +``` + +``` bash +TRITON_PASS_PLUGIN_PATH=/home/triton/python/triton/plugins/libTritonPluginsTestLib.so python test.py +``` + +Shows the pass ran and modified the kernel name but only after the hook is set. Any kernels before the hook or after the hook is unset are left unchanged. + +``` MLIR +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @kernel() attributes {noinline = false} { + tt.return loc(#loc1) + } loc(#loc) +} loc(#loc) +#loc = loc("/home/triton/test.py":13:0) +#loc1 = loc("/home/triton/test.py":14:4) + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @foo() attributes {noinline = false} { + tt.return loc(#loc1) + } loc(#loc) +} loc(#loc) +#loc = loc("/home/triton/test.py":13:0) +#loc1 = loc("/home/triton/test.py":14:4) + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @kernel() attributes {noinline = false} { + tt.return loc(#loc1) + } loc(#loc) +} loc(#loc) +#loc = loc("/home/triton/test.py":13:0) +#loc1 = loc("/home/triton/test.py":14:4) +``` + +The hook, as defined, in the example will insert the pass at the end of the make_ttir pipeline but it's placement in the Triton pipeline is abritary. +This functionality can be toggled on and off by just commenting out this line in kernel code (or setting to None): +knobs.runtime.add_stages_inspection_hook = inspect_stages_hook +without needing any core compiler changes or rebuilding Triton. + +## Example 3: Fully customizing the compiler pipeline with pass and op insertions at abitrary locations + +Here we now run two kernels one with the full standard Triton pipeline and one with fully customized pipeline entirely from within +kernel code with modifying any core Triton compiler code or recompiling. We run the kernel with a hook to output the standard pipeline, modify +the compiler.py file to insert our out of tree pass before add_loop_unroll pass (although there is no restriction of where it can be inserted), +then run the second kernel with a different pipeline. This modification can, as before, be seen in the kernel function name modification by the +inserted pass. + +``` python +import torch +import os +import sys + +import triton +import triton.language as tl +from triton._C.libtriton import ir, passes +from triton import knobs +import inspect +from importlib.util import module_from_spec, spec_from_file_location + +from triton.backends.compiler import Language + +DEVICE = triton.runtime.driver.active.get_active_torch_device() + + +@triton.jit +def kernel1(BLOCK_SIZE: tl.constexpr): + return +@triton.jit +def kernel2(BLOCK_SIZE: tl.constexpr): + return + +def get_key(): + return pathlib.Path(__file__).read_text() +def get_hash(): + return hashlib.sha256(get_key().encode('utf-8')).hexdigest() + +def dump_stages_hook(self=None, stages=None, options=None, language=None, capability=None): + if all(arg is None for arg in (stages, options, language, capability)): + return get_key(), get_hash() + source_code = "# This is generated from Triton compiler.py" + source_code = ( + source_code + + "\n" + + "from triton._C.libtriton import ir, passes, llvm, amd, nvidia" + ) + source_code = source_code + "\n" + "class GPUOverrideBackend:" + source_code = source_code + "\n" + inspect.getsource(self.make_ttir) + source_code = source_code + "\n" + inspect.getsource(self.make_ttgir) + + with open("compiler_override.py", "w") as file: + file.write(source_code) + return get_key(), get_hash() +def override_stages(self=None, stages=None, options=None, language=None, capability=None): + if all(arg is None for arg in (stages, options, language, capability)): + return get_key(), get_hash() + if language != Language.TRITON: + return + full_name = "compiler_override.py" + + print(f"\nOverriding compile pass stages with file {full_name}") + module_name = "triton_override_compiler_stages" + spec = ( + spec_from_file_location(module_name, full_name) + if os.path.isfile(full_name) + else None + ) + if not spec: + return + + module = module_from_spec(spec) + sys.modules[module_name] = module + spec.loader.exec_module(module) + if not hasattr(module, "GPUOverrideBackend"): + return + module = getattr(module, "GPUOverrideBackend") + + has_func = lambda mod, name: hasattr(mod, name) and callable(getattr(mod, name)) + make_lambda = lambda f: lambda src, metadata: f(src, metadata, options, capability) + if has_func(module, "make_ttir"): + stages["ttir"] = make_lambda(module.make_ttir) + if has_func(module, "make_ttgir"): + stages["ttgir"] = make_lambda(module.make_ttgir) + return get_key(), get_hash() + +if __name__ == '__main__': + + size = 98432 + x = torch.rand(size, device=DEVICE) + output = torch.empty_like(x) + n_elements = output.numel() + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) + + knobs.runtime.add_stages_inspection_hook = dump_stages_hook + h = kernel1[grid](BLOCK_SIZE=1024) + filename = "compiler_override.py" + + with open(filename, "r") as infile: + file_str = infile.readlines() + + with open(filename, "w") as outfile: + for line in file_str: + if "add_loop_unroll" in line: + outfile.write("\n passes.plugin.add_plugin(pm)\n") + outfile.write(line) + if "TRITON_PASS_PLUGIN_PATH" in os.environ: + knobs.runtime.add_stages_inspection_hook = override_stages + h = kernel2[grid](BLOCK_SIZE=1024) + print(h.asm["ttgir"]) +``` diff --git a/lib/Plugins/TritonPlugin.cpp b/lib/Plugins/TritonPlugin.cpp new file mode 100644 index 0000000000..9807857d77 --- /dev/null +++ b/lib/Plugins/TritonPlugin.cpp @@ -0,0 +1,87 @@ +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Tools/PluginUtils.h" +#include + +#define TRITON_PLUGIN_API \ + extern "C" __attribute__((visibility("default"))) TritonPluginResult + +namespace mlir { +namespace triton { +namespace plugin { + +#define GEN_PASS_DEF_TRITONGPUMLIRPLUGIN +#include "Passes.h.inc" + +struct MLIRPluginPass : public impl::TritonGPUMLIRPluginBase { + void runOnOperation() override { + + MLIRContext *context = &getContext(); + ModuleOp mod = getOperation(); + mod.walk([&](FunctionOpInterface funcOp) { + StringAttr funcNameAttr = funcOp.getNameAttr(); + funcOp.setName("foo"); + }); + } +}; + +} // namespace plugin +} // namespace triton +} // namespace mlir + +static void addTritonPluginPass(mlir::PassManager *pm) { + pm->addPass(mlir::triton::plugin::createTritonGPUMLIRPlugin()); +} + +static void registerTritonPluginPass() { + ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> { + return mlir::triton::plugin::createTritonGPUMLIRPlugin(); + }); +} + +static const char *ADD_PLUGIN_PASS_NAME = "add_plugin"; +static std::unordered_map passMap = + {{ADD_PLUGIN_PASS_NAME, addTritonPluginPass}}; +static std::unordered_map registryMap = { + {ADD_PLUGIN_PASS_NAME, registerTritonPluginPass}}; +static std::vector passNamesTable = {ADD_PLUGIN_PASS_NAME}; + +// Key APIs: + +TRITON_PLUGIN_API +tritonAddPluginPass(mlir::PassManager *pm, const char *passName) { + std::string passNameStr(passName); + if (passMap.find(passNameStr) == passMap.end()) + return TP_GENERIC_FAILURE; + passMap[passNameStr](pm); + return TP_SUCCESS; +} + +TRITON_PLUGIN_API +tritonRegisterPluginPass(const char *passName) { + std::string passNameStr(passName); + if (registryMap.find(passNameStr) == registryMap.end()) + return TP_GENERIC_FAILURE; + registryMap[passNameStr](); + return TP_SUCCESS; +} + +TRITON_PLUGIN_API +tritonEnumeratePluginPasses(uint32_t *passCount, const char **passNames) { + if (!passCount) + return TP_GENERIC_FAILURE; + auto count = passMap.size(); + assert(count == registryMap.size() && + "Expected register and add passes map size to match"); + *passCount = count; + if (!passNames) + return TP_SUCCESS; + unsigned i = 0; + for (auto passName : passNamesTable) { + passNames[i] = passName; + } + return TP_SUCCESS; +} diff --git a/lib/Tools/CMakeLists.txt b/lib/Tools/CMakeLists.txt index a2f9f8aea5..611468b9a2 100644 --- a/lib/Tools/CMakeLists.txt +++ b/lib/Tools/CMakeLists.txt @@ -2,6 +2,7 @@ add_triton_library(TritonTools GenericSwizzling.cpp LayoutUtils.cpp LinearLayout.cpp + PluginUtils.cpp DEPENDS diff --git a/lib/Tools/PluginUtils.cpp b/lib/Tools/PluginUtils.cpp new file mode 100644 index 0000000000..97e97bde83 --- /dev/null +++ b/lib/Tools/PluginUtils.cpp @@ -0,0 +1,115 @@ +#include "triton/Tools/PluginUtils.h" + +llvm::Error TritonPlugin::checkLibraryValid(const std::string &error) const { + if (!library.isValid()) { + auto msg = llvm::Twine("Failed to load plugin library: " + error + "\n"); + return llvm::createStringError(msg); + } + return llvm::Error::success(); +} + +llvm::Expected +TritonPlugin::getAddressOfSymbol(const std::string &symbol) const { + if (auto isValid = checkLibraryValid("not loaded")) + return isValid; + intptr_t getDetailsFn = (intptr_t)library.getAddressOfSymbol(symbol.c_str()); + if (!getDetailsFn) { + auto msg = llvm::Twine("Failed to get symbol: " + symbol + "\n"); + return llvm::createStringError(msg); + } + return getDetailsFn; +} + +llvm::Expected +TritonPlugin::checkAPIResult(TritonPluginResult result, + const char *handle) const { + if (result == TP_SUCCESS) + return TP_SUCCESS; + std::string msg; + llvm::raw_string_ostream os(msg); + os << "Failed to add/register plugin pass (" << handle + << ") to pass manager, error code: " << result; + return llvm::createStringError(msg); +} + +std::runtime_error TritonPlugin::err2exp(llvm::Error Err) { + std::string msg; + llvm::raw_string_ostream os(msg); + os << Err; + return std::runtime_error(msg); +} + +llvm::Error TritonPlugin::loadPlugin() { + if (isLoaded) + return llvm::Error::success(); + + std::string error; + library = + llvm::sys::DynamicLibrary::getPermanentLibrary(filename.c_str(), &error); + if (auto isValid = checkLibraryValid(error)) + return isValid; + + auto enumeratePassesAPIOrErr = + getAPI( + ENUMERATE_PASSES); + auto addPassAPIOrErr = getAPI(ADD_PASS); + auto registerPassAPIOrErr = + getAPI(REGISTER_PASS); + + if (auto Err = enumeratePassesAPIOrErr.takeError()) + return Err; + if (auto Err = addPassAPIOrErr.takeError()) + return Err; + if (auto Err = registerPassAPIOrErr.takeError()) + return Err; + + enumeratePassesAPI = *enumeratePassesAPIOrErr; + addPassAPI = *addPassAPIOrErr; + registerPassAPI = *registerPassAPIOrErr; + isLoaded = true; + return llvm::Error::success(); +} + +llvm::Expected TritonPlugin::enumeratePyBindHandles( + enumeratePyBindHandlesType &enumeratePyBindHandles, + std::vector &handles) { + if (auto Err = loadPlugin()) + return Err; + + uint32_t passCount = 0; + handles.clear(); + auto result = enumeratePyBindHandles(&passCount, nullptr); + if (result == TP_SUCCESS) { + if (passCount == 0) + return TP_SUCCESS; + + handles.resize(passCount); + result = enumeratePyBindHandles(&passCount, handles.data()); + } + + if (result == TP_SUCCESS) + return TP_SUCCESS; + std::string msg; + llvm::raw_string_ostream os(msg); + os << "Failed to retrive plugin pass handles, error code: " << result; + return llvm::createStringError(msg); +} + +llvm::Expected +TritonPlugin::getPassHandles(std::vector &passNames) { + return enumeratePyBindHandles(enumeratePassesAPI, passNames); +} + +llvm::Expected +TritonPlugin::addPass(mlir::PassManager *pm, const char *passHandle) { + if (auto Err = loadPlugin()) + return Err; + return checkAPIResult(addPassAPI(pm, passHandle), passHandle); +} + +llvm::Expected +TritonPlugin::registerPass(const char *passHandle) { + if (auto Err = loadPlugin()) + return Err; + return checkAPIResult(registerPassAPI(passHandle), passHandle); +} diff --git a/python/src/passes.cc b/python/src/passes.cc index f2d93fab96..8977b59913 100644 --- a/python/src/passes.cc +++ b/python/src/passes.cc @@ -12,8 +12,11 @@ #include "triton/Dialect/TritonGPU/Transforms/Passes.h" #include "triton/Dialect/TritonInstrument/Transforms/Passes.h" #include "triton/Target/LLVMIR/Passes.h" +#include "triton/Tools/PluginUtils.h" +#include "triton/Tools/Sys/GetEnv.hpp" #include #include +#include namespace py = pybind11; @@ -96,6 +99,30 @@ void init_triton_passes_ttgpuir(py::module &&m) { createTritonGPUOptimizePartitionWarps); } +void init_plugin_passes(py::module &&m) { + std::string filename = + mlir::triton::tools::getStrEnv("TRITON_PASS_PLUGIN_PATH"); + if (filename.empty()) + return; + + TritonPlugin TP(filename); + std::vector passNames; + if (auto result = TP.getPassHandles(passNames); !result) + throw TP.err2exp(result.takeError()); + + for (unsigned i = 0; i < passNames.size(); ++i) { + const char *passName = passNames.data()[i]; + + m.def(passName, [passName](mlir ::PassManager &pm) { + std::string filename = + mlir::triton::tools::getStrEnv("TRITON_PASS_PLUGIN_PATH"); + TritonPlugin TP(filename); + if (auto result = TP.addPass(&pm, passName); !result) + throw TP.err2exp(result.takeError()); + }); + } +} + void init_triton_passes_convert(py::module &&m) { using namespace mlir; ADD_PASS_WRAPPER_0("add_scf_to_cf", createSCFToControlFlowPass); @@ -130,4 +157,5 @@ void init_triton_passes(py::module &&m) { init_triton_passes_ttgpuir(m.def_submodule("ttgpuir")); init_triton_passes_llvmir(m.def_submodule("llvmir")); init_gluon_passes(m.def_submodule("gluon")); + init_plugin_passes(m.def_submodule("plugin")); } diff --git a/python/test/unit/plugins/custom_stages.py b/python/test/unit/plugins/custom_stages.py new file mode 100644 index 0000000000..f98cab4b0a --- /dev/null +++ b/python/test/unit/plugins/custom_stages.py @@ -0,0 +1,38 @@ +from triton._C.libtriton import ir, passes +import hashlib +import pathlib + + +# These two methods must be implemented and returned by the plugin hook. +# any changes in this entire file and the the plugin pipeline +# will trigger a recompile since the hash will change. To be +# less conservative, we could use a hash of the inspect_stages_hook +# function but then changes outside of the function won't be considered +# potentially causing a stale kernel hash +def get_key(): + return pathlib.Path(__file__).read_text() + + +def get_hash(): + return hashlib.sha256(get_key().encode('utf-8')).hexdigest() + + +# Keep custom pipeline stages in a seperate file from kernels as any change to the file +# will trigger a recompile. +def inspect_stages_hook(self=None, stages=None, options=None, language=None, capability=None): + # If the hook is called with no arguments we assume were just after the key and hash and don't want to + # actually execute the pipeline yet + if all(arg is None for arg in (stages, options, language, capability)): + return get_key(), get_hash() + + def make_ttir_wrapper(mod, metadata, opt, capability): + mod = self.make_ttir(mod, metadata, opt, capability) + pm = ir.pass_manager(mod.context) + pm.enable_debug() + passes.plugin.add_plugin(pm) + pm.run(mod, 'make_ttir_plugin') + return mod + + stages["ttir"] = lambda src, metadata: make_ttir_wrapper(src, metadata, options, capability) + + return get_key(), get_hash() diff --git a/python/test/unit/plugins/test_plugin.py b/python/test/unit/plugins/test_plugin.py new file mode 100644 index 0000000000..9a895174b1 --- /dev/null +++ b/python/test/unit/plugins/test_plugin.py @@ -0,0 +1,43 @@ +import torch + +import pytest +import os + +import triton +import triton.language as tl +from triton import knobs +import custom_stages + + +@pytest.mark.parametrize(None, [None]) +@triton.jit +def kernel1(BLOCK_SIZE: tl.constexpr): + return + + +@pytest.mark.parametrize(None, [None]) +@triton.jit +def kernel2(BLOCK_SIZE: tl.constexpr): + return + + +def test_op(capfd, device: str): + if os.environ.get('LLVM_BUILD_SHARED_LIBS', '0') == '0': + return + + size = 98432 + x = torch.rand(size, device=device) + output = torch.empty_like(x) + n_elements = output.numel() + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) + + h = kernel1[grid](BLOCK_SIZE=1024) + assert "tt.func public @foo" not in h.asm["ttgir"] + + knobs.runtime.add_stages_inspection_hook = custom_stages.inspect_stages_hook + h = kernel2[grid](BLOCK_SIZE=1024) + assert "tt.func public @foo" in h.asm["ttgir"] + + knobs.runtime.add_stages_inspection_hook = None + h = kernel2[grid](BLOCK_SIZE=1024) + assert "tt.func public @foo" not in h.asm["ttgir"] diff --git a/python/triton/compiler/compiler.py b/python/triton/compiler/compiler.py index d4872a4635..d6f8dd26fe 100644 --- a/python/triton/compiler/compiler.py +++ b/python/triton/compiler/compiler.py @@ -244,6 +244,9 @@ def compile(src, target=None, options=None, _env_vars=None): # create cache manager env_vars = get_cache_invalidating_env_vars() if _env_vars is None else _env_vars key = get_cache_key(src, backend, options, env_vars=env_vars) + if knobs.runtime.add_stages_inspection_hook is not None: + inspect_stages_key, inspect_stages_hash = knobs.runtime.add_stages_inspection_hook() + key += inspect_stages_key hash = hashlib.sha256(key.encode("utf-8")).hexdigest() fn_cache_manager = get_cache_manager(hash) # For dumping/overriding only hash the source as we want it to be independent of triton diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index 0f506818f9..7c73a84f36 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -709,6 +709,12 @@ def run(self, *args, grid, warmup, **kwargs): # the type and the second parameter is the 'specialization' value. bound_args, specialization, options = binder(*args, **kwargs) + # add a cache field to the kernel specializations for kernel specific + # pass pipelines + if knobs.runtime.add_stages_inspection_hook is not None: + inspect_stages_key, inspect_stages_hash = knobs.runtime.add_stages_inspection_hook() + specialization.append(f'("custom_pipeline", {inspect_stages_hash})') + key = compute_cache_key(kernel_key_cache, specialization, options) kernel = kernel_cache.get(key, None) diff --git a/scripts/build-llvm-project.sh b/scripts/build-llvm-project.sh index 3651f4ee67..d8ce9aebd7 100755 --- a/scripts/build-llvm-project.sh +++ b/scripts/build-llvm-project.sh @@ -5,6 +5,7 @@ REPO_ROOT="$(git rev-parse --show-toplevel)" LLVM_TARGETS=${LLVM_TARGETS:-Native;NVPTX;AMDGPU} LLVM_PROJECTS=${LLVM_PROJECTS:-mlir;llvm;lld} LLVM_BUILD_TYPE=${LLVM_BUILD_TYPE:-RelWithDebInfo} +LLVM_BUILD_SHARED_LIBS=${LLVM_BUILD_SHARED_LIBS:-OFF} LLVM_COMMIT_HASH=${LLVM_COMMIT_HASH:-$(cat "$REPO_ROOT/cmake/llvm-hash.txt")} LLVM_PROJECT_PATH=${LLVM_PROJECT_PATH:-"$REPO_ROOT/llvm-project"} LLVM_BUILD_PATH=${LLVM_BUILD_PATH:-"$LLVM_PROJECT_PATH/build"} @@ -21,6 +22,7 @@ if [ -z "$CMAKE_ARGS" ]; then -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DLLVM_ENABLE_LLD=ON + -DBUILD_SHARED_LIBS="$LLVM_BUILD_SHARED_LIBS" -DLLVM_OPTIMIZED_TABLEGEN=ON -DMLIR_ENABLE_BINDINGS_PYTHON=OFF -DLLVM_TARGETS_TO_BUILD="$LLVM_TARGETS" diff --git a/setup.py b/setup.py index 795eae685b..017f1112f8 100644 --- a/setup.py +++ b/setup.py @@ -492,6 +492,11 @@ def build_extension(self, ext): "-DCMAKE_SHARED_LINKER_FLAGS=-fuse-ld=lld", ] + if check_env_flag("LLVM_BUILD_SHARED_LIBS"): + cmake_args += ["-DLLVM_BUILD_SHARED_LIBS=1"] + else: + cmake_args += ["-DLLVM_BUILD_SHARED_LIBS=0"] + # Note that asan doesn't work with binaries that use the GPU, so this is # only useful for tools like triton-opt that don't run code on the GPU. # diff --git a/test/Plugins/test-plugin.mlir b/test/Plugins/test-plugin.mlir new file mode 100644 index 0000000000..f16ef07882 --- /dev/null +++ b/test/Plugins/test-plugin.mlir @@ -0,0 +1,30 @@ +// RUN: TRITON_PASS_PLUGIN_PATH=%shlibdir/../plugins/libTritonPluginsTestLib.so triton-opt -split-input-file -tritongpu-plugin %s | FileCheck %s --check-prefix=CHECK-PLUGIN +// RUN: TRITON_PASS_PLUGIN_PATH=%shlibdir/../plugins/libTritonPluginsTestLib.so triton-opt -split-input-file %s | FileCheck %s -allow-unused-prefixes --check-prefix=CHECK-NOFLAG +// RUN: triton-opt -split-input-file %s | FileCheck %s -allow-unused-prefixes --check-prefix=CHECK-BASE + +// REQUIRES: shared-libs + +module attributes {"ttg.num-warps" = 4 : i32, "ttg.target" = "cuda:80"} { + // CHECK-PLUGIN: func @foo() + tt.func @bar() { + tt.return + } +} // module + +// ----- + +module attributes {"ttg.num-warps" = 4 : i32, "ttg.target" = "cuda:80"} { + // CHECK-NOFLAG: func @bar() + tt.func @bar() { + tt.return + } +} // module + +// ----- + +module attributes {"ttg.num-warps" = 4 : i32, "ttg.target" = "cuda:80"} { + // CHECK-BASE: func @bar() + tt.func @bar() { + tt.return + } +} // module diff --git a/test/lit.cfg.py b/test/lit.cfg.py index 0890e609b7..09044bebe4 100644 --- a/test/lit.cfg.py +++ b/test/lit.cfg.py @@ -63,6 +63,10 @@ ToolSubst('%PYTHON', config.python_executable, unresolved='ignore'), ] +# Static libraries are not built if LLVM_BUILD_SHARED_LIBS is ON. +if config.build_shared_libs: + config.available_features.add("shared-libs") + llvm_config.add_tool_substitutions(tools, tool_dirs) # TODO: what's this? diff --git a/test/lit.site.cfg.py.in b/test/lit.site.cfg.py.in index fd1fb486dd..59b212a4d2 100644 --- a/test/lit.site.cfg.py.in +++ b/test/lit.site.cfg.py.in @@ -14,6 +14,7 @@ config.lit_tools_dir = "@LLVM_LIT_TOOLS_DIR@" config.mlir_binary_dir = "@MLIR_BINARY_DIR@" config.python_executable = "@Python3_EXECUTABLE@" config.enable_bindings_python = @MLIR_ENABLE_BINDINGS_PYTHON@ +config.build_shared_libs = @LLVM_BUILD_SHARED_LIBS@ import lit.llvm From 2760a7e0ef6e202c9b5e526a2ca6d0959a22c5d4 Mon Sep 17 00:00:00 2001 From: Kyle Wang Date: Fri, 21 Nov 2025 23:18:16 -0800 Subject: [PATCH 2/5] [Bench][AMD]Support Padding and Unswizzling Scale on CDNA4 (#8803) This PR supports - the `CDNA4MXScaleLayout.unswizzle_data` method used in GPT-OSS model - padding tensors with 0 when doing scale preshuffling --- python/triton_kernels/tests/test_matmul.py | 2 -- .../test_tensor_details/test_layout_cdna4.py | 24 ++++++++++++++ .../layout_details/cdna4_scale.py | 31 ++++++++++++------- 3 files changed, 44 insertions(+), 13 deletions(-) create mode 100644 python/triton_kernels/tests/test_tensor_details/test_layout_cdna4.py diff --git a/python/triton_kernels/tests/test_matmul.py b/python/triton_kernels/tests/test_matmul.py index c7405cd944..664de12bdd 100644 --- a/python/triton_kernels/tests/test_matmul.py +++ b/python/triton_kernels/tests/test_matmul.py @@ -369,8 +369,6 @@ def _test_op(m, n, k, split_k, do_gather, do_scatter, inner_expt_opt, has_y_gamm pytest.skip("Scale preshuffling on AMD GPU has not been emulated on non-CDNA4 arch yet.") if "mx" not in weight_dtype_str: pytest.skip("Non-scale swizzling not supported on CDNA4 yet") - if n % 32 != 0 or k % (32 * 8) != 0: - pytest.skip(f"Shape {m}x{n}x{k} is not supported for scale swizzling on AMD GPU") if torch.cuda.get_device_capability()[0] < 9: pytest.skip("NYI. Ampere swizzling.") if torch.cuda.get_device_capability()[0] < 10: diff --git a/python/triton_kernels/tests/test_tensor_details/test_layout_cdna4.py b/python/triton_kernels/tests/test_tensor_details/test_layout_cdna4.py new file mode 100644 index 0000000000..0ffa8b6c58 --- /dev/null +++ b/python/triton_kernels/tests/test_tensor_details/test_layout_cdna4.py @@ -0,0 +1,24 @@ +import pytest +import torch +from triton_kernels.tensor_details.layout import CDNA4MXScaleLayout + +# ------------------------------------------------------------ +# Torch tests +# ------------------------------------------------------------ + + +@pytest.mark.parametrize( + "shape", + [ + (3, 4096, 1024), + (10, 254, 60), + (1, 320, 160), + (2, 16, 512), + (3, 2, 36), + ], +) +def test_mxfp4_scale_roundtrip(shape): + x = torch.randint(0, 256, shape, dtype=torch.uint8, device="cuda") + layout = CDNA4MXScaleLayout(x.shape) + res = layout.unswizzle_data(layout.swizzle_data(x)) + assert (res == x).all() diff --git a/python/triton_kernels/triton_kernels/tensor_details/layout_details/cdna4_scale.py b/python/triton_kernels/triton_kernels/tensor_details/layout_details/cdna4_scale.py index 870d739096..ca23f090db 100644 --- a/python/triton_kernels/triton_kernels/tensor_details/layout_details/cdna4_scale.py +++ b/python/triton_kernels/triton_kernels/tensor_details/layout_details/cdna4_scale.py @@ -1,3 +1,5 @@ +import math +import torch from dataclasses import dataclass import triton import triton.language as tl @@ -12,24 +14,31 @@ class CDNA4MXScaleLayout(Layout): def __init__(self, shape) -> None: super().__init__(shape) + ( + *self.leading_shape, + self.K_SCALE, + self.N, + ) = shape + self.B = math.prod(self.leading_shape) + self.ALIGN_K_SCALE = 8 + self.ALIGN_N = 32 + self.K_SCALE_pad = math.ceil(self.K_SCALE / self.ALIGN_K_SCALE) * self.ALIGN_K_SCALE + self.N_pad = math.ceil(self.N / self.ALIGN_N) * self.ALIGN_N def swizzle_data(self, data): - block_shape = data.shape - SCALE_K = block_shape[-2] - N = block_shape[-1] + data = torch.nn.functional.pad(data, (0, self.N_pad - self.N, 0, self.K_SCALE_pad - self.K_SCALE)) data = data.transpose(-1, -2) - data = data.view(-1, N // NON_K_PRESHUFFLE_BLOCK_SIZE, 2, 16, SCALE_K // 8, 2, 4, 1) + data = data.view(-1, self.N_pad // NON_K_PRESHUFFLE_BLOCK_SIZE, 2, 16, self.K_SCALE_pad // 8, 2, 4, 1) data = data.permute(0, 1, 4, 6, 3, 5, 2, 7).contiguous() - if len(block_shape) == 3: - E = block_shape[0] - data = data.reshape(E, N // 32, SCALE_K * 32) - else: - assert len(block_shape) == 2 - data = data.reshape(N // 32, SCALE_K * 32) + data = data.reshape(self.B, self.N_pad // 32, self.K_SCALE_pad * 32) return data.transpose(-1, -2) def unswizzle_data(self, data): - raise NotImplementedError() + data = data.transpose(-1, -2) + data = data.view(-1, self.N_pad // NON_K_PRESHUFFLE_BLOCK_SIZE, self.K_SCALE_pad // 8, 4, 16, 2, 2, 1) + data = data.permute(0, 1, 6, 4, 2, 5, 3, 7) + data = data.reshape(*self.leading_shape, self.N_pad, self.K_SCALE_pad) + return data.transpose(-1, -2)[..., :self.K_SCALE, :self.N] def swizzle_block_shape(self, block_shape): SCALE_K = block_shape[-2] From 9fbf44ffea9fdb9012fbd4229d08a66632f6d89d Mon Sep 17 00:00:00 2001 From: peiying779 Date: Sat, 22 Nov 2025 09:18:31 -0800 Subject: [PATCH 3/5] [AMD] Add messages for debugging BlockPingpong (#8804) For some num_warps, num_stages and tile_size, BlockPingpong will exit early without leaving a message. Added message for these cases so people can be aware that pingpong wasn't really involved. Also add some transparency for debugging. This shouldn't change the way how BlockPingong or anything else in Triton is used. # New contributor declaration - [X] I am not making a trivial change, such as fixing a typo in a comment. - [X] I have written a PR description following these [rules](https://cbea.ms/git-commit/#why-not-how). - [X] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`. - Select one of the following. - [ ] I have added tests. - `/test` for `lit` tests - `/unittest` for C++ tests - `/python/test` for end-to-end tests - [X] This PR does not need a test because `Just added a few message to make sure BlockPingpong won't silently exit early, making it easier to debug`. - Select one of the following. - [X] I have not added any `lit` tests. - [ ] The `lit` tests I have added follow these [best practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices), including the "tests should be minimal" section. (Usually running Python code and using the instructions it generates is not minimal.) Co-authored-by: Peiying Hua --- .../TritonAMDGPUTransforms/BlockPingpong.cpp | 22 +++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/BlockPingpong.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/BlockPingpong.cpp index e78db6422a..6c3773afee 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/BlockPingpong.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/BlockPingpong.cpp @@ -1185,8 +1185,19 @@ void Pingponger::getDotPingponged() { // times for issuing the memory operations and issuing dot operations, // smaller tile sizes are not likely to get any advantage from current dot // centric pingpong scheduling. - if (tileSize <= smallTile && tileSize >= minTile) + if (tileSize <= smallTile && tileSize >= minTile) { transformOnePPClusters(builder, loc); + LDBG("Pingpong scheduling applied for numWarps=4 with tileSize=" + + std::to_string(tileSize) + " (in range [" + std::to_string(minTile) + + ", " + std::to_string(smallTile) + + "]), One Dot-Memory (ping-pong) cluster used."); + } else { + std::stringstream message; + message << "Skipping pingpong for numWarps=4: tileSize=" << tileSize + << " is outside the range [" << minTile << ", " << smallTile + << "]"; + LDBG(message.str()); + } // numWarps=4 doesn't need asymmetric sync, return. return; } else if (numWarps == 8 && numStages == 2) { @@ -1216,8 +1227,15 @@ void Pingponger::getDotPingponged() { "cluster transformation"); return; } - } else + } else { + std::stringstream message; + message << "Skipping pingpong for numWarps=8, numStages=2: tileSize=" + << tileSize + << " does not match supported tile sizes (medium=" << mediumTile + << " or large=" << largeTile << ")"; + LDBG(message.str()); return; + } // Let half of the warps start the loop first and the others follow later // but in the synchronized way. This can be accomplished by calling From e4e68d1b9c7d307785b230c0521095c753ffe071 Mon Sep 17 00:00:00 2001 From: Corbin Robeck Date: Sun, 23 Nov 2025 09:34:21 -0500 Subject: [PATCH 4/5] [BACKEND][PROTON] Remove deprecated instrumentation (#8809) --- lib/CMakeLists.txt | 1 - lib/Instrumentation/CMakeLists.txt | 42 -------- .../PrintLoadStoreMemSpaces.cpp | 102 ------------------ 3 files changed, 145 deletions(-) delete mode 100644 lib/Instrumentation/CMakeLists.txt delete mode 100644 lib/Instrumentation/PrintLoadStoreMemSpaces.cpp diff --git a/lib/CMakeLists.txt b/lib/CMakeLists.txt index 9e14a86250..e480ffa52a 100644 --- a/lib/CMakeLists.txt +++ b/lib/CMakeLists.txt @@ -3,5 +3,4 @@ add_subdirectory(Conversion) add_subdirectory(Dialect) add_subdirectory(Target) add_subdirectory(Tools) -add_subdirectory(Instrumentation) add_subdirectory(Plugins) diff --git a/lib/Instrumentation/CMakeLists.txt b/lib/Instrumentation/CMakeLists.txt deleted file mode 100644 index 6e6da2351e..0000000000 --- a/lib/Instrumentation/CMakeLists.txt +++ /dev/null @@ -1,42 +0,0 @@ -set(GPU_INSTRUMENTATION_PASSES - PrintLoadStoreMemSpaces - ) - -set(PrintLoadStoreMemSpaces_SOURCES - PrintLoadStoreMemSpaces.cpp - ) - - -foreach( plugin ${GPU_INSTRUMENTATION_PASSES} ) - add_library( - ${plugin} - SHARED - ${${plugin}_SOURCES} - ) - - target_link_libraries( - ${plugin} - PRIVATE - LLVMCore - LLVMSupport - LLVMTransformUtils - "$<$:-undefined dynamic_lookup>" - ) - # CMAKE_LIBRARY_OUTPUT_DIRECTORY is only set during the Python - # build. It is empty if building directly from the root - # CMakeLists.txt file. Therefore if not building from Python just - # use the default CMake shared lib path otherwise this causes a hard - # build error - if(DEFINED CMAKE_LIBRARY_OUTPUT_DIRECTORY) - set_target_properties(${plugin} PROPERTIES - LIBRARY_OUTPUT_DIRECTORY - "${CMAKE_LIBRARY_OUTPUT_DIRECTORY}/../instrumentation") - endif(DEFINED CMAKE_LIBRARY_OUTPUT_DIRECTORY) - - # This is set to -fvisibility=hidden in the top level CMake file - # which causes the llvmGetPassPluginInfo symbol to be hidden and - # an "entry point not found" error. Reset it just for this target - if(NOT MSVC) - target_compile_options(${plugin} PRIVATE -fvisibility=default -fno-rtti) - endif() -endforeach() diff --git a/lib/Instrumentation/PrintLoadStoreMemSpaces.cpp b/lib/Instrumentation/PrintLoadStoreMemSpaces.cpp deleted file mode 100644 index 7e2945d3d2..0000000000 --- a/lib/Instrumentation/PrintLoadStoreMemSpaces.cpp +++ /dev/null @@ -1,102 +0,0 @@ -#include "llvm/IR/Module.h" -#include "llvm/IR/PassManager.h" -#include "llvm/Passes/PassBuilder.h" -#include "llvm/Passes/PassPlugin.h" -#include - -using namespace llvm; - -namespace { - -struct LoadStoreMemSpace : public PassInfoMixin { - PreservedAnalyses run(llvm::Module &module, ModuleAnalysisManager &) { - bool modifiedCodeGen = runOnModule(module); - - return (modifiedCodeGen ? llvm::PreservedAnalyses::none() - : llvm::PreservedAnalyses::all()); - } - bool runOnModule(llvm::Module &module); - // isRequired being set to true keeps this pass from being skipped - // if it has the optnone LLVM attribute - static bool isRequired() { return true; } -}; - -} // end anonymous namespace - -static std::map AddrSpaceMap = { - {0, "FLAT"}, {1, "GLOBAL"}, {3, "SHARED"}, {4, "CONSTANT"}, {5, "SCRATCH"}}; - -static std::map LocationCounterSourceMap; - -static std::string LoadOrStoreMap(const BasicBlock::iterator &I) { - if (LoadInst *LI = dyn_cast(I)) - return "LOAD"; - else if (StoreInst *SI = dyn_cast(I)) - return "STORE"; - else - throw std::runtime_error("Error: unknown operation type"); -} -template -static void InstrumentationFunction(const BasicBlock::iterator &I, - const Function &F, const llvm::Module &M, - uint32_t &LocationCounter) { - auto LSI = dyn_cast(I); - if (not LSI) - return; - Value *Op = LSI->getPointerOperand()->stripPointerCasts(); - uint32_t AddrSpace = cast(Op->getType())->getAddressSpace(); - DILocation *DL = dyn_cast(I)->getDebugLoc(); - - std::string SourceAndAddrSpaceInfo = - (F.getName() + " " + DL->getFilename() + ":" + Twine(DL->getLine()) + - ":" + Twine(DL->getColumn())) - .str() + - " " + AddrSpaceMap[AddrSpace] + " " + LoadOrStoreMap(I); - - if (LocationCounterSourceMap.find(SourceAndAddrSpaceInfo) == - LocationCounterSourceMap.end()) { - errs() << LocationCounter << " " << SourceAndAddrSpaceInfo << "\n"; - LocationCounterSourceMap[SourceAndAddrSpaceInfo] = LocationCounter; - LocationCounter++; - } -} - -bool LoadStoreMemSpace::runOnModule(Module &M) { - bool ModifiedCodeGen = false; - uint32_t LocationCounter = 0; - for (auto &F : M) { - if (F.isIntrinsic()) - continue; - StringRef functionName = F.getName(); - if (F.getCallingConv() == CallingConv::AMDGPU_KERNEL || - F.getCallingConv() == CallingConv::PTX_Kernel || - functionName.contains("kernel")) { - for (Function::iterator BB = F.begin(); BB != F.end(); BB++) { - for (BasicBlock::iterator I = BB->begin(); I != BB->end(); I++) { - if (LoadInst *LI = dyn_cast(I)) { - InstrumentationFunction(I, F, M, LocationCounter); - } else if (StoreInst *SI = dyn_cast(I)) { - InstrumentationFunction(I, F, M, LocationCounter); - } - } - } - } - } - return ModifiedCodeGen; -} - -static PassPluginLibraryInfo getPassPluginInfo() { - const auto callback = [](PassBuilder &PB) { - PB.registerOptimizerLastEPCallback([&](ModulePassManager &MPM, auto, auto) { - MPM.addPass(LoadStoreMemSpace()); - return true; - }); - }; - - return {LLVM_PLUGIN_API_VERSION, "print-mem-space", LLVM_VERSION_STRING, - callback}; -}; - -extern "C" LLVM_ATTRIBUTE_WEAK PassPluginLibraryInfo llvmGetPassPluginInfo() { - return getPassPluginInfo(); -} From ed1efc497da6f2ed33ff321d82c1945371e9070a Mon Sep 17 00:00:00 2001 From: Anatoly Myachev Date: Tue, 2 Dec 2025 13:17:42 +0000 Subject: [PATCH 5/5] [Intel] Fix Windows build after 'd907aed' Signed-off-by: Anatoly Myachev --- include/triton/Tools/PluginUtils.h | 10 ++++++++-- lib/Plugins/TritonPlugin.cpp | 9 +++++++-- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/include/triton/Tools/PluginUtils.h b/include/triton/Tools/PluginUtils.h index d0b4de6adb..14761f6ef6 100644 --- a/include/triton/Tools/PluginUtils.h +++ b/include/triton/Tools/PluginUtils.h @@ -12,8 +12,14 @@ enum TritonPluginResult { TP_GENERIC_FAILURE = 1, }; }; -#define TRITON_PLUGIN_API \ - extern "C" __attribute__((visibility("default"))) TritonPluginResult + +#if defined(_WIN32) +#define EXPORT_FUNC __declspec(dllexport) +#else +#define EXPORT_FUNC __attribute__((visibility("default"))) +#endif + +#define TRITON_PLUGIN_API extern "C" EXPORT_FUNC TritonPluginResult struct TritonPlugin { TritonPlugin() = delete; diff --git a/lib/Plugins/TritonPlugin.cpp b/lib/Plugins/TritonPlugin.cpp index 9807857d77..26e18c0ed7 100644 --- a/lib/Plugins/TritonPlugin.cpp +++ b/lib/Plugins/TritonPlugin.cpp @@ -6,8 +6,13 @@ #include "triton/Tools/PluginUtils.h" #include -#define TRITON_PLUGIN_API \ - extern "C" __attribute__((visibility("default"))) TritonPluginResult +#if defined(_WIN32) +#define EXPORT_FUNC __declspec(dllexport) +#else +#define EXPORT_FUNC __attribute__((visibility("default"))) +#endif + +#define TRITON_PLUGIN_API extern "C" EXPORT_FUNC TritonPluginResult namespace mlir { namespace triton {