diff --git a/.gitignore b/.gitignore index 6ac292eb1e..08b393568a 100644 --- a/.gitignore +++ b/.gitignore @@ -51,6 +51,9 @@ pytest.ini # Instrumentation python/triton/instrumentation +# MLIR Plugin +python/triton/plugins + # Python caches __pycache__/ *.py[cod] diff --git a/CMakeLists.txt b/CMakeLists.txt index 4e55f9dfac..4a3efd485b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -20,6 +20,8 @@ option(TRITON_BUILD_PYTHON_MODULE "Build Python Triton bindings" OFF) option(TRITON_BUILD_PROTON "Build the Triton Proton profiler" ON) option(TRITON_BUILD_UT "Build C++ Triton Unit Tests" ON) option(TRITON_BUILD_WITH_CCACHE "Build with ccache (if available)" ON) +option(LLVM_BUILD_SHARED_LIBS + "Build all libraries as shared libraries instead of static" OFF) set(TRITON_CODEGEN_BACKENDS "" CACHE STRING "Enable different codegen backends") if(TRITON_BUILD_WITH_CCACHE) @@ -64,6 +66,7 @@ if(WIN32) set(CMAKE_EXE_LINKER_FLAGS_TRITONRELBUILDWITHASSERTS "/debug:fastlink /INCREMENTAL") set(CMAKE_MODULE_LINKER_FLAGS_TRITONRELBUILDWITHASSERTS "/debug:fastlink /INCREMENTAL") set(CMAKE_SHARED_LINKER_FLAGS_TRITONRELBUILDWITHASSERTS "/debug:fastlink /INCREMENTAL") + set(LLVM_BUILD_SHARED_LIBS "0") else() message(FATAL_ERROR "Unsupported compiler") endif() diff --git a/Makefile b/Makefile index 53960a6a04..43ed4adcb5 100644 --- a/Makefile +++ b/Makefile @@ -43,6 +43,8 @@ test-unit: all $(PYTEST) --tb=short -vs python/examples/gluon/01-attention-forward.py TRITON_ALWAYS_COMPILE=1 TRITON_DISABLE_LINE_INFO=0 LLVM_PASS_PLUGIN_PATH=python/triton/instrumentation/libGPUInstrumentationTestLib.so \ $(PYTEST) --capture=tee-sys -rfs -vvv python/test/unit/instrumentation/test_gpuhello.py + TRITON_PASS_PLUGIN_PATH=python/triton/plugins/libTritonPluginsTestLib.so \ + $(PYTEST) -vvv python/test/unit/plugins/test_plugin.py $(PYTEST) --tb=short -s -n $(NUM_PROCS) python/test/gluon .PHONY: test-distributed diff --git a/README.md b/README.md index d235dab4fd..b069daa989 100644 --- a/README.md +++ b/README.md @@ -272,7 +272,7 @@ def inspect_stages(_self, stages, options, language, capability): # inspect or modify add_stages here triton.knobs.runtime.add_stages_inspection_hook = inspect_stages ``` - +Examples of how to use this for out of tree plugin passes is [here](lib/Plugins/README.md) # Changelog Version 2.0 is out! New features include: diff --git a/bin/RegisterTritonDialects.h b/bin/RegisterTritonDialects.h index 4c901f66ae..6273dbe949 100644 --- a/bin/RegisterTritonDialects.h +++ b/bin/RegisterTritonDialects.h @@ -55,6 +55,9 @@ #include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h" #include "mlir/Conversion/UBToLLVM/UBToLLVM.h" +#include "triton/Tools/PluginUtils.h" +#include "triton/Tools/Sys/GetEnv.hpp" + namespace mlir { namespace test { namespace intel { @@ -165,6 +168,21 @@ inline void registerTritonDialects(mlir::DialectRegistry ®istry) { mlir::triton::proton::gpu::registerScheduleBufferStorePass(); mlir::triton::proton::gpu::registerAddSchedBarriersPass(); + // Plugin passes + if (std::string filename = + mlir::triton::tools::getStrEnv("TRITON_PASS_PLUGIN_PATH"); + !filename.empty()) { + + TritonPlugin TP(filename); + std::vector passNames; + if (auto result = TP.getPassHandles(passNames); !result) + llvm::report_fatal_error(result.takeError()); + + for (const char *passName : passNames) + if (auto result = TP.registerPass(passName); !result) + llvm::report_fatal_error(result.takeError()); + } + registry.insert< mlir::triton::TritonDialect, mlir::cf::ControlFlowDialect, mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect, diff --git a/include/triton/Tools/PluginUtils.h b/include/triton/Tools/PluginUtils.h new file mode 100644 index 0000000000..14761f6ef6 --- /dev/null +++ b/include/triton/Tools/PluginUtils.h @@ -0,0 +1,90 @@ +#ifndef TRITON_PLUGIN_UTILS_H +#define TRITON_PLUGIN_UTILS_H + +#include "mlir/Pass/PassManager.h" +#include "llvm/Support/DynamicLibrary.h" +#include "llvm/Support/Error.h" +#include + +extern "C" { +enum TritonPluginResult { + TP_SUCCESS = 0, + TP_GENERIC_FAILURE = 1, +}; +}; + +#if defined(_WIN32) +#define EXPORT_FUNC __declspec(dllexport) +#else +#define EXPORT_FUNC __attribute__((visibility("default"))) +#endif + +#define TRITON_PLUGIN_API extern "C" EXPORT_FUNC TritonPluginResult + +struct TritonPlugin { + TritonPlugin() = delete; + TritonPlugin(std::string filename) : filename(filename) {} + +private: + using enumeratePyBindHandlesType = + std::function; + using enumeratePyBindHandlesCType = TritonPluginResult (*)(uint32_t *, + const char **); + + // Put enumerate API names here, these can be involved with + // enumeratePyBindHandles + const std::string ENUMERATE_PASSES = "tritonEnumeratePluginPasses"; + + const std::string ADD_PASS = "tritonAddPluginPass"; + using addPassType = + std::function; + using addPassCType = TritonPluginResult (*)(mlir::PassManager *, + const char *); + + const std::string REGISTER_PASS = "tritonRegisterPluginPass"; + using registerPassType = std::function; + using registerPassCType = TritonPluginResult (*)(const char *); + + llvm::Error checkLibraryValid(const std::string &error) const; + + llvm::Expected getAddressOfSymbol(const std::string &symbol) const; + + template + llvm::Expected getAPI(const std::string &symbol) const { + llvm::Expected getDetailsFn = getAddressOfSymbol(symbol); + if (auto Err = getDetailsFn.takeError()) { + return Err; + } + auto func = reinterpret_cast(*getDetailsFn); + return func; + } + + llvm::Expected checkAPIResult(TritonPluginResult result, + const char *handle) const; + llvm::Expected + enumeratePyBindHandles(enumeratePyBindHandlesType &enumeratePyBindHandles, + std::vector &passNames); + +public: + std::runtime_error err2exp(llvm::Error Err); + + llvm::Error loadPlugin(); + + llvm::Expected + getPassHandles(std::vector &handles); + + llvm::Expected addPass(mlir::PassManager *pm, + const char *passHandle); + + llvm::Expected registerPass(const char *passHandle); + +private: + std::string filename = ""; + mutable llvm::sys::DynamicLibrary library; + enumeratePyBindHandlesType enumeratePassesAPI; + addPassType addPassAPI; + registerPassType registerPassAPI; + bool isLoaded = false; +}; + +#endif // TRITON_PLUGIN_UTILS_H diff --git a/include/triton/Tools/Sys/GetEnv.hpp b/include/triton/Tools/Sys/GetEnv.hpp index 0028293e3d..4223638eb0 100644 --- a/include/triton/Tools/Sys/GetEnv.hpp +++ b/include/triton/Tools/Sys/GetEnv.hpp @@ -45,6 +45,7 @@ inline const std::set CACHE_INVALIDATING_ENV_VARS = { "TRITON_F32_DEFAULT", "TRITON_PREFER_TMEM_16x256_LAYOUT", "TRITON_ENABLE_EXPERIMENTAL_CONSAN", + "TRITON_PASS_PLUGIN_PATH", "TRITON_INTEL_2DBLOCK_ASSERT", "TRITON_INTEL_AGGRESSIVE_DPAS_REUSE", "TRITON_INTEL_ENABLE_BLOCK_IO_ALL_LAYOUTS", diff --git a/lib/CMakeLists.txt b/lib/CMakeLists.txt index e8ae340f2d..e480ffa52a 100644 --- a/lib/CMakeLists.txt +++ b/lib/CMakeLists.txt @@ -3,4 +3,4 @@ add_subdirectory(Conversion) add_subdirectory(Dialect) add_subdirectory(Target) add_subdirectory(Tools) -add_subdirectory(Instrumentation) +add_subdirectory(Plugins) diff --git a/lib/Instrumentation/CMakeLists.txt b/lib/Instrumentation/CMakeLists.txt deleted file mode 100644 index 1872288c55..0000000000 --- a/lib/Instrumentation/CMakeLists.txt +++ /dev/null @@ -1,47 +0,0 @@ -set(GPU_INSTRUMENTATION_PASSES - PrintLoadStoreMemSpaces - ) - -set(PrintLoadStoreMemSpaces_SOURCES - PrintLoadStoreMemSpaces.cpp - ) - -if(WIN32) - set(TYPE_OUTPUT_DIRECTORY RUNTIME_OUTPUT_DIRECTORY) -else() - set(TYPE_OUTPUT_DIRECTORY LIBRARY_OUTPUT_DIRECTORY) -endif() - -foreach( plugin ${GPU_INSTRUMENTATION_PASSES} ) - add_library( - ${plugin} - SHARED - ${${plugin}_SOURCES} - ) - - target_link_libraries( - ${plugin} - PRIVATE - LLVMCore - LLVMSupport - LLVMTransformUtils - "$<$:-undefined dynamic_lookup>" - ) - # CMAKE_LIBRARY_OUTPUT_DIRECTORY is only set during the Python - # build. It is empty if building directly from the root - # CMakeLists.txt file. Therefore if not building from Python just - # use the default CMake shared lib path otherwise this causes a hard - # build error - if(DEFINED CMAKE_LIBRARY_OUTPUT_DIRECTORY) - set_target_properties(${plugin} PROPERTIES - ${TYPE_OUTPUT_DIRECTORY} - "${CMAKE_LIBRARY_OUTPUT_DIRECTORY}/../instrumentation") - endif(DEFINED CMAKE_LIBRARY_OUTPUT_DIRECTORY) - - # This is set to -fvisibility=hidden in the top level CMake file - # which causes the llvmGetPassPluginInfo symbol to be hidden and - # an "entry point not found" error. Reset it just for this target - if(NOT MSVC) - target_compile_options(${plugin} PRIVATE -fvisibility=default -fno-rtti) - endif() -endforeach() diff --git a/lib/Instrumentation/PrintLoadStoreMemSpaces.cpp b/lib/Instrumentation/PrintLoadStoreMemSpaces.cpp deleted file mode 100644 index bec2f8a3fc..0000000000 --- a/lib/Instrumentation/PrintLoadStoreMemSpaces.cpp +++ /dev/null @@ -1,106 +0,0 @@ -#include "llvm/IR/Module.h" -#include "llvm/IR/PassManager.h" -#include "llvm/Passes/PassBuilder.h" -#include "llvm/Passes/PassPlugin.h" -#include - -using namespace llvm; - -namespace { - -struct LoadStoreMemSpace : public PassInfoMixin { - PreservedAnalyses run(llvm::Module &module, ModuleAnalysisManager &) { - bool modifiedCodeGen = runOnModule(module); - - return (modifiedCodeGen ? llvm::PreservedAnalyses::none() - : llvm::PreservedAnalyses::all()); - } - bool runOnModule(llvm::Module &module); - // isRequired being set to true keeps this pass from being skipped - // if it has the optnone LLVM attribute - static bool isRequired() { return true; } -}; - -} // end anonymous namespace - -static std::map AddrSpaceMap = { - {0, "FLAT"}, {1, "GLOBAL"}, {3, "SHARED"}, {4, "CONSTANT"}, {5, "SCRATCH"}}; - -static std::map LocationCounterSourceMap; - -static std::string LoadOrStoreMap(const BasicBlock::iterator &I) { - if (LoadInst *LI = dyn_cast(I)) - return "LOAD"; - else if (StoreInst *SI = dyn_cast(I)) - return "STORE"; - else - throw std::runtime_error("Error: unknown operation type"); -} -template -static void InstrumentationFunction(const BasicBlock::iterator &I, - const Function &F, const llvm::Module &M, - uint32_t &LocationCounter) { - auto LSI = dyn_cast(I); - if (not LSI) - return; - Value *Op = LSI->getPointerOperand()->stripPointerCasts(); - uint32_t AddrSpace = cast(Op->getType())->getAddressSpace(); - DILocation *DL = dyn_cast(I)->getDebugLoc(); - - std::string SourceAndAddrSpaceInfo = - (F.getName() + " " + DL->getFilename() + ":" + Twine(DL->getLine()) + - ":" + Twine(DL->getColumn())) - .str() + - " " + AddrSpaceMap[AddrSpace] + " " + LoadOrStoreMap(I); - - if (LocationCounterSourceMap.find(SourceAndAddrSpaceInfo) == - LocationCounterSourceMap.end()) { - errs() << LocationCounter << " " << SourceAndAddrSpaceInfo << "\n"; - LocationCounterSourceMap[SourceAndAddrSpaceInfo] = LocationCounter; - LocationCounter++; - } -} - -bool LoadStoreMemSpace::runOnModule(Module &M) { - bool ModifiedCodeGen = false; - uint32_t LocationCounter = 0; - for (auto &F : M) { - if (F.isIntrinsic()) - continue; - StringRef functionName = F.getName(); - if (F.getCallingConv() == CallingConv::AMDGPU_KERNEL || - F.getCallingConv() == CallingConv::PTX_Kernel || - functionName.contains("kernel")) { - for (Function::iterator BB = F.begin(); BB != F.end(); BB++) { - for (BasicBlock::iterator I = BB->begin(); I != BB->end(); I++) { - if (LoadInst *LI = dyn_cast(I)) { - InstrumentationFunction(I, F, M, LocationCounter); - } else if (StoreInst *SI = dyn_cast(I)) { - InstrumentationFunction(I, F, M, LocationCounter); - } - } - } - } - } - return ModifiedCodeGen; -} - -static PassPluginLibraryInfo getPassPluginInfo() { - const auto callback = [](PassBuilder &PB) { - PB.registerOptimizerLastEPCallback([&](ModulePassManager &MPM, auto, auto) { - MPM.addPass(LoadStoreMemSpace()); - return true; - }); - }; - - return {LLVM_PLUGIN_API_VERSION, "print-mem-space", LLVM_VERSION_STRING, - callback}; -}; - -extern "C" LLVM_ATTRIBUTE_WEAK PassPluginLibraryInfo llvmGetPassPluginInfo() { - return getPassPluginInfo(); -} - -#if defined(_WIN32) -#pragma comment(linker, "/export:llvmGetPassPluginInfo") -#endif diff --git a/lib/Plugins/CMakeLists.txt b/lib/Plugins/CMakeLists.txt new file mode 100644 index 0000000000..593335d222 --- /dev/null +++ b/lib/Plugins/CMakeLists.txt @@ -0,0 +1,43 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls -name Plugins) +add_public_tablegen_target(TritonPluginsIncGen) + +llvm_canonicalize_cmake_booleans( + MLIR_ENABLE_BINDINGS_PYTHON +) + +set(TRITON_PLUGIN_PASSES + TritonPluginsTestLib + ) + +set(TritonPluginsTestLib_SOURCES + TritonPlugin.cpp + ) + + +foreach( plugin ${TRITON_PLUGIN_PASSES} ) + add_mlir_library(${plugin} + ${${plugin}_SOURCES} + SHARED + + ADDITIONAL_HEADER_DIRS + ${PROJECT_BINARY_DIR}/lib + + DEPENDS + TritonTableGen + TritonCanonicalizeIncGen + TritonPluginsIncGen + + LINK_LIBS PUBLIC + MLIRPass + LLVMSupport + MLIRSupport + "$<$:-undefined dynamic_lookup>" + ) + + set_target_properties(${plugin} PROPERTIES + LIBRARY_OUTPUT_DIRECTORY + "${CMAKE_LIBRARY_OUTPUT_DIRECTORY}/../plugins") + + target_compile_options(${plugin} PRIVATE -fvisibility=hidden) +endforeach() diff --git a/lib/Plugins/Passes.td b/lib/Plugins/Passes.td new file mode 100644 index 0000000000..a8007a09e8 --- /dev/null +++ b/lib/Plugins/Passes.td @@ -0,0 +1,9 @@ +#ifndef TRITONGPU_PLUGIN_PASSES +#define TRITONGPU_PLUGIN_PASSES + +include "mlir/Pass/PassBase.td" + +def TritonGPUMLIRPlugin : Pass<"tritongpu-plugin", "mlir::ModuleOp"> { + let summary = "Triton MLIR Plugin Pass"; +} +#endif diff --git a/lib/Plugins/README.md b/lib/Plugins/README.md new file mode 100644 index 0000000000..a12b0c79ae --- /dev/null +++ b/lib/Plugins/README.md @@ -0,0 +1,311 @@ +# Triton TTIR and TTGIR Out of Tree Plugin Passes + +## Overview +Triton’s existing pass pipelines are assembled in the various extended compiler.py files that live in Triton’s backends. Currently when we want to insert +passes either for downstream optimizations, custom ops, or instrumentation it is required for the compiler.py file itself to be modified and all of Triton to be +recompiled. + +In order to allow for more downstream configurability we have implemented a custom MLIR level (TTIR and TTGIR) pass plugin and configuration system that allows for either +overriding the compiler.py pipeline entirely or inserting passes and custom ops through a compiler pipeline hook. Example use cases include: +- Custom ops and lowering passes +- Custom optimization passes +- Instrumentation and analysis passes +- Specialized per kernel passes (e.g. kernel/model specific warp specialization) + +Custom passes/ops are implemented as a shared library that is loaded by Triton at JIT compile/runtime. The plugins can be implement entirely out of tree or in the Triton source tree as +long as the libtriton.so is linked to the plugin and the Triton include passes are used to build the plugin. + +## Example 1: Developing a custom pass and running triton-opt to inspect the modified IR +``` bash +export LLVM_BUILD_SHARED_LIBS=1; make dev-install-llvm +TRITON_PASS_PLUGIN_PATH=/home/triton/python/triton/plugins/libTritonPluginsTestLib.so triton-opt -tritongpu-plugin test/Plugins/test-plugin.mlir +``` +``` MLIR +module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:80"} { + tt.func @foo() { + tt.return + } +} +``` + +After the out of tree pass runs, becomes: +``` MLIR +module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:80"} { + tt.func @bar() { + tt.return + } +} +``` +Function "foo" is renamed to "bar" by the out of tree pass. + +## Example 2: Inserting a new pass into the compiler pipeline +Let's take the following toy kernel example: +``` python +import torch +import os + +import triton +import triton.language as tl +from triton._C.libtriton import ir, passes +from triton import knobs + +DEVICE = triton.runtime.driver.active.get_active_torch_device() + +@triton.jit +def kernel(BLOCK_SIZE: tl.constexpr): + return + +if __name__ == '__main__': + + size = 98432 + x = torch.rand(size, device=DEVICE) + output = torch.empty_like(x) + n_elements = output.numel() + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) + + h = kernel[grid](BLOCK_SIZE=1024) + print(h.asm["ttgir"]) +``` + +Running as is will produce the expected output of printing the TTGIR of the kernel: +``` bash +python test.py +``` +``` MLIR +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @kernel() attributes {noinline = false} { + tt.return loc(#loc1) + } loc(#loc) +} loc(#loc) +#loc = loc("/home/triton/test.py":13:0) +#loc1 = loc("/home/triton/test.py":14:4) +``` + +Running same code but loading the plugin library also produces the same results since, while the plugin pass has been loaded and registered with the +pass manager it is not inserted into the compiler pass pipeline: + +``` bash +TRITON_PASS_PLUGIN_PATH=/home/triton/python/triton/plugins/libTritonPluginsTestLib.so python test.py +``` + +``` MLIR +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @kernel() attributes {noinline = false} { + tt.return loc(#loc1) + } loc(#loc) +} loc(#loc) +#loc = loc("/home/triton/test.py":13:0) +#loc1 = loc("/home/triton/test.py":14:4) +``` + +Finally, if we both load the plugin at runtime and insert the pass pipeline hook into the kernel code: + +``` python +import torch +import os + +import triton +import triton.language as tl +from triton._C.libtriton import ir, passes +from triton import knobs + +DEVICE = triton.runtime.driver.active.get_active_torch_device() + +@triton.jit +def kernel(BLOCK_SIZE: tl.constexpr): + return + +#These two methods must be implemented by the plugin +def get_key(): + return pathlib.Path(__file__).read_text() +def get_hash(): + return hashlib.sha256(get_key().encode('utf-8')).hexdigest() + +def inspect_stages_hook(self=None, stages=None, options=None, language=None, capability=None): + # If the hook is called with no arguments we assume were just after the key and hash and don't want to + # actually execute the pipeline yet. + # This no argument early return must be implemented. + if all(arg is None for arg in (stages, options, language, capability)): + return get_key(), get_hash() + + def make_ttir_wrapper(mod, metadata, opt, capability): + mod = self.make_ttir(mod, metadata, opt, capability) + pm = ir.pass_manager(mod.context) + pm.enable_debug() + passes.plugin.add_plugin(pm) + pm.run(mod, 'make_ttir_plugin') + return mod + + stages["ttir"] = lambda src, metadata: make_ttir_wrapper(src, metadata, options, capability) + + return get_key(), get_hash() + +if __name__ == '__main__': + + size = 98432 + x = torch.rand(size, device=DEVICE) + output = torch.empty_like(x) + n_elements = output.numel() + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) + + h = kernel[grid](BLOCK_SIZE=1024) + print(h.asm["ttgir"]) + + if "TRITON_PASS_PLUGIN_PATH" in os.environ: + knobs.runtime.add_stages_inspection_hook = inspect_stages_hook + h = kernel[grid](BLOCK_SIZE=1024) + print(h.asm["ttgir"]) + + # Unset the hook to go back to the standard pipeline + knobs.runtime.add_stages_inspection_hook = None + h = kernel[grid](BLOCK_SIZE=1024) + print(h.asm["ttgir"]) +``` + +``` bash +TRITON_PASS_PLUGIN_PATH=/home/triton/python/triton/plugins/libTritonPluginsTestLib.so python test.py +``` + +Shows the pass ran and modified the kernel name but only after the hook is set. Any kernels before the hook or after the hook is unset are left unchanged. + +``` MLIR +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @kernel() attributes {noinline = false} { + tt.return loc(#loc1) + } loc(#loc) +} loc(#loc) +#loc = loc("/home/triton/test.py":13:0) +#loc1 = loc("/home/triton/test.py":14:4) + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @foo() attributes {noinline = false} { + tt.return loc(#loc1) + } loc(#loc) +} loc(#loc) +#loc = loc("/home/triton/test.py":13:0) +#loc1 = loc("/home/triton/test.py":14:4) + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @kernel() attributes {noinline = false} { + tt.return loc(#loc1) + } loc(#loc) +} loc(#loc) +#loc = loc("/home/triton/test.py":13:0) +#loc1 = loc("/home/triton/test.py":14:4) +``` + +The hook, as defined, in the example will insert the pass at the end of the make_ttir pipeline but it's placement in the Triton pipeline is abritary. +This functionality can be toggled on and off by just commenting out this line in kernel code (or setting to None): +knobs.runtime.add_stages_inspection_hook = inspect_stages_hook +without needing any core compiler changes or rebuilding Triton. + +## Example 3: Fully customizing the compiler pipeline with pass and op insertions at abitrary locations + +Here we now run two kernels one with the full standard Triton pipeline and one with fully customized pipeline entirely from within +kernel code with modifying any core Triton compiler code or recompiling. We run the kernel with a hook to output the standard pipeline, modify +the compiler.py file to insert our out of tree pass before add_loop_unroll pass (although there is no restriction of where it can be inserted), +then run the second kernel with a different pipeline. This modification can, as before, be seen in the kernel function name modification by the +inserted pass. + +``` python +import torch +import os +import sys + +import triton +import triton.language as tl +from triton._C.libtriton import ir, passes +from triton import knobs +import inspect +from importlib.util import module_from_spec, spec_from_file_location + +from triton.backends.compiler import Language + +DEVICE = triton.runtime.driver.active.get_active_torch_device() + + +@triton.jit +def kernel1(BLOCK_SIZE: tl.constexpr): + return +@triton.jit +def kernel2(BLOCK_SIZE: tl.constexpr): + return + +def get_key(): + return pathlib.Path(__file__).read_text() +def get_hash(): + return hashlib.sha256(get_key().encode('utf-8')).hexdigest() + +def dump_stages_hook(self=None, stages=None, options=None, language=None, capability=None): + if all(arg is None for arg in (stages, options, language, capability)): + return get_key(), get_hash() + source_code = "# This is generated from Triton compiler.py" + source_code = ( + source_code + + "\n" + + "from triton._C.libtriton import ir, passes, llvm, amd, nvidia" + ) + source_code = source_code + "\n" + "class GPUOverrideBackend:" + source_code = source_code + "\n" + inspect.getsource(self.make_ttir) + source_code = source_code + "\n" + inspect.getsource(self.make_ttgir) + + with open("compiler_override.py", "w") as file: + file.write(source_code) + return get_key(), get_hash() +def override_stages(self=None, stages=None, options=None, language=None, capability=None): + if all(arg is None for arg in (stages, options, language, capability)): + return get_key(), get_hash() + if language != Language.TRITON: + return + full_name = "compiler_override.py" + + print(f"\nOverriding compile pass stages with file {full_name}") + module_name = "triton_override_compiler_stages" + spec = ( + spec_from_file_location(module_name, full_name) + if os.path.isfile(full_name) + else None + ) + if not spec: + return + + module = module_from_spec(spec) + sys.modules[module_name] = module + spec.loader.exec_module(module) + if not hasattr(module, "GPUOverrideBackend"): + return + module = getattr(module, "GPUOverrideBackend") + + has_func = lambda mod, name: hasattr(mod, name) and callable(getattr(mod, name)) + make_lambda = lambda f: lambda src, metadata: f(src, metadata, options, capability) + if has_func(module, "make_ttir"): + stages["ttir"] = make_lambda(module.make_ttir) + if has_func(module, "make_ttgir"): + stages["ttgir"] = make_lambda(module.make_ttgir) + return get_key(), get_hash() + +if __name__ == '__main__': + + size = 98432 + x = torch.rand(size, device=DEVICE) + output = torch.empty_like(x) + n_elements = output.numel() + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) + + knobs.runtime.add_stages_inspection_hook = dump_stages_hook + h = kernel1[grid](BLOCK_SIZE=1024) + filename = "compiler_override.py" + + with open(filename, "r") as infile: + file_str = infile.readlines() + + with open(filename, "w") as outfile: + for line in file_str: + if "add_loop_unroll" in line: + outfile.write("\n passes.plugin.add_plugin(pm)\n") + outfile.write(line) + if "TRITON_PASS_PLUGIN_PATH" in os.environ: + knobs.runtime.add_stages_inspection_hook = override_stages + h = kernel2[grid](BLOCK_SIZE=1024) + print(h.asm["ttgir"]) +``` diff --git a/lib/Plugins/TritonPlugin.cpp b/lib/Plugins/TritonPlugin.cpp new file mode 100644 index 0000000000..26e18c0ed7 --- /dev/null +++ b/lib/Plugins/TritonPlugin.cpp @@ -0,0 +1,92 @@ +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Tools/PluginUtils.h" +#include + +#if defined(_WIN32) +#define EXPORT_FUNC __declspec(dllexport) +#else +#define EXPORT_FUNC __attribute__((visibility("default"))) +#endif + +#define TRITON_PLUGIN_API extern "C" EXPORT_FUNC TritonPluginResult + +namespace mlir { +namespace triton { +namespace plugin { + +#define GEN_PASS_DEF_TRITONGPUMLIRPLUGIN +#include "Passes.h.inc" + +struct MLIRPluginPass : public impl::TritonGPUMLIRPluginBase { + void runOnOperation() override { + + MLIRContext *context = &getContext(); + ModuleOp mod = getOperation(); + mod.walk([&](FunctionOpInterface funcOp) { + StringAttr funcNameAttr = funcOp.getNameAttr(); + funcOp.setName("foo"); + }); + } +}; + +} // namespace plugin +} // namespace triton +} // namespace mlir + +static void addTritonPluginPass(mlir::PassManager *pm) { + pm->addPass(mlir::triton::plugin::createTritonGPUMLIRPlugin()); +} + +static void registerTritonPluginPass() { + ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> { + return mlir::triton::plugin::createTritonGPUMLIRPlugin(); + }); +} + +static const char *ADD_PLUGIN_PASS_NAME = "add_plugin"; +static std::unordered_map passMap = + {{ADD_PLUGIN_PASS_NAME, addTritonPluginPass}}; +static std::unordered_map registryMap = { + {ADD_PLUGIN_PASS_NAME, registerTritonPluginPass}}; +static std::vector passNamesTable = {ADD_PLUGIN_PASS_NAME}; + +// Key APIs: + +TRITON_PLUGIN_API +tritonAddPluginPass(mlir::PassManager *pm, const char *passName) { + std::string passNameStr(passName); + if (passMap.find(passNameStr) == passMap.end()) + return TP_GENERIC_FAILURE; + passMap[passNameStr](pm); + return TP_SUCCESS; +} + +TRITON_PLUGIN_API +tritonRegisterPluginPass(const char *passName) { + std::string passNameStr(passName); + if (registryMap.find(passNameStr) == registryMap.end()) + return TP_GENERIC_FAILURE; + registryMap[passNameStr](); + return TP_SUCCESS; +} + +TRITON_PLUGIN_API +tritonEnumeratePluginPasses(uint32_t *passCount, const char **passNames) { + if (!passCount) + return TP_GENERIC_FAILURE; + auto count = passMap.size(); + assert(count == registryMap.size() && + "Expected register and add passes map size to match"); + *passCount = count; + if (!passNames) + return TP_SUCCESS; + unsigned i = 0; + for (auto passName : passNamesTable) { + passNames[i] = passName; + } + return TP_SUCCESS; +} diff --git a/lib/Tools/CMakeLists.txt b/lib/Tools/CMakeLists.txt index a2f9f8aea5..611468b9a2 100644 --- a/lib/Tools/CMakeLists.txt +++ b/lib/Tools/CMakeLists.txt @@ -2,6 +2,7 @@ add_triton_library(TritonTools GenericSwizzling.cpp LayoutUtils.cpp LinearLayout.cpp + PluginUtils.cpp DEPENDS diff --git a/lib/Tools/PluginUtils.cpp b/lib/Tools/PluginUtils.cpp new file mode 100644 index 0000000000..97e97bde83 --- /dev/null +++ b/lib/Tools/PluginUtils.cpp @@ -0,0 +1,115 @@ +#include "triton/Tools/PluginUtils.h" + +llvm::Error TritonPlugin::checkLibraryValid(const std::string &error) const { + if (!library.isValid()) { + auto msg = llvm::Twine("Failed to load plugin library: " + error + "\n"); + return llvm::createStringError(msg); + } + return llvm::Error::success(); +} + +llvm::Expected +TritonPlugin::getAddressOfSymbol(const std::string &symbol) const { + if (auto isValid = checkLibraryValid("not loaded")) + return isValid; + intptr_t getDetailsFn = (intptr_t)library.getAddressOfSymbol(symbol.c_str()); + if (!getDetailsFn) { + auto msg = llvm::Twine("Failed to get symbol: " + symbol + "\n"); + return llvm::createStringError(msg); + } + return getDetailsFn; +} + +llvm::Expected +TritonPlugin::checkAPIResult(TritonPluginResult result, + const char *handle) const { + if (result == TP_SUCCESS) + return TP_SUCCESS; + std::string msg; + llvm::raw_string_ostream os(msg); + os << "Failed to add/register plugin pass (" << handle + << ") to pass manager, error code: " << result; + return llvm::createStringError(msg); +} + +std::runtime_error TritonPlugin::err2exp(llvm::Error Err) { + std::string msg; + llvm::raw_string_ostream os(msg); + os << Err; + return std::runtime_error(msg); +} + +llvm::Error TritonPlugin::loadPlugin() { + if (isLoaded) + return llvm::Error::success(); + + std::string error; + library = + llvm::sys::DynamicLibrary::getPermanentLibrary(filename.c_str(), &error); + if (auto isValid = checkLibraryValid(error)) + return isValid; + + auto enumeratePassesAPIOrErr = + getAPI( + ENUMERATE_PASSES); + auto addPassAPIOrErr = getAPI(ADD_PASS); + auto registerPassAPIOrErr = + getAPI(REGISTER_PASS); + + if (auto Err = enumeratePassesAPIOrErr.takeError()) + return Err; + if (auto Err = addPassAPIOrErr.takeError()) + return Err; + if (auto Err = registerPassAPIOrErr.takeError()) + return Err; + + enumeratePassesAPI = *enumeratePassesAPIOrErr; + addPassAPI = *addPassAPIOrErr; + registerPassAPI = *registerPassAPIOrErr; + isLoaded = true; + return llvm::Error::success(); +} + +llvm::Expected TritonPlugin::enumeratePyBindHandles( + enumeratePyBindHandlesType &enumeratePyBindHandles, + std::vector &handles) { + if (auto Err = loadPlugin()) + return Err; + + uint32_t passCount = 0; + handles.clear(); + auto result = enumeratePyBindHandles(&passCount, nullptr); + if (result == TP_SUCCESS) { + if (passCount == 0) + return TP_SUCCESS; + + handles.resize(passCount); + result = enumeratePyBindHandles(&passCount, handles.data()); + } + + if (result == TP_SUCCESS) + return TP_SUCCESS; + std::string msg; + llvm::raw_string_ostream os(msg); + os << "Failed to retrive plugin pass handles, error code: " << result; + return llvm::createStringError(msg); +} + +llvm::Expected +TritonPlugin::getPassHandles(std::vector &passNames) { + return enumeratePyBindHandles(enumeratePassesAPI, passNames); +} + +llvm::Expected +TritonPlugin::addPass(mlir::PassManager *pm, const char *passHandle) { + if (auto Err = loadPlugin()) + return Err; + return checkAPIResult(addPassAPI(pm, passHandle), passHandle); +} + +llvm::Expected +TritonPlugin::registerPass(const char *passHandle) { + if (auto Err = loadPlugin()) + return Err; + return checkAPIResult(registerPassAPI(passHandle), passHandle); +} diff --git a/python/src/passes.cc b/python/src/passes.cc index f2d93fab96..8977b59913 100644 --- a/python/src/passes.cc +++ b/python/src/passes.cc @@ -12,8 +12,11 @@ #include "triton/Dialect/TritonGPU/Transforms/Passes.h" #include "triton/Dialect/TritonInstrument/Transforms/Passes.h" #include "triton/Target/LLVMIR/Passes.h" +#include "triton/Tools/PluginUtils.h" +#include "triton/Tools/Sys/GetEnv.hpp" #include #include +#include namespace py = pybind11; @@ -96,6 +99,30 @@ void init_triton_passes_ttgpuir(py::module &&m) { createTritonGPUOptimizePartitionWarps); } +void init_plugin_passes(py::module &&m) { + std::string filename = + mlir::triton::tools::getStrEnv("TRITON_PASS_PLUGIN_PATH"); + if (filename.empty()) + return; + + TritonPlugin TP(filename); + std::vector passNames; + if (auto result = TP.getPassHandles(passNames); !result) + throw TP.err2exp(result.takeError()); + + for (unsigned i = 0; i < passNames.size(); ++i) { + const char *passName = passNames.data()[i]; + + m.def(passName, [passName](mlir ::PassManager &pm) { + std::string filename = + mlir::triton::tools::getStrEnv("TRITON_PASS_PLUGIN_PATH"); + TritonPlugin TP(filename); + if (auto result = TP.addPass(&pm, passName); !result) + throw TP.err2exp(result.takeError()); + }); + } +} + void init_triton_passes_convert(py::module &&m) { using namespace mlir; ADD_PASS_WRAPPER_0("add_scf_to_cf", createSCFToControlFlowPass); @@ -130,4 +157,5 @@ void init_triton_passes(py::module &&m) { init_triton_passes_ttgpuir(m.def_submodule("ttgpuir")); init_triton_passes_llvmir(m.def_submodule("llvmir")); init_gluon_passes(m.def_submodule("gluon")); + init_plugin_passes(m.def_submodule("plugin")); } diff --git a/python/test/unit/plugins/custom_stages.py b/python/test/unit/plugins/custom_stages.py new file mode 100644 index 0000000000..f98cab4b0a --- /dev/null +++ b/python/test/unit/plugins/custom_stages.py @@ -0,0 +1,38 @@ +from triton._C.libtriton import ir, passes +import hashlib +import pathlib + + +# These two methods must be implemented and returned by the plugin hook. +# any changes in this entire file and the the plugin pipeline +# will trigger a recompile since the hash will change. To be +# less conservative, we could use a hash of the inspect_stages_hook +# function but then changes outside of the function won't be considered +# potentially causing a stale kernel hash +def get_key(): + return pathlib.Path(__file__).read_text() + + +def get_hash(): + return hashlib.sha256(get_key().encode('utf-8')).hexdigest() + + +# Keep custom pipeline stages in a seperate file from kernels as any change to the file +# will trigger a recompile. +def inspect_stages_hook(self=None, stages=None, options=None, language=None, capability=None): + # If the hook is called with no arguments we assume were just after the key and hash and don't want to + # actually execute the pipeline yet + if all(arg is None for arg in (stages, options, language, capability)): + return get_key(), get_hash() + + def make_ttir_wrapper(mod, metadata, opt, capability): + mod = self.make_ttir(mod, metadata, opt, capability) + pm = ir.pass_manager(mod.context) + pm.enable_debug() + passes.plugin.add_plugin(pm) + pm.run(mod, 'make_ttir_plugin') + return mod + + stages["ttir"] = lambda src, metadata: make_ttir_wrapper(src, metadata, options, capability) + + return get_key(), get_hash() diff --git a/python/test/unit/plugins/test_plugin.py b/python/test/unit/plugins/test_plugin.py new file mode 100644 index 0000000000..9a895174b1 --- /dev/null +++ b/python/test/unit/plugins/test_plugin.py @@ -0,0 +1,43 @@ +import torch + +import pytest +import os + +import triton +import triton.language as tl +from triton import knobs +import custom_stages + + +@pytest.mark.parametrize(None, [None]) +@triton.jit +def kernel1(BLOCK_SIZE: tl.constexpr): + return + + +@pytest.mark.parametrize(None, [None]) +@triton.jit +def kernel2(BLOCK_SIZE: tl.constexpr): + return + + +def test_op(capfd, device: str): + if os.environ.get('LLVM_BUILD_SHARED_LIBS', '0') == '0': + return + + size = 98432 + x = torch.rand(size, device=device) + output = torch.empty_like(x) + n_elements = output.numel() + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) + + h = kernel1[grid](BLOCK_SIZE=1024) + assert "tt.func public @foo" not in h.asm["ttgir"] + + knobs.runtime.add_stages_inspection_hook = custom_stages.inspect_stages_hook + h = kernel2[grid](BLOCK_SIZE=1024) + assert "tt.func public @foo" in h.asm["ttgir"] + + knobs.runtime.add_stages_inspection_hook = None + h = kernel2[grid](BLOCK_SIZE=1024) + assert "tt.func public @foo" not in h.asm["ttgir"] diff --git a/python/triton/compiler/compiler.py b/python/triton/compiler/compiler.py index 1cbac97747..1fd15908b0 100644 --- a/python/triton/compiler/compiler.py +++ b/python/triton/compiler/compiler.py @@ -253,6 +253,9 @@ def compile(src, target=None, options=None, _env_vars=None): # create cache manager env_vars = get_cache_invalidating_env_vars() if _env_vars is None else _env_vars key = get_cache_key(src, backend, options, env_vars=env_vars) + if knobs.runtime.add_stages_inspection_hook is not None: + inspect_stages_key, inspect_stages_hash = knobs.runtime.add_stages_inspection_hook() + key += inspect_stages_key hash = hashlib.sha256(key.encode("utf-8")).hexdigest() fn_cache_manager = get_cache_manager(hash) # For dumping/overriding only hash the source as we want it to be independent of triton diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index 0f506818f9..7c73a84f36 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -709,6 +709,12 @@ def run(self, *args, grid, warmup, **kwargs): # the type and the second parameter is the 'specialization' value. bound_args, specialization, options = binder(*args, **kwargs) + # add a cache field to the kernel specializations for kernel specific + # pass pipelines + if knobs.runtime.add_stages_inspection_hook is not None: + inspect_stages_key, inspect_stages_hash = knobs.runtime.add_stages_inspection_hook() + specialization.append(f'("custom_pipeline", {inspect_stages_hash})') + key = compute_cache_key(kernel_key_cache, specialization, options) kernel = kernel_cache.get(key, None) diff --git a/python/triton_kernels/tests/test_matmul.py b/python/triton_kernels/tests/test_matmul.py index bf75fb015a..f74a2af059 100644 --- a/python/triton_kernels/tests/test_matmul.py +++ b/python/triton_kernels/tests/test_matmul.py @@ -369,8 +369,6 @@ def _test_op(m, n, k, split_k, do_gather, do_scatter, inner_expt_opt, has_y_gamm pytest.skip("Scale preshuffling on AMD GPU has not been emulated on non-CDNA4 arch yet.") if "mx" not in weight_dtype_str: pytest.skip("Non-scale swizzling not supported on CDNA4 yet") - if n % 32 != 0 or k % (32 * 8) != 0: - pytest.skip(f"Shape {m}x{n}x{k} is not supported for scale swizzling on AMD GPU") if is_cuda(): if torch.cuda.get_device_capability()[0] < 9: pytest.skip("NYI. Ampere swizzling.") diff --git a/python/triton_kernels/tests/test_tensor_details/test_layout_cdna4.py b/python/triton_kernels/tests/test_tensor_details/test_layout_cdna4.py new file mode 100644 index 0000000000..0ffa8b6c58 --- /dev/null +++ b/python/triton_kernels/tests/test_tensor_details/test_layout_cdna4.py @@ -0,0 +1,24 @@ +import pytest +import torch +from triton_kernels.tensor_details.layout import CDNA4MXScaleLayout + +# ------------------------------------------------------------ +# Torch tests +# ------------------------------------------------------------ + + +@pytest.mark.parametrize( + "shape", + [ + (3, 4096, 1024), + (10, 254, 60), + (1, 320, 160), + (2, 16, 512), + (3, 2, 36), + ], +) +def test_mxfp4_scale_roundtrip(shape): + x = torch.randint(0, 256, shape, dtype=torch.uint8, device="cuda") + layout = CDNA4MXScaleLayout(x.shape) + res = layout.unswizzle_data(layout.swizzle_data(x)) + assert (res == x).all() diff --git a/python/triton_kernels/triton_kernels/tensor_details/layout_details/cdna4_scale.py b/python/triton_kernels/triton_kernels/tensor_details/layout_details/cdna4_scale.py index 870d739096..ca23f090db 100644 --- a/python/triton_kernels/triton_kernels/tensor_details/layout_details/cdna4_scale.py +++ b/python/triton_kernels/triton_kernels/tensor_details/layout_details/cdna4_scale.py @@ -1,3 +1,5 @@ +import math +import torch from dataclasses import dataclass import triton import triton.language as tl @@ -12,24 +14,31 @@ class CDNA4MXScaleLayout(Layout): def __init__(self, shape) -> None: super().__init__(shape) + ( + *self.leading_shape, + self.K_SCALE, + self.N, + ) = shape + self.B = math.prod(self.leading_shape) + self.ALIGN_K_SCALE = 8 + self.ALIGN_N = 32 + self.K_SCALE_pad = math.ceil(self.K_SCALE / self.ALIGN_K_SCALE) * self.ALIGN_K_SCALE + self.N_pad = math.ceil(self.N / self.ALIGN_N) * self.ALIGN_N def swizzle_data(self, data): - block_shape = data.shape - SCALE_K = block_shape[-2] - N = block_shape[-1] + data = torch.nn.functional.pad(data, (0, self.N_pad - self.N, 0, self.K_SCALE_pad - self.K_SCALE)) data = data.transpose(-1, -2) - data = data.view(-1, N // NON_K_PRESHUFFLE_BLOCK_SIZE, 2, 16, SCALE_K // 8, 2, 4, 1) + data = data.view(-1, self.N_pad // NON_K_PRESHUFFLE_BLOCK_SIZE, 2, 16, self.K_SCALE_pad // 8, 2, 4, 1) data = data.permute(0, 1, 4, 6, 3, 5, 2, 7).contiguous() - if len(block_shape) == 3: - E = block_shape[0] - data = data.reshape(E, N // 32, SCALE_K * 32) - else: - assert len(block_shape) == 2 - data = data.reshape(N // 32, SCALE_K * 32) + data = data.reshape(self.B, self.N_pad // 32, self.K_SCALE_pad * 32) return data.transpose(-1, -2) def unswizzle_data(self, data): - raise NotImplementedError() + data = data.transpose(-1, -2) + data = data.view(-1, self.N_pad // NON_K_PRESHUFFLE_BLOCK_SIZE, self.K_SCALE_pad // 8, 4, 16, 2, 2, 1) + data = data.permute(0, 1, 6, 4, 2, 5, 3, 7) + data = data.reshape(*self.leading_shape, self.N_pad, self.K_SCALE_pad) + return data.transpose(-1, -2)[..., :self.K_SCALE, :self.N] def swizzle_block_shape(self, block_shape): SCALE_K = block_shape[-2] diff --git a/scripts/build-llvm-project.sh b/scripts/build-llvm-project.sh index 3651f4ee67..d8ce9aebd7 100755 --- a/scripts/build-llvm-project.sh +++ b/scripts/build-llvm-project.sh @@ -5,6 +5,7 @@ REPO_ROOT="$(git rev-parse --show-toplevel)" LLVM_TARGETS=${LLVM_TARGETS:-Native;NVPTX;AMDGPU} LLVM_PROJECTS=${LLVM_PROJECTS:-mlir;llvm;lld} LLVM_BUILD_TYPE=${LLVM_BUILD_TYPE:-RelWithDebInfo} +LLVM_BUILD_SHARED_LIBS=${LLVM_BUILD_SHARED_LIBS:-OFF} LLVM_COMMIT_HASH=${LLVM_COMMIT_HASH:-$(cat "$REPO_ROOT/cmake/llvm-hash.txt")} LLVM_PROJECT_PATH=${LLVM_PROJECT_PATH:-"$REPO_ROOT/llvm-project"} LLVM_BUILD_PATH=${LLVM_BUILD_PATH:-"$LLVM_PROJECT_PATH/build"} @@ -21,6 +22,7 @@ if [ -z "$CMAKE_ARGS" ]; then -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DLLVM_ENABLE_LLD=ON + -DBUILD_SHARED_LIBS="$LLVM_BUILD_SHARED_LIBS" -DLLVM_OPTIMIZED_TABLEGEN=ON -DMLIR_ENABLE_BINDINGS_PYTHON=OFF -DLLVM_TARGETS_TO_BUILD="$LLVM_TARGETS" diff --git a/setup.py b/setup.py index 4e500ff126..9083718fc5 100644 --- a/setup.py +++ b/setup.py @@ -519,6 +519,11 @@ def build_extension(self, ext): "-DCMAKE_SHARED_LINKER_FLAGS=-fuse-ld=lld", ] + if check_env_flag("LLVM_BUILD_SHARED_LIBS"): + cmake_args += ["-DLLVM_BUILD_SHARED_LIBS=1"] + else: + cmake_args += ["-DLLVM_BUILD_SHARED_LIBS=0"] + # Note that asan doesn't work with binaries that use the GPU, so this is # only useful for tools like triton-opt that don't run code on the GPU. # diff --git a/test/Plugins/test-plugin.mlir b/test/Plugins/test-plugin.mlir new file mode 100644 index 0000000000..f16ef07882 --- /dev/null +++ b/test/Plugins/test-plugin.mlir @@ -0,0 +1,30 @@ +// RUN: TRITON_PASS_PLUGIN_PATH=%shlibdir/../plugins/libTritonPluginsTestLib.so triton-opt -split-input-file -tritongpu-plugin %s | FileCheck %s --check-prefix=CHECK-PLUGIN +// RUN: TRITON_PASS_PLUGIN_PATH=%shlibdir/../plugins/libTritonPluginsTestLib.so triton-opt -split-input-file %s | FileCheck %s -allow-unused-prefixes --check-prefix=CHECK-NOFLAG +// RUN: triton-opt -split-input-file %s | FileCheck %s -allow-unused-prefixes --check-prefix=CHECK-BASE + +// REQUIRES: shared-libs + +module attributes {"ttg.num-warps" = 4 : i32, "ttg.target" = "cuda:80"} { + // CHECK-PLUGIN: func @foo() + tt.func @bar() { + tt.return + } +} // module + +// ----- + +module attributes {"ttg.num-warps" = 4 : i32, "ttg.target" = "cuda:80"} { + // CHECK-NOFLAG: func @bar() + tt.func @bar() { + tt.return + } +} // module + +// ----- + +module attributes {"ttg.num-warps" = 4 : i32, "ttg.target" = "cuda:80"} { + // CHECK-BASE: func @bar() + tt.func @bar() { + tt.return + } +} // module diff --git a/test/lit.cfg.py b/test/lit.cfg.py index 0890e609b7..09044bebe4 100644 --- a/test/lit.cfg.py +++ b/test/lit.cfg.py @@ -63,6 +63,10 @@ ToolSubst('%PYTHON', config.python_executable, unresolved='ignore'), ] +# Static libraries are not built if LLVM_BUILD_SHARED_LIBS is ON. +if config.build_shared_libs: + config.available_features.add("shared-libs") + llvm_config.add_tool_substitutions(tools, tool_dirs) # TODO: what's this? diff --git a/test/lit.site.cfg.py.in b/test/lit.site.cfg.py.in index fd1fb486dd..59b212a4d2 100644 --- a/test/lit.site.cfg.py.in +++ b/test/lit.site.cfg.py.in @@ -14,6 +14,7 @@ config.lit_tools_dir = "@LLVM_LIT_TOOLS_DIR@" config.mlir_binary_dir = "@MLIR_BINARY_DIR@" config.python_executable = "@Python3_EXECUTABLE@" config.enable_bindings_python = @MLIR_ENABLE_BINDINGS_PYTHON@ +config.build_shared_libs = @LLVM_BUILD_SHARED_LIBS@ import lit.llvm diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/BlockPingpong.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/BlockPingpong.cpp index e78db6422a..6c3773afee 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/BlockPingpong.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/BlockPingpong.cpp @@ -1185,8 +1185,19 @@ void Pingponger::getDotPingponged() { // times for issuing the memory operations and issuing dot operations, // smaller tile sizes are not likely to get any advantage from current dot // centric pingpong scheduling. - if (tileSize <= smallTile && tileSize >= minTile) + if (tileSize <= smallTile && tileSize >= minTile) { transformOnePPClusters(builder, loc); + LDBG("Pingpong scheduling applied for numWarps=4 with tileSize=" + + std::to_string(tileSize) + " (in range [" + std::to_string(minTile) + + ", " + std::to_string(smallTile) + + "]), One Dot-Memory (ping-pong) cluster used."); + } else { + std::stringstream message; + message << "Skipping pingpong for numWarps=4: tileSize=" << tileSize + << " is outside the range [" << minTile << ", " << smallTile + << "]"; + LDBG(message.str()); + } // numWarps=4 doesn't need asymmetric sync, return. return; } else if (numWarps == 8 && numStages == 2) { @@ -1216,8 +1227,15 @@ void Pingponger::getDotPingponged() { "cluster transformation"); return; } - } else + } else { + std::stringstream message; + message << "Skipping pingpong for numWarps=8, numStages=2: tileSize=" + << tileSize + << " does not match supported tile sizes (medium=" << mediumTile + << " or large=" << largeTile << ")"; + LDBG(message.str()); return; + } // Let half of the warps start the loop first and the others follow later // but in the synchronized way. This can be accomplished by calling