diff --git a/byo_compiler.py b/byo_compiler.py new file mode 100644 index 000000000000..83467a013885 --- /dev/null +++ b/byo_compiler.py @@ -0,0 +1,65 @@ + +@staticmethod +def byo_make_ttgir(pm, mod, metadata, opt, capability, cluster_info, dump_enabled, passes, nvidia): + + passes.ttir.add_convert_to_ttgpuir(pm, f"cuda:{capability}", opt.num_warps, 32, opt.num_ctas) + # optimize TTGIR + passes.ttgpuir.add_coalesce(pm) + if capability // 10 >= 8: + passes.ttgpuir.add_f32_dot_tc(pm) + # TODO(Qingyi): Move PlanCTAPass to the front of CoalescePass + nvidia.passes.ttnvgpuir.add_plan_cta(pm, cluster_info) + passes.ttgpuir.add_remove_layout_conversions(pm) + passes.ttgpuir.add_optimize_thread_locality(pm) + passes.ttgpuir.add_accelerate_matmul(pm) + passes.ttgpuir.add_remove_layout_conversions(pm) + passes.ttgpuir.add_optimize_dot_operands(pm, capability >= 80) + nvidia.passes.ttnvgpuir.add_optimize_descriptor_encoding(pm) + passes.ttir.add_loop_aware_cse(pm) + if capability // 10 in [8, 9]: + passes.ttgpuir.add_fuse_nested_loops(pm) + passes.common.add_canonicalizer(pm) + passes.ttir.add_triton_licm(pm) + passes.common.add_canonicalizer(pm) + passes.ttgpuir.add_combine_tensor_select_and_if(pm) + nvidia.passes.hopper.add_hopper_warpspec(pm, opt.num_stages, dump_enabled) + passes.ttgpuir.add_assign_latencies(pm, opt.num_stages) + passes.ttgpuir.add_schedule_loops(pm) + passes.ttgpuir.add_pipeline(pm, opt.num_stages, dump_enabled) + elif capability // 10 >= 10: + passes.ttgpuir.add_fuse_nested_loops(pm) + passes.common.add_canonicalizer(pm) + passes.ttir.add_triton_licm(pm) + passes.ttgpuir.add_optimize_accumulator_init(pm) + passes.ttgpuir.add_hoist_tmem_alloc(pm, False) + nvidia.passes.ttnvgpuir.add_promote_lhs_to_tmem(pm) + passes.ttgpuir.add_assign_latencies(pm, opt.num_stages) + passes.ttgpuir.add_schedule_loops(pm) + passes.ttgpuir.add_warp_specialize(pm, opt.num_stages) + passes.ttgpuir.add_pipeline(pm, opt.num_stages, dump_enabled) + passes.ttgpuir.add_combine_tensor_select_and_if(pm) + # hoist again and allow hoisting out of if statements + passes.ttgpuir.add_hoist_tmem_alloc(pm, True) + nvidia.passes.ttnvgpuir.add_remove_tmem_tokens(pm) + else: + passes.ttir.add_triton_licm(pm) + passes.common.add_canonicalizer(pm) + passes.ttir.add_loop_aware_cse(pm) + passes.ttgpuir.add_prefetch(pm) + passes.ttgpuir.add_optimize_dot_operands(pm, capability >= 80) + passes.ttgpuir.add_coalesce_async_copy(pm) + nvidia.passes.ttnvgpuir.add_optimize_tmem_layouts(pm) + passes.ttgpuir.add_remove_layout_conversions(pm) + nvidia.passes.ttnvgpuir.add_interleave_tmem(pm) + passes.ttgpuir.add_reduce_data_duplication(pm) + passes.ttgpuir.add_reorder_instructions(pm) + passes.ttir.add_loop_aware_cse(pm) + passes.common.add_symbol_dce(pm) + if capability // 10 >= 9: + nvidia.passes.ttnvgpuir.add_tma_lowering(pm) + nvidia.passes.ttnvgpuir.add_fence_insertion(pm, capability) + nvidia.passes.ttnvgpuir.add_lower_mma(pm) + passes.common.add_sccp(pm) + passes.common.add_cse(pm) + passes.common.add_canonicalizer(pm) + diff --git a/default.config b/default.config new file mode 100644 index 000000000000..b4f6bd54ec61 --- /dev/null +++ b/default.config @@ -0,0 +1,57 @@ +{ + "ttgir" : { + "passes" : [ +"ttir.add_convert_to_ttgpuir", +"ttgpuir.add_coalesce", +"ttgpuir.add_f32_dot_tc", +"ttnvgpuir.add_plan_cta", +"ttgpuir.add_remove_layout_conversions", +"ttgpuir.add_optimize_thread_locality", +"ttgpuir.add_accelerate_matmul", +"ttgpuir.add_remove_layout_conversions", +"ttgpuir.add_optimize_dot_operands", +"ttnvgpuir.add_optimize_descriptor_encoding", +"ttir.add_loop_aware_cse", +"ttgpuir.add_fuse_nested_loops", +"common.add_canonicalizer", +"ttir.add_triton_licm", +"common.add_canonicalizer", +"ttgpuir.add_combine_tensor_select_and_if", +"nvidia.hopper.add_hopper_warpspec", +"ttgpuir.add_assign_latencies", +"ttgpuir.add_schedule_loops", +"ttgpuir.add_pipeline", +"ttgpuir.add_fuse_nested_loops", +"common.add_canonicalizer", +"ttir.add_triton_licm", +"ttgpuir.add_optimize_accumulator_init", +"ttgpuir.add_hoist_tmem_alloc{False}", +"ttnvgpuir.add_promote_lhs_to_tmem", +"ttgpuir.add_assign_latencies", +"ttgpuir.add_schedule_loops", +"ttgpuir.add_warp_specialize", +"ttgpuir.add_pipeline", +"ttgpuir.add_combine_tensor_select_and_if", +"ttgpuir.add_hoist_tmem_alloc{True}", +"ttnvgpuir.add_remove_tmem_tokens", +"ttir.add_triton_licm", +"common.add_canonicalizer", +"ttir.add_loop_aware_cse", +"ttgpuir.add_prefetch", +"ttgpuir.add_optimize_dot_operands", +"ttgpuir.add_coalesce_async_copy", +"ttnvgpuir.add_optimize_tmem_layouts", +"ttgpuir.add_remove_layout_conversions", +"ttnvgpuir.add_interleave_tmem", +"ttgpuir.add_reduce_data_duplication", +"ttgpuir.add_reorder_instructions", +"ttir.add_loop_aware_cse", +"common.add_symbol_dce", +"ttnvgpuir.add_tma_lowering", +"ttnvgpuir.add_fence_insertion", +"ttnvgpuir.add_lower_mma", +"common.add_sccp", +"common.add_canonicalizer" + ] + } +} diff --git a/include/triton/Tools/Sys/GetEnv.hpp b/include/triton/Tools/Sys/GetEnv.hpp index efe3d930ef3d..ded502772a26 100644 --- a/include/triton/Tools/Sys/GetEnv.hpp +++ b/include/triton/Tools/Sys/GetEnv.hpp @@ -23,6 +23,7 @@ inline const std::set CACHE_INVALIDATING_ENV_VARS = { "LLVM_IR_ENABLE_DUMP", "LLVM_ENABLE_TIMING", "LLVM_PASS_PLUGIN_PATH", + "PASS_MANAGER_CONFIG_PATH", "MLIR_ENABLE_DIAGNOSTICS", "MLIR_ENABLE_DUMP", "MLIR_DUMP_PATH", @@ -44,6 +45,7 @@ inline const std::set CACHE_INVALIDATING_ENV_VARS = { "ALLOW_LHS_TMEM_LAYOUT_CONVERSION", "TRITON_F32_DEFAULT", "TRITON_PREFER_TMEM_16x256_LAYOUT", + "MLIR_PASS_PLUGIN_PATH", // clang-format on }; diff --git a/python/src/ir.cc b/python/src/ir.cc index 4c8a4233bf73..c574a2ebfd31 100644 --- a/python/src/ir.cc +++ b/python/src/ir.cc @@ -1816,7 +1816,7 @@ void init_triton_ir(py::module &&m) { }) .def( "run", - [](PassManager &self, ModuleOp &mod) { + [](PassManager &self, ModuleOp &mod, std::string repro_suffix) { // TODO: maybe dump module to file and print error for better // diagnostics @@ -1900,7 +1900,14 @@ void init_triton_ir(py::module &&m) { } if (failed(self.run(mod.getOperation()))) throw std::runtime_error("PassManager::run failed"); + + if (!repro_suffix.empty() && !reproducerPath.empty() && reproducerPath != "-" && + llvm::sys::fs::copy_file(reproducerPath, reproducerPath + repro_suffix)) { + throw std::runtime_error("PassManager::run failed (repro temp)"); + } }, + py::arg("mod"), + py::arg("repro_suffix") = "", py::call_guard()); } diff --git a/python/src/passes.cc b/python/src/passes.cc index e54da7e73ec6..514229528b56 100644 --- a/python/src/passes.cc +++ b/python/src/passes.cc @@ -12,6 +12,8 @@ #include "triton/Dialect/TritonGPU/Transforms/Passes.h" #include "triton/Dialect/TritonInstrument/Transforms/Passes.h" #include "triton/Target/LLVMIR/Passes.h" +#include "triton/Tools/Sys/GetEnv.hpp" +#include "mlir/Tools/Plugins/PassPlugin.h" #include #include @@ -92,6 +94,20 @@ void init_triton_passes_ttgpuir(py::module &&m) { createTritonGPUCoalesceAsyncCopy); ADD_PASS_WRAPPER_0("add_concurrency_sanitizer", createTritonInstrumentConcurrencySanitizer); + + std::string pluginFile = + mlir::triton::tools::getStrEnv("MLIR_PASS_PLUGIN_PATH"); + + if (!pluginFile.empty()) { + auto plugin = mlir::PassPlugin::load(pluginFile); + if (!plugin) { + llvm::Error Err = plugin.takeError(); + std::string ErrMsg = + "Pass Plugin Error: " + llvm::toString(std::move(Err)); + throw std::runtime_error(ErrMsg); + } + plugin.get().registerPassRegistryCallbacks(); + } } void init_triton_passes_convert(py::module &&m) { diff --git a/python/test/unit/language/test_reproducer.py b/python/test/unit/language/test_reproducer.py index 4c8f847ac64f..76c153de4e5f 100644 --- a/python/test/unit/language/test_reproducer.py +++ b/python/test/unit/language/test_reproducer.py @@ -1,5 +1,7 @@ +from triton._internal_testing import is_cuda, is_hip import triton import re +import os def test_triton_reproducer_path(monkeypatch, tmp_path): @@ -21,6 +23,12 @@ def triton_(): # matter what the kernel does, just that the PassManager runs its passes. triton_[(1, )]() + if is_cuda() and not is_hip(): + base_path = str(repro_path) + assert os.path.exists(base_path + '.make_ttir.repro.mlir') + assert os.path.exists(base_path + '.make_ttgir.repro.mlir') + assert os.path.exists(base_path + '.make_llir.repro.mlir') + repro = repro_path.read_text() assert "mlir_reproducer" in repro, f"Expected MLIR reproducer in {repro_path}. Got:\n{repro}" m = re.search(r"pipeline: \"(.*)\"", repro) diff --git a/test/lib/CMakeLists.txt b/test/lib/CMakeLists.txt index ae92295191a2..3c10a85b944a 100644 --- a/test/lib/CMakeLists.txt +++ b/test/lib/CMakeLists.txt @@ -2,3 +2,4 @@ add_subdirectory(Analysis) add_subdirectory(Dialect) add_subdirectory(Instrumentation) add_subdirectory(Proton) +add_subdirectory(Extensions) diff --git a/test/lib/Extensions/CMakeLists.txt b/test/lib/Extensions/CMakeLists.txt new file mode 100644 index 000000000000..e120987582b5 --- /dev/null +++ b/test/lib/Extensions/CMakeLists.txt @@ -0,0 +1,39 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls -name Extensions) +add_public_tablegen_target(TritonGPUExtensionIncGen) + +set(GPU_EXTENSION_PASSES + GPUExtensionTestLib + ) + +set(GPUExtensionTestLib_SOURCES + ExtensionHello.cpp + ) + +# MODULE + +foreach( plugin ${GPU_EXTENSION_PASSES} ) + add_mlir_library(${plugin} + ${${plugin}_SOURCES} + SHARED + + ADDITIONAL_HEADER_DIRS + ${PROJECT_BINARY_DIR}/test/lib/Extensions + + DEPENDS + TritonGPUExtensionIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRInferTypeOpInterface + MLIRFuncDialect + ) + set_target_properties(${plugin} PROPERTIES LIBRARY_OUTPUT_DIRECTORY + ${PROJECT_BINARY_DIR}/test/lib/Extensions) + # This is set to -fvisibility=hidden in the top level CMake file + # which causes the llvmGetPassPluginInfo symbol to be hidden and + # an "entry point not found" error. Reset it just for this target + if(NOT MSVC) + target_compile_options(${plugin} PRIVATE -fvisibility=default) + endif() +endforeach() diff --git a/test/lib/Extensions/ExtensionHello.cpp b/test/lib/Extensions/ExtensionHello.cpp new file mode 100644 index 000000000000..749f873cd036 --- /dev/null +++ b/test/lib/Extensions/ExtensionHello.cpp @@ -0,0 +1,106 @@ + +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Dominance.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Verifier.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" +#include "mlir/Transforms/RegionUtils.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include + +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" + +#include "mlir/IR/MLIRContext.h" +#include "mlir/InitAllDialects.h" +#include "mlir/Tools/Plugins/DialectPlugin.h" + +#include "mlir/Tools/Plugins/PassPlugin.h" +#include "llvm/Config/llvm-config.h" +#include "llvm/Support/Compiler.h" + +using namespace mlir; + +/// Dialect plugin registration mechanism. +/// Observe that it also allows to register passes. +/// Necessary symbol to register the dialect plugin. +// extern "C" LLVM_ATTRIBUTE_WEAK DialectPluginLibraryInfo +// mlirGetDialectPluginInfo() { +// return {MLIR_PLUGIN_API_VERSION, "Standalone", LLVM_VERSION_STRING, +// [](DialectRegistry *registry) { +// registry->insert(); +// mlir::standalone::registerPasses(); +// }}; +// } + +#include "mlir/Pass/Pass.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" + +extern "C" void FOOBAR() { } + +namespace mlir { +namespace triton { +namespace gpu { + +#define GEN_PASS_DEF_TRITONGPUHELLOEXTENSION +#include "Passes.h.inc" + +namespace { +struct HelloExtension : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(DotOp dotOp, + PatternRewriter &rewriter) const override { + return success(); + } +}; +} // anonymous namespace + +struct HelloExtensionPass : + public impl::TritonGPUHelloExtensionBase { + void runOnOperation() override { + // MLIRContext *context = &getContext(); + // ModuleOp m = getOperation(); + // RewritePatternSet decomposePatterns(context); + // decomposePatterns.add(context); + // if (applyPatternsGreedily(m, std::move(decomposePatterns)).failed()) { + // signalPassFailure(); + // } + } +}; + +} // namespace gpu +} // namespace triton +} // namespace mlir + + +void registerPasses() { + ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> { + return mlir::triton::gpu::createTritonGPUHelloExtension(); + }); +} + + +/// Pass plugin registration mechanism. +/// Necessary symbol to register the pass plugin. +extern "C" LLVM_ATTRIBUTE_WEAK PassPluginLibraryInfo mlirGetPassPluginInfo() { + return {MLIR_PLUGIN_API_VERSION, "HelloExtensionPlugin", LLVM_VERSION_STRING, + []() { registerPasses(); }}; +} diff --git a/test/lib/Extensions/Passes.td b/test/lib/Extensions/Passes.td new file mode 100644 index 000000000000..2d838edee9a7 --- /dev/null +++ b/test/lib/Extensions/Passes.td @@ -0,0 +1,10 @@ +#ifndef TRITONGPU_EXTENSION_PASSES +#define TRITONGPU_EXTENSION_PASSES + +include "mlir/Pass/PassBase.td" + +def TritonGPUHelloExtension : Pass<"tritongpu-HelloExtension", "mlir::ModuleOp"> { + let summary = "Hello World Extension"; + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect"]; +} +#endif diff --git a/third_party/nvidia/backend/compiler.py b/third_party/nvidia/backend/compiler.py index af19a543aa82..72560c8bdf72 100644 --- a/third_party/nvidia/backend/compiler.py +++ b/third_party/nvidia/backend/compiler.py @@ -14,6 +14,10 @@ import os import subprocess from pathlib import Path +import json + +from importlib.util import spec_from_file_location, module_from_spec +import sys def min_dot_size(target: GPUTarget): @@ -239,7 +243,7 @@ def make_ttir(mod, metadata, opt, capability): passes.common.add_cse(pm) passes.common.add_symbol_dce(pm) passes.ttir.add_loop_unroll(pm) - pm.run(mod) + pm.run(mod, '.make_ttir.repro.mlir') return mod @staticmethod @@ -255,6 +259,22 @@ def make_ttgir(mod, metadata, opt, capability): cluster_info.clusterDimZ = opt.cluster_dims[2] pm = ir.pass_manager(mod.context) dump_enabled = pm.enable_debug() + + file_path = '/home/plotfi/opt/dev/Triton-MetaGPU-Clean/triton/byo_compiler.py' + module_name = 'byo_compiler_setup' + + spec = spec_from_file_location(module_name, file_path) + if spec: + module = module_from_spec(spec) + sys.modules[module_name] = module # Add to sys.modules if you want it discoverable + spec.loader.exec_module(module) + module.byo_make_ttgir(pm, mod, metadata, opt, capability, cluster_info, dump_enabled, passes, nvidia) + pm.run(mod, '.make_ttgir.repro.mlir') + metadata["cluster_dims"] = (cluster_info.clusterDimX, cluster_info.clusterDimY, cluster_info.clusterDimZ) + tensordesc_meta = mod.get_tensordesc_metadata() + metadata["tensordesc_meta"] = tensordesc_meta + return mod + passes.ttir.add_convert_to_ttgpuir(pm, f"cuda:{capability}", opt.num_warps, 32, opt.num_ctas) # optimize TTGIR passes.ttgpuir.add_coalesce(pm) @@ -316,7 +336,7 @@ def make_ttgir(mod, metadata, opt, capability): passes.common.add_cse(pm) passes.common.add_canonicalizer(pm) - pm.run(mod) + pm.run(mod, '.make_ttgir.repro.mlir') metadata["cluster_dims"] = (cluster_info.clusterDimX, cluster_info.clusterDimY, cluster_info.clusterDimZ) tensordesc_meta = mod.get_tensordesc_metadata() metadata["tensordesc_meta"] = tensordesc_meta @@ -334,7 +354,7 @@ def gluon_to_ttgir(self, src, metadata, options, capability): passes.gluon.add_canonicalizer(pm) passes.ttgpuir.add_combine_tensor_select_and_if(pm) - pm.run(mod) + pm.run(mod, '.gluon_to_ttgir.repro.mlir') metadata["tensordesc_meta"] = mod.get_tensordesc_metadata() return mod @@ -373,7 +393,7 @@ def make_llir(self, src, metadata, options, capability): if CUDABackend.instrumentation: CUDABackend.instrumentation.patch("llvmir_to_llvm", pm, mod.context) - pm.run(mod) + pm.run(mod, '.make_llir.repro.mlir') # LLVM-IR (MLIR) -> LLVM-IR (LLVM) llvm.init_targets() context = llvm.context() @@ -509,8 +529,17 @@ def make_cubin(self, src, metadata, opt, capability): return cubin def add_stages(self, stages, options, language): + global_config = None + global_config_path = None + if "PASS_MANAGER_CONFIG_PATH" in os.environ: + global_config_path = os.path.realpath(os.environ.get("PASS_MANAGER_CONFIG_PATH")) + if(global_config_path != None and os.access(global_config_path, os.R_OK)): + print(f"Loading global config from {global_config_path}") + with open(global_config_path, "r") as f: + global_config = json.load(f) capability = self._parse_arch(options.arch) if language == Language.TRITON: + # ttgir_pass_config = self.make_ttgir_pass_config(options.num_warps, options.num_ctas, options.num_stages, capability, options.cluster_dims, dump_enabled=False) stages["ttir"] = lambda src, metadata: self.make_ttir(src, metadata, options, capability) stages["ttgir"] = lambda src, metadata: self.make_ttgir(src, metadata, options, capability) elif language == Language.GLUON: