From 909a62b45a168979969ca0d638e998678e4d7ab1 Mon Sep 17 00:00:00 2001 From: Puyan Lotfi Date: Fri, 22 Aug 2025 00:43:56 -0700 Subject: [PATCH 01/11] [BACKEND] Retain mlir reproducer temporaries from prior run pass pipelines Currently MLIR reproducers for pass pipelines run override the previous TRITON_REPRODUCER_PATH env var path. This change allows for including a reproducer suffix when calling pm.run() to allow for retaining pipeline reproducers run prior to the most recently run pass pipeline. --- python/src/ir.cc | 9 ++++++++- python/test/unit/language/test_reproducer.py | 8 ++++++++ third_party/nvidia/backend/compiler.py | 8 ++++---- 3 files changed, 20 insertions(+), 5 deletions(-) diff --git a/python/src/ir.cc b/python/src/ir.cc index 4c8a4233bf73..c574a2ebfd31 100644 --- a/python/src/ir.cc +++ b/python/src/ir.cc @@ -1816,7 +1816,7 @@ void init_triton_ir(py::module &&m) { }) .def( "run", - [](PassManager &self, ModuleOp &mod) { + [](PassManager &self, ModuleOp &mod, std::string repro_suffix) { // TODO: maybe dump module to file and print error for better // diagnostics @@ -1900,7 +1900,14 @@ void init_triton_ir(py::module &&m) { } if (failed(self.run(mod.getOperation()))) throw std::runtime_error("PassManager::run failed"); + + if (!repro_suffix.empty() && !reproducerPath.empty() && reproducerPath != "-" && + llvm::sys::fs::copy_file(reproducerPath, reproducerPath + repro_suffix)) { + throw std::runtime_error("PassManager::run failed (repro temp)"); + } }, + py::arg("mod"), + py::arg("repro_suffix") = "", py::call_guard()); } diff --git a/python/test/unit/language/test_reproducer.py b/python/test/unit/language/test_reproducer.py index 4c8f847ac64f..76c153de4e5f 100644 --- a/python/test/unit/language/test_reproducer.py +++ b/python/test/unit/language/test_reproducer.py @@ -1,5 +1,7 @@ +from triton._internal_testing import is_cuda, is_hip import triton import re +import os def test_triton_reproducer_path(monkeypatch, tmp_path): @@ -21,6 +23,12 @@ def triton_(): # matter what the kernel does, just that the PassManager runs its passes. triton_[(1, )]() + if is_cuda() and not is_hip(): + base_path = str(repro_path) + assert os.path.exists(base_path + '.make_ttir.repro.mlir') + assert os.path.exists(base_path + '.make_ttgir.repro.mlir') + assert os.path.exists(base_path + '.make_llir.repro.mlir') + repro = repro_path.read_text() assert "mlir_reproducer" in repro, f"Expected MLIR reproducer in {repro_path}. Got:\n{repro}" m = re.search(r"pipeline: \"(.*)\"", repro) diff --git a/third_party/nvidia/backend/compiler.py b/third_party/nvidia/backend/compiler.py index af19a543aa82..e2ff91805733 100644 --- a/third_party/nvidia/backend/compiler.py +++ b/third_party/nvidia/backend/compiler.py @@ -239,7 +239,7 @@ def make_ttir(mod, metadata, opt, capability): passes.common.add_cse(pm) passes.common.add_symbol_dce(pm) passes.ttir.add_loop_unroll(pm) - pm.run(mod) + pm.run(mod, '.make_ttir.repro.mlir') return mod @staticmethod @@ -316,7 +316,7 @@ def make_ttgir(mod, metadata, opt, capability): passes.common.add_cse(pm) passes.common.add_canonicalizer(pm) - pm.run(mod) + pm.run(mod, '.make_ttgir.repro.mlir') metadata["cluster_dims"] = (cluster_info.clusterDimX, cluster_info.clusterDimY, cluster_info.clusterDimZ) tensordesc_meta = mod.get_tensordesc_metadata() metadata["tensordesc_meta"] = tensordesc_meta @@ -334,7 +334,7 @@ def gluon_to_ttgir(self, src, metadata, options, capability): passes.gluon.add_canonicalizer(pm) passes.ttgpuir.add_combine_tensor_select_and_if(pm) - pm.run(mod) + pm.run(mod, '.gluon_to_ttgir.repro.mlir') metadata["tensordesc_meta"] = mod.get_tensordesc_metadata() return mod @@ -373,7 +373,7 @@ def make_llir(self, src, metadata, options, capability): if CUDABackend.instrumentation: CUDABackend.instrumentation.patch("llvmir_to_llvm", pm, mod.context) - pm.run(mod) + pm.run(mod, '.make_llir.repro.mlir') # LLVM-IR (MLIR) -> LLVM-IR (LLVM) llvm.init_targets() context = llvm.context() From 298cf0630f775b1d61aa65b39a3f72e242526132 Mon Sep 17 00:00:00 2001 From: Puyan Lotfi Date: Thu, 31 Jul 2025 01:45:42 -0700 Subject: [PATCH 02/11] First attempt at a pass config --- third_party/nvidia/backend/compiler.py | 159 +++++++++++++++---------- 1 file changed, 95 insertions(+), 64 deletions(-) diff --git a/third_party/nvidia/backend/compiler.py b/third_party/nvidia/backend/compiler.py index e2ff91805733..7b003cabc863 100644 --- a/third_party/nvidia/backend/compiler.py +++ b/third_party/nvidia/backend/compiler.py @@ -243,78 +243,108 @@ def make_ttir(mod, metadata, opt, capability): return mod @staticmethod - def make_ttgir(mod, metadata, opt, capability): - # Set maxnreg on all kernels, if it was provided. - if opt.maxnreg is not None: - mod.set_attr("ttg.maxnreg", ir.builder(mod.context).get_int32_attr(opt.maxnreg)) - + def make_ttgir_pass_config(num_warps, num_ctas, num_stages, capability, cluster_dims, dump_enabled): cluster_info = nvidia.ClusterInfo() - if opt.cluster_dims is not None: - cluster_info.clusterDimX = opt.cluster_dims[0] - cluster_info.clusterDimY = opt.cluster_dims[1] - cluster_info.clusterDimZ = opt.cluster_dims[2] - pm = ir.pass_manager(mod.context) - dump_enabled = pm.enable_debug() - passes.ttir.add_convert_to_ttgpuir(pm, f"cuda:{capability}", opt.num_warps, 32, opt.num_ctas) + if cluster_dims is not None: + cluster_info.clusterDimX = cluster_dims[0] + cluster_info.clusterDimY = cluster_dims[1] + cluster_info.clusterDimZ = cluster_dims[2] + + pass_config = [ + [passes.ttir.add_convert_to_ttgpuir, [f"cuda:{capability}", num_warps, 32, num_ctas]], + ] + # optimize TTGIR - passes.ttgpuir.add_coalesce(pm) + pass_config.append([passes.ttgpuir.add_coalesce, []]) + if capability // 10 >= 8: - passes.ttgpuir.add_f32_dot_tc(pm) + pass_config.append([passes.ttgpuir.add_f32_dot_tc, []]) + # TODO(Qingyi): Move PlanCTAPass to the front of CoalescePass - nvidia.passes.ttnvgpuir.add_plan_cta(pm, cluster_info) - passes.ttgpuir.add_remove_layout_conversions(pm) - passes.ttgpuir.add_optimize_thread_locality(pm) - passes.ttgpuir.add_accelerate_matmul(pm) - passes.ttgpuir.add_remove_layout_conversions(pm) - passes.ttgpuir.add_optimize_dot_operands(pm, capability >= 80) - nvidia.passes.ttnvgpuir.add_optimize_descriptor_encoding(pm) - passes.ttir.add_loop_aware_cse(pm) + pass_config.append([nvidia.passes.ttnvgpuir.add_plan_cta, [cluster_info]]) + pass_config.append([passes.ttgpuir.add_remove_layout_conversions, []]) + pass_config.append([passes.ttgpuir.add_optimize_thread_locality, []]) + pass_config.append([passes.ttgpuir.add_accelerate_matmul, []]) + pass_config.append([passes.ttgpuir.add_remove_layout_conversions, []]) + pass_config.append([passes.ttgpuir.add_optimize_dot_operands, [capability >= 80]]) + pass_config.append([nvidia.passes.ttnvgpuir.add_optimize_descriptor_encoding, []]) + pass_config.append([passes.ttir.add_loop_aware_cse, []]) + if capability // 10 in [8, 9]: - passes.ttgpuir.add_fuse_nested_loops(pm) - passes.common.add_canonicalizer(pm) - passes.ttir.add_triton_licm(pm) - passes.common.add_canonicalizer(pm) - passes.ttgpuir.add_combine_tensor_select_and_if(pm) - nvidia.passes.hopper.add_hopper_warpspec(pm, opt.num_stages, dump_enabled) - passes.ttgpuir.add_assign_latencies(pm, opt.num_stages) - passes.ttgpuir.add_schedule_loops(pm) - passes.ttgpuir.add_pipeline(pm, opt.num_stages, dump_enabled) + pass_config.append([passes.ttgpuir.add_fuse_nested_loops, []]) + pass_config.append([passes.common.add_canonicalizer, []]) + pass_config.append([passes.ttir.add_triton_licm, []]) + pass_config.append([passes.common.add_canonicalizer, []]) + pass_config.append([passes.ttgpuir.add_combine_tensor_select_and_if, []]) + pass_config.append([nvidia.passes.hopper.add_hopper_warpspec, [num_stages, dump_enabled]]) + pass_config.append([passes.ttgpuir.add_assign_latencies, [num_stages]]) + pass_config.append([passes.ttgpuir.add_schedule_loops, []]) + pass_config.append([passes.ttgpuir.add_pipeline, [num_stages, dump_enabled]]) elif capability // 10 >= 10: - passes.ttgpuir.add_fuse_nested_loops(pm) - passes.common.add_canonicalizer(pm) - passes.ttir.add_triton_licm(pm) - passes.ttgpuir.add_optimize_accumulator_init(pm) - passes.ttgpuir.add_hoist_tmem_alloc(pm, False) - nvidia.passes.ttnvgpuir.add_promote_lhs_to_tmem(pm) - passes.ttgpuir.add_assign_latencies(pm, opt.num_stages) - passes.ttgpuir.add_schedule_loops(pm) - passes.ttgpuir.add_warp_specialize(pm, opt.num_stages) - passes.ttgpuir.add_pipeline(pm, opt.num_stages, dump_enabled) - passes.ttgpuir.add_combine_tensor_select_and_if(pm) + pass_config.append([passes.ttgpuir.add_fuse_nested_loops, []]) + pass_config.append([passes.common.add_canonicalizer, []]) + pass_config.append([passes.ttir.add_triton_licm, []]) + pass_config.append([passes.ttgpuir.add_optimize_accumulator_init, []]) + pass_config.append([passes.ttgpuir.add_hoist_tmem_alloc, [False]]) + pass_config.append([nvidia.passes.ttnvgpuir.add_promote_lhs_to_tmem, []]) + pass_config.append([passes.ttgpuir.add_assign_latencies, [num_stages]]) + pass_config.append([passes.ttgpuir.add_schedule_loops, []]) + pass_config.append([passes.ttgpuir.add_warp_specialize, [num_stages]]) + pass_config.append([passes.ttgpuir.add_pipeline, [num_stages, dump_enabled]]) + pass_config.append([passes.ttgpuir.add_combine_tensor_select_and_if, []]) # hoist again and allow hoisting out of if statements - passes.ttgpuir.add_hoist_tmem_alloc(pm, True) - nvidia.passes.ttnvgpuir.add_remove_tmem_tokens(pm) + pass_config.append([passes.ttgpuir.add_hoist_tmem_alloc, [True]]) + pass_config.append([nvidia.passes.ttnvgpuir.add_remove_tmem_tokens, []]) else: - passes.ttir.add_triton_licm(pm) - passes.common.add_canonicalizer(pm) - passes.ttir.add_loop_aware_cse(pm) - passes.ttgpuir.add_prefetch(pm) - passes.ttgpuir.add_optimize_dot_operands(pm, capability >= 80) - passes.ttgpuir.add_coalesce_async_copy(pm) - nvidia.passes.ttnvgpuir.add_optimize_tmem_layouts(pm) - passes.ttgpuir.add_remove_layout_conversions(pm) - nvidia.passes.ttnvgpuir.add_interleave_tmem(pm) - passes.ttgpuir.add_reduce_data_duplication(pm) - passes.ttgpuir.add_reorder_instructions(pm) - passes.ttir.add_loop_aware_cse(pm) - passes.common.add_symbol_dce(pm) + pass_config.append([passes.ttir.add_triton_licm, []]) + pass_config.append([passes.common.add_canonicalizer, []]) + pass_config.append([passes.ttir.add_loop_aware_cse, []]) + pass_config.append([passes.ttgpuir.add_prefetch, []]) + pass_config.append([passes.ttgpuir.add_optimize_dot_operands, [capability >= 80]]) + pass_config.append([passes.ttgpuir.add_coalesce_async_copy, []]) + pass_config.append([nvidia.passes.ttnvgpuir.add_optimize_tmem_layouts, []]) + pass_config.append([passes.ttgpuir.add_remove_layout_conversions, []]) + pass_config.append([nvidia.passes.ttnvgpuir.add_interleave_tmem, []]) + pass_config.append([passes.ttgpuir.add_reduce_data_duplication, []]) + pass_config.append([passes.ttgpuir.add_reorder_instructions, []]) + pass_config.append([passes.ttir.add_loop_aware_cse, []]) + pass_config.append([passes.common.add_symbol_dce, []]) if capability // 10 >= 9: - nvidia.passes.ttnvgpuir.add_tma_lowering(pm) - nvidia.passes.ttnvgpuir.add_fence_insertion(pm, capability) - nvidia.passes.ttnvgpuir.add_lower_mma(pm) - passes.common.add_sccp(pm) - passes.common.add_cse(pm) - passes.common.add_canonicalizer(pm) + pass_config.append([nvidia.passes.ttnvgpuir.add_tma_lowering, []]) + pass_config.append([nvidia.passes.ttnvgpuir.add_fence_insertion, [capability]]) + pass_config.append([nvidia.passes.ttnvgpuir.add_lower_mma, []]) + pass_config.append([passes.common.add_sccp, []]) + pass_config.append([passes.common.add_cse, []]) + pass_config.append([passes.common.add_canonicalizer, []]) + + return [pass_config, cluster_info] + + @staticmethod + def make_ttgir(mod, metadata, opt, capability, pass_config): + # Set maxnreg on all kernels, if it was provided. + if opt.maxnreg is not None: + mod.set_attr("ttg.maxnreg", ir.builder(mod.context).get_int32_attr(opt.maxnreg)) + pm = ir.pass_manager(mod.context) + dump_enabled = pm.enable_debug() + + pass_entries = pass_config[0] + cluster_info = pass_config[1] + + for e in pass_entries: + p = e[0] + args = e[1] + if len(args) == 0: + p(pm) + elif len(args) == 1: + p(pm, args[0]) + elif len(args) == 2: + p(pm, args[0], args[1]) + elif len(args) == 3: + p(pm, args[0], args[1], args[2]) + elif len(args) == 4: + p(pm, args[0], args[1], args[2], args[3]) + else: + raise Exception("Bad Arg Count") pm.run(mod, '.make_ttgir.repro.mlir') metadata["cluster_dims"] = (cluster_info.clusterDimX, cluster_info.clusterDimY, cluster_info.clusterDimZ) @@ -511,8 +541,9 @@ def make_cubin(self, src, metadata, opt, capability): def add_stages(self, stages, options, language): capability = self._parse_arch(options.arch) if language == Language.TRITON: + ttgir_pass_config = self.make_ttgir_pass_config(options.num_warps, options.num_ctas, options.num_stages, capability, options.cluster_dims, dump_enabled=False) stages["ttir"] = lambda src, metadata: self.make_ttir(src, metadata, options, capability) - stages["ttgir"] = lambda src, metadata: self.make_ttgir(src, metadata, options, capability) + stages["ttgir"] = lambda src, metadata: self.make_ttgir(src, metadata, options, capability, ttgir_pass_config) elif language == Language.GLUON: stages["ttgir"] = lambda src, metadata: self.gluon_to_ttgir(src, metadata, options, capability) stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options, capability) From 19c4a5115adf773cefb3b6c20dec2a4cabe034ac Mon Sep 17 00:00:00 2001 From: Corbin Robeck Date: Thu, 31 Jul 2025 12:45:27 -0700 Subject: [PATCH 03/11] make pass entires dictionary based (cherry picked from commit d4cfc372cd93160191e51f0881a730a9abc8f947) --- third_party/nvidia/backend/compiler.py | 111 ++++++++++++------------- 1 file changed, 54 insertions(+), 57 deletions(-) diff --git a/third_party/nvidia/backend/compiler.py b/third_party/nvidia/backend/compiler.py index 7b003cabc863..97f9e26fe87f 100644 --- a/third_party/nvidia/backend/compiler.py +++ b/third_party/nvidia/backend/compiler.py @@ -249,73 +249,70 @@ def make_ttgir_pass_config(num_warps, num_ctas, num_stages, capability, cluster_ cluster_info.clusterDimX = cluster_dims[0] cluster_info.clusterDimY = cluster_dims[1] cluster_info.clusterDimZ = cluster_dims[2] - - pass_config = [ - [passes.ttir.add_convert_to_ttgpuir, [f"cuda:{capability}", num_warps, 32, num_ctas]], - ] - + pass_config = dict() + pass_config["ttir.add_convert_to_ttgpuir"] = [passes.ttir.add_convert_to_ttgpuir, [f"cuda:{capability}", num_warps, 32, num_ctas]] # optimize TTGIR - pass_config.append([passes.ttgpuir.add_coalesce, []]) + pass_config["ttgpuir.add_coalesce"] = [passes.ttgpuir.add_coalesce, []] if capability // 10 >= 8: - pass_config.append([passes.ttgpuir.add_f32_dot_tc, []]) + pass_config["ttgpuir.add_f32_dot_tc"] = [passes.ttgpuir.add_f32_dot_tc, []] # TODO(Qingyi): Move PlanCTAPass to the front of CoalescePass - pass_config.append([nvidia.passes.ttnvgpuir.add_plan_cta, [cluster_info]]) - pass_config.append([passes.ttgpuir.add_remove_layout_conversions, []]) - pass_config.append([passes.ttgpuir.add_optimize_thread_locality, []]) - pass_config.append([passes.ttgpuir.add_accelerate_matmul, []]) - pass_config.append([passes.ttgpuir.add_remove_layout_conversions, []]) - pass_config.append([passes.ttgpuir.add_optimize_dot_operands, [capability >= 80]]) - pass_config.append([nvidia.passes.ttnvgpuir.add_optimize_descriptor_encoding, []]) - pass_config.append([passes.ttir.add_loop_aware_cse, []]) + pass_config["ttnvgpuir.add_plan_cta"] = [nvidia.passes.ttnvgpuir.add_plan_cta, [cluster_info]] + pass_config["ttgpuir.add_remove_layout_conversions"] = [passes.ttgpuir.add_remove_layout_conversions, []] + pass_config["ttgpuir.add_optimize_thread_locality"] = [passes.ttgpuir.add_optimize_thread_locality, []] + pass_config["ttgpuir.add_accelerate_matmul"] = [passes.ttgpuir.add_accelerate_matmul, []] + pass_config["ttgpuir.add_remove_layout_conversions"] = [passes.ttgpuir.add_remove_layout_conversions, []] + pass_config["ttgpuir.add_optimize_dot_operands"] = [passes.ttgpuir.add_optimize_dot_operands, [capability >= 80]] + pass_config["ttnvgpuir.add_optimize_descriptor_encoding"] = [nvidia.passes.ttnvgpuir.add_optimize_descriptor_encoding, []] + pass_config["ttir.add_loop_aware_cse"] = [passes.ttir.add_loop_aware_cse, []] if capability // 10 in [8, 9]: - pass_config.append([passes.ttgpuir.add_fuse_nested_loops, []]) - pass_config.append([passes.common.add_canonicalizer, []]) - pass_config.append([passes.ttir.add_triton_licm, []]) - pass_config.append([passes.common.add_canonicalizer, []]) - pass_config.append([passes.ttgpuir.add_combine_tensor_select_and_if, []]) - pass_config.append([nvidia.passes.hopper.add_hopper_warpspec, [num_stages, dump_enabled]]) - pass_config.append([passes.ttgpuir.add_assign_latencies, [num_stages]]) - pass_config.append([passes.ttgpuir.add_schedule_loops, []]) - pass_config.append([passes.ttgpuir.add_pipeline, [num_stages, dump_enabled]]) + pass_config["ttgpuir.add_fuse_nested_loops"] = [passes.ttgpuir.add_fuse_nested_loops, []] + pass_config["common.add_canonicalizer"] = [passes.common.add_canonicalizer, []] + pass_config["ttir.add_triton_licm"] = [passes.ttir.add_triton_licm, []] + pass_config["common.add_canonicalizer"] = [passes.common.add_canonicalizer, []] + pass_config["ttgpuir.add_combine_tensor_select_and_if"] = [passes.ttgpuir.add_combine_tensor_select_and_if, []] + pass_config["nvidia.hopper.add_hopper_warpspec"] = [nvidia.passes.hopper.add_hopper_warpspec, [num_stages, dump_enabled]] + pass_config["ttgpuir.add_assign_latencies"] = [passes.ttgpuir.add_assign_latencies, [num_stages]] + pass_config["ttgpuir.add_schedule_loops"] = [passes.ttgpuir.add_schedule_loops, []] + pass_config["ttgpuir.add_pipeline"] = [passes.ttgpuir.add_pipeline, [num_stages, dump_enabled]] elif capability // 10 >= 10: - pass_config.append([passes.ttgpuir.add_fuse_nested_loops, []]) - pass_config.append([passes.common.add_canonicalizer, []]) - pass_config.append([passes.ttir.add_triton_licm, []]) - pass_config.append([passes.ttgpuir.add_optimize_accumulator_init, []]) - pass_config.append([passes.ttgpuir.add_hoist_tmem_alloc, [False]]) - pass_config.append([nvidia.passes.ttnvgpuir.add_promote_lhs_to_tmem, []]) - pass_config.append([passes.ttgpuir.add_assign_latencies, [num_stages]]) - pass_config.append([passes.ttgpuir.add_schedule_loops, []]) - pass_config.append([passes.ttgpuir.add_warp_specialize, [num_stages]]) - pass_config.append([passes.ttgpuir.add_pipeline, [num_stages, dump_enabled]]) - pass_config.append([passes.ttgpuir.add_combine_tensor_select_and_if, []]) + pass_config["ttgpuir.add_fuse_nested_loops"] = [passes.ttgpuir.add_fuse_nested_loops, []] + pass_config["common.add_canonicalizer"] = [passes.common.add_canonicalizer, []] + pass_config["ttir.add_triton_licm"] = [passes.ttir.add_triton_licm, []] + pass_config["ttgpuir.add_optimize_accumulator_init"] = [passes.ttgpuir.add_optimize_accumulator_init, []] + pass_config["ttgpuir.add_hoist_tmem_alloc"] = [passes.ttgpuir.add_hoist_tmem_alloc, [False]] + pass_config["ttnvgpuir.add_promote_lhs_to_tmem"] = [nvidia.passes.ttnvgpuir.add_promote_lhs_to_tmem, []] + pass_config["ttgpuir.add_assign_latencies"] = [passes.ttgpuir.add_assign_latencies, [num_stages]] + pass_config["ttgpuir.add_schedule_loops"] = [passes.ttgpuir.add_schedule_loops, []] + pass_config["ttgpuir.add_warp_specialize"] = [passes.ttgpuir.add_warp_specialize, [num_stages]] + pass_config["ttgpuir.add_pipeline"] = [passes.ttgpuir.add_pipeline, [num_stages, dump_enabled]] + pass_config["ttgpuir.add_combine_tensor_select_and_if"] = [passes.ttgpuir.add_combine_tensor_select_and_if, []] # hoist again and allow hoisting out of if statements - pass_config.append([passes.ttgpuir.add_hoist_tmem_alloc, [True]]) - pass_config.append([nvidia.passes.ttnvgpuir.add_remove_tmem_tokens, []]) + pass_config["ttgpuir.add_hoist_tmem_alloc"] = [passes.ttgpuir.add_hoist_tmem_alloc, [True]] + pass_config["ttnvgpuir.add_remove_tmem_tokens"] = [nvidia.passes.ttnvgpuir.add_remove_tmem_tokens, []] else: - pass_config.append([passes.ttir.add_triton_licm, []]) - pass_config.append([passes.common.add_canonicalizer, []]) - pass_config.append([passes.ttir.add_loop_aware_cse, []]) - pass_config.append([passes.ttgpuir.add_prefetch, []]) - pass_config.append([passes.ttgpuir.add_optimize_dot_operands, [capability >= 80]]) - pass_config.append([passes.ttgpuir.add_coalesce_async_copy, []]) - pass_config.append([nvidia.passes.ttnvgpuir.add_optimize_tmem_layouts, []]) - pass_config.append([passes.ttgpuir.add_remove_layout_conversions, []]) - pass_config.append([nvidia.passes.ttnvgpuir.add_interleave_tmem, []]) - pass_config.append([passes.ttgpuir.add_reduce_data_duplication, []]) - pass_config.append([passes.ttgpuir.add_reorder_instructions, []]) - pass_config.append([passes.ttir.add_loop_aware_cse, []]) - pass_config.append([passes.common.add_symbol_dce, []]) + pass_config["ttir.add_triton_licm"] = [passes.ttir.add_triton_licm, []] + pass_config["common.add_canonicalizer"] = [passes.common.add_canonicalizer, []] + pass_config["ttir.add_loop_aware_cse"] = [passes.ttir.add_loop_aware_cse, []] + pass_config["ttgpuir.add_prefetch"] = [passes.ttgpuir.add_prefetch, []] + pass_config["ttgpuir.add_optimize_dot_operands"] = [passes.ttgpuir.add_optimize_dot_operands, [capability >= 80]] + pass_config["ttgpuir.add_coalesce_async_copy"] = [passes.ttgpuir.add_coalesce_async_copy, []] + pass_config["ttnvgpuir.add_optimize_tmem_layouts"] = [nvidia.passes.ttnvgpuir.add_optimize_tmem_layouts, []] + pass_config["ttgpuir.add_remove_layout_conversions"] = [passes.ttgpuir.add_remove_layout_conversions, []] + pass_config["ttnvgpuir.add_interleave_tmem"] = [nvidia.passes.ttnvgpuir.add_interleave_tmem, []] + pass_config["ttgpuir.add_reduce_data_duplication"] = [passes.ttgpuir.add_reduce_data_duplication, []] + pass_config["ttgpuir.add_reorder_instructions"] = [passes.ttgpuir.add_reorder_instructions, []] + pass_config["ttir.add_loop_aware_cse"] = [passes.ttir.add_loop_aware_cse, []] + pass_config["common.add_symbol_dce"] = [passes.common.add_symbol_dce, []] if capability // 10 >= 9: - pass_config.append([nvidia.passes.ttnvgpuir.add_tma_lowering, []]) - pass_config.append([nvidia.passes.ttnvgpuir.add_fence_insertion, [capability]]) - pass_config.append([nvidia.passes.ttnvgpuir.add_lower_mma, []]) - pass_config.append([passes.common.add_sccp, []]) - pass_config.append([passes.common.add_cse, []]) - pass_config.append([passes.common.add_canonicalizer, []]) + pass_config["ttnvgpuir.add_tma_lowering"] = [nvidia.passes.ttnvgpuir.add_tma_lowering, []] + pass_config["ttnvgpuir.add_fence_insertion"] = [nvidia.passes.ttnvgpuir.add_fence_insertion, [capability]] + pass_config["ttnvgpuir.add_lower_mma"] = [nvidia.passes.ttnvgpuir.add_lower_mma, []] + pass_config["common.add_sccp"] = [passes.common.add_sccp, []] + pass_config["common.add_cse"] = [passes.common.add_cse, []] + pass_config["common.add_canonicalizer"] = [passes.common.add_canonicalizer, []] return [pass_config, cluster_info] @@ -330,7 +327,7 @@ def make_ttgir(mod, metadata, opt, capability, pass_config): pass_entries = pass_config[0] cluster_info = pass_config[1] - for e in pass_entries: + for e in pass_entries.values(): p = e[0] args = e[1] if len(args) == 0: From d797c5063aabe2acbd98d3b3d102ff13bcc68308 Mon Sep 17 00:00:00 2001 From: Corbin Robeck Date: Fri, 1 Aug 2025 10:37:33 -0700 Subject: [PATCH 04/11] add config file (cherry picked from commit e756711ab66decf8328ee2014509885ce58d4e10) --- default.config | 57 ++++++++++++++++++++++++++ third_party/nvidia/backend/compiler.py | 19 +++++++-- 2 files changed, 73 insertions(+), 3 deletions(-) create mode 100644 default.config diff --git a/default.config b/default.config new file mode 100644 index 000000000000..e0a98dbc682c --- /dev/null +++ b/default.config @@ -0,0 +1,57 @@ +{ + "ttgir" : { + "passes" : [ +"ttir.add_convert_to_ttgpuir", +"ttgpuir.add_coalesce", +"ttgpuir.add_f32_dot_tc", +"ttnvgpuir.add_plan_cta", +"ttgpuir.add_remove_layout_conversions", +"ttgpuir.add_optimize_thread_locality", +"ttgpuir.add_accelerate_matmul", +"ttgpuir.add_remove_layout_conversions", +"ttgpuir.add_optimize_dot_operands", +"ttnvgpuir.add_optimize_descriptor_encoding", +"ttir.add_loop_aware_cse", +"ttgpuir.add_fuse_nested_loops", +"common.add_canonicalizer", +"ttir.add_triton_licm", +"common.add_canonicalizer", +"ttgpuir.add_combine_tensor_select_and_if", +"nvidia.hopper.add_hopper_warpspec", +"ttgpuir.add_assign_latencies", +"ttgpuir.add_schedule_loops", +"ttgpuir.add_pipeline", +"ttgpuir.add_fuse_nested_loops", +"common.add_canonicalizer", +"ttir.add_triton_licm", +"ttgpuir.add_optimize_accumulator_init", +"ttgpuir.add_hoist_tmem_alloc", +"ttnvgpuir.add_promote_lhs_to_tmem", +"ttgpuir.add_assign_latencies", +"ttgpuir.add_schedule_loops", +"ttgpuir.add_warp_specialize", +"ttgpuir.add_pipeline", +"ttgpuir.add_combine_tensor_select_and_if", +"ttgpuir.add_hoist_tmem_alloc", +"ttnvgpuir.add_remove_tmem_tokens", +"ttir.add_triton_licm", +"common.add_canonicalizer", +"ttir.add_loop_aware_cse", +"ttgpuir.add_prefetch", +"ttgpuir.add_optimize_dot_operands", +"ttgpuir.add_coalesce_async_copy", +"ttnvgpuir.add_optimize_tmem_layouts", +"ttgpuir.add_remove_layout_conversions", +"ttnvgpuir.add_interleave_tmem", +"ttgpuir.add_reduce_data_duplication", +"ttgpuir.add_reorder_instructions", +"ttir.add_loop_aware_cse", +"common.add_symbol_dce", +"ttnvgpuir.add_tma_lowering", +"ttnvgpuir.add_fence_insertion", +"ttnvgpuir.add_lower_mma", +"common.add_sccp", +"common.add_canonicalizer" + ] + } +} diff --git a/third_party/nvidia/backend/compiler.py b/third_party/nvidia/backend/compiler.py index 97f9e26fe87f..fd3d91534190 100644 --- a/third_party/nvidia/backend/compiler.py +++ b/third_party/nvidia/backend/compiler.py @@ -14,6 +14,7 @@ import os import subprocess from pathlib import Path +import json def min_dot_size(target: GPUTarget): @@ -317,7 +318,7 @@ def make_ttgir_pass_config(num_warps, num_ctas, num_stages, capability, cluster_ return [pass_config, cluster_info] @staticmethod - def make_ttgir(mod, metadata, opt, capability, pass_config): + def make_ttgir(mod, metadata, opt, capability, pass_config, global_config): # Set maxnreg on all kernels, if it was provided. if opt.maxnreg is not None: mod.set_attr("ttg.maxnreg", ir.builder(mod.context).get_int32_attr(opt.maxnreg)) @@ -325,9 +326,15 @@ def make_ttgir(mod, metadata, opt, capability, pass_config): dump_enabled = pm.enable_debug() pass_entries = pass_config[0] + config_pass_entries = global_config["ttgir"]["passes"] cluster_info = pass_config[1] - for e in pass_entries.values(): + new_pass_entries = dict() + + for e in config_pass_entries: + if(e in list(pass_entries.keys())): + new_pass_entries[e] = [pass_entries[e][0], pass_entries[e][1]] + for e in new_pass_entries.values(): p = e[0] args = e[1] if len(args) == 0: @@ -536,11 +543,17 @@ def make_cubin(self, src, metadata, opt, capability): return cubin def add_stages(self, stages, options, language): + global_config = None + if os.path.isfile("default.config") and os.access("default.config", os.R_OK): + with open("default.config", "r") as f: + global_config = json.load(f) + # if global_config is not None: + # print(global_config["ttgir"]["passes"]) capability = self._parse_arch(options.arch) if language == Language.TRITON: ttgir_pass_config = self.make_ttgir_pass_config(options.num_warps, options.num_ctas, options.num_stages, capability, options.cluster_dims, dump_enabled=False) stages["ttir"] = lambda src, metadata: self.make_ttir(src, metadata, options, capability) - stages["ttgir"] = lambda src, metadata: self.make_ttgir(src, metadata, options, capability, ttgir_pass_config) + stages["ttgir"] = lambda src, metadata: self.make_ttgir(src, metadata, options, capability, ttgir_pass_config, global_config) elif language == Language.GLUON: stages["ttgir"] = lambda src, metadata: self.gluon_to_ttgir(src, metadata, options, capability) stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options, capability) From 15e166488d554bad4c0b04dcb2d5b1172209ec18 Mon Sep 17 00:00:00 2001 From: Corbin Robeck Date: Fri, 1 Aug 2025 10:58:56 -0700 Subject: [PATCH 05/11] add default pass entry for cases when no config file (cherry picked from commit bb231e7d1801dc62217532239a6a7bb9b37ef56d) --- third_party/nvidia/backend/compiler.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/third_party/nvidia/backend/compiler.py b/third_party/nvidia/backend/compiler.py index fd3d91534190..71ddfc9b2f32 100644 --- a/third_party/nvidia/backend/compiler.py +++ b/third_party/nvidia/backend/compiler.py @@ -330,10 +330,13 @@ def make_ttgir(mod, metadata, opt, capability, pass_config, global_config): cluster_info = pass_config[1] new_pass_entries = dict() + if(global_config == None): + new_pass_entries = pass_entries + else: + for e in config_pass_entries: + if(e in list(pass_entries.keys())): + new_pass_entries[e] = [pass_entries[e][0], pass_entries[e][1]] - for e in config_pass_entries: - if(e in list(pass_entries.keys())): - new_pass_entries[e] = [pass_entries[e][0], pass_entries[e][1]] for e in new_pass_entries.values(): p = e[0] args = e[1] From ced89c3e9a92400e5da0edec732b1ebe835e6010 Mon Sep 17 00:00:00 2001 From: Puyan Lotfi Date: Wed, 20 Aug 2025 00:18:08 -0700 Subject: [PATCH 06/11] Fix issues with dicts and repeating passes with different arguments Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: Differential Revision: https://phabricator.intern.facebook.com/D80597444 (cherry picked from commit 2acdc2a654d210c3c1d419a9dd0c8e63ae65df78) (cherry picked from commit 758b2732fd12b78ee254c87cd7cb9f9014558171) --- default.config | 4 +- third_party/nvidia/backend/compiler.py | 158 ++++++++++++++++--------- 2 files changed, 101 insertions(+), 61 deletions(-) diff --git a/default.config b/default.config index e0a98dbc682c..b4f6bd54ec61 100644 --- a/default.config +++ b/default.config @@ -25,14 +25,14 @@ "common.add_canonicalizer", "ttir.add_triton_licm", "ttgpuir.add_optimize_accumulator_init", -"ttgpuir.add_hoist_tmem_alloc", +"ttgpuir.add_hoist_tmem_alloc{False}", "ttnvgpuir.add_promote_lhs_to_tmem", "ttgpuir.add_assign_latencies", "ttgpuir.add_schedule_loops", "ttgpuir.add_warp_specialize", "ttgpuir.add_pipeline", "ttgpuir.add_combine_tensor_select_and_if", -"ttgpuir.add_hoist_tmem_alloc", +"ttgpuir.add_hoist_tmem_alloc{True}", "ttnvgpuir.add_remove_tmem_tokens", "ttir.add_triton_licm", "common.add_canonicalizer", diff --git a/third_party/nvidia/backend/compiler.py b/third_party/nvidia/backend/compiler.py index 71ddfc9b2f32..1ff47fb93026 100644 --- a/third_party/nvidia/backend/compiler.py +++ b/third_party/nvidia/backend/compiler.py @@ -250,72 +250,111 @@ def make_ttgir_pass_config(num_warps, num_ctas, num_stages, capability, cluster_ cluster_info.clusterDimX = cluster_dims[0] cluster_info.clusterDimY = cluster_dims[1] cluster_info.clusterDimZ = cluster_dims[2] - pass_config = dict() - pass_config["ttir.add_convert_to_ttgpuir"] = [passes.ttir.add_convert_to_ttgpuir, [f"cuda:{capability}", num_warps, 32, num_ctas]] + + pass_config_table = dict() + pass_config_table["ttir.add_convert_to_ttgpuir"] = [passes.ttir.add_convert_to_ttgpuir, [f"cuda:{capability}", num_warps, 32, num_ctas]] + pass_config_table["ttgpuir.add_coalesce"] = [passes.ttgpuir.add_coalesce, []] + pass_config_table["ttgpuir.add_f32_dot_tc"] = [passes.ttgpuir.add_f32_dot_tc, []] + pass_config_table["ttnvgpuir.add_plan_cta"] = [nvidia.passes.ttnvgpuir.add_plan_cta, [cluster_info]] + pass_config_table["ttgpuir.add_remove_layout_conversions"] = [passes.ttgpuir.add_remove_layout_conversions, []] + pass_config_table["ttgpuir.add_optimize_thread_locality"] = [passes.ttgpuir.add_optimize_thread_locality, []] + pass_config_table["ttgpuir.add_accelerate_matmul"] = [passes.ttgpuir.add_accelerate_matmul, []] + pass_config_table["ttgpuir.add_optimize_dot_operands"] = [passes.ttgpuir.add_optimize_dot_operands, [capability >= 80]] + pass_config_table["ttnvgpuir.add_optimize_descriptor_encoding"] = [nvidia.passes.ttnvgpuir.add_optimize_descriptor_encoding, []] + pass_config_table["ttir.add_loop_aware_cse"] = [passes.ttir.add_loop_aware_cse, []] + pass_config_table["ttgpuir.add_fuse_nested_loops"] = [passes.ttgpuir.add_fuse_nested_loops, []] + pass_config_table["common.add_canonicalizer"] = [passes.common.add_canonicalizer, []] + pass_config_table["ttir.add_triton_licm"] = [passes.ttir.add_triton_licm, []] + pass_config_table["ttgpuir.add_combine_tensor_select_and_if"] = [passes.ttgpuir.add_combine_tensor_select_and_if, []] + pass_config_table["nvidia.hopper.add_hopper_warpspec"] = [nvidia.passes.hopper.add_hopper_warpspec, [num_stages, dump_enabled]] + pass_config_table["ttgpuir.add_assign_latencies"] = [passes.ttgpuir.add_assign_latencies, [num_stages]] + pass_config_table["ttgpuir.add_schedule_loops"] = [passes.ttgpuir.add_schedule_loops, []] + pass_config_table["ttgpuir.add_pipeline"] = [passes.ttgpuir.add_pipeline, [num_stages, dump_enabled]] + pass_config_table["ttgpuir.add_optimize_accumulator_init"] = [passes.ttgpuir.add_optimize_accumulator_init, []] + pass_config_table["ttgpuir.add_hoist_tmem_alloc{False}"] = [passes.ttgpuir.add_hoist_tmem_alloc, [False]] # TODO: This is the one case when the params for a pass differ in two different instances + pass_config_table["ttnvgpuir.add_promote_lhs_to_tmem"] = [nvidia.passes.ttnvgpuir.add_promote_lhs_to_tmem, []] + pass_config_table["ttgpuir.add_warp_specialize"] = [passes.ttgpuir.add_warp_specialize, [num_stages]] + pass_config_table["ttgpuir.add_hoist_tmem_alloc{True}"] = [passes.ttgpuir.add_hoist_tmem_alloc, [True]] # TODO: This is the one case when the params for a pass differ in two different instances + pass_config_table["ttnvgpuir.add_remove_tmem_tokens"] = [nvidia.passes.ttnvgpuir.add_remove_tmem_tokens, []] + pass_config_table["ttgpuir.add_prefetch"] = [passes.ttgpuir.add_prefetch, []] + pass_config_table["ttgpuir.add_coalesce_async_copy"] = [passes.ttgpuir.add_coalesce_async_copy, []] + pass_config_table["ttnvgpuir.add_optimize_tmem_layouts"] = [nvidia.passes.ttnvgpuir.add_optimize_tmem_layouts, []] + pass_config_table["ttnvgpuir.add_interleave_tmem"] = [nvidia.passes.ttnvgpuir.add_interleave_tmem, []] + pass_config_table["ttgpuir.add_reduce_data_duplication"] = [passes.ttgpuir.add_reduce_data_duplication, []] + pass_config_table["ttgpuir.add_reorder_instructions"] = [passes.ttgpuir.add_reorder_instructions, []] + pass_config_table["common.add_symbol_dce"] = [passes.common.add_symbol_dce, []] + pass_config_table["ttnvgpuir.add_tma_lowering"] = [nvidia.passes.ttnvgpuir.add_tma_lowering, []] + pass_config_table["ttnvgpuir.add_fence_insertion"] = [nvidia.passes.ttnvgpuir.add_fence_insertion, [capability]] + pass_config_table["ttnvgpuir.add_lower_mma"] = [nvidia.passes.ttnvgpuir.add_lower_mma, []] + pass_config_table["common.add_cse"] = [passes.common.add_cse, []] + pass_config_table["common.add_sccp"] = [passes.common.add_sccp, []] + + pass_config = list() + pass_config.append(pass_config_table["ttir.add_convert_to_ttgpuir"]) # optimize TTGIR - pass_config["ttgpuir.add_coalesce"] = [passes.ttgpuir.add_coalesce, []] + pass_config.append(pass_config_table["ttgpuir.add_coalesce"]) if capability // 10 >= 8: - pass_config["ttgpuir.add_f32_dot_tc"] = [passes.ttgpuir.add_f32_dot_tc, []] + pass_config.append(pass_config_table["ttgpuir.add_f32_dot_tc"]) # TODO(Qingyi): Move PlanCTAPass to the front of CoalescePass - pass_config["ttnvgpuir.add_plan_cta"] = [nvidia.passes.ttnvgpuir.add_plan_cta, [cluster_info]] - pass_config["ttgpuir.add_remove_layout_conversions"] = [passes.ttgpuir.add_remove_layout_conversions, []] - pass_config["ttgpuir.add_optimize_thread_locality"] = [passes.ttgpuir.add_optimize_thread_locality, []] - pass_config["ttgpuir.add_accelerate_matmul"] = [passes.ttgpuir.add_accelerate_matmul, []] - pass_config["ttgpuir.add_remove_layout_conversions"] = [passes.ttgpuir.add_remove_layout_conversions, []] - pass_config["ttgpuir.add_optimize_dot_operands"] = [passes.ttgpuir.add_optimize_dot_operands, [capability >= 80]] - pass_config["ttnvgpuir.add_optimize_descriptor_encoding"] = [nvidia.passes.ttnvgpuir.add_optimize_descriptor_encoding, []] - pass_config["ttir.add_loop_aware_cse"] = [passes.ttir.add_loop_aware_cse, []] + pass_config.append(pass_config_table["ttnvgpuir.add_plan_cta"]) + pass_config.append(pass_config_table["ttgpuir.add_remove_layout_conversions"]) + pass_config.append(pass_config_table["ttgpuir.add_optimize_thread_locality"]) + pass_config.append(pass_config_table["ttgpuir.add_accelerate_matmul"]) + pass_config.append(pass_config_table["ttgpuir.add_remove_layout_conversions"]) + pass_config.append(pass_config_table["ttgpuir.add_optimize_dot_operands"]) + pass_config.append(pass_config_table["ttnvgpuir.add_optimize_descriptor_encoding"]) + pass_config.append(pass_config_table["ttir.add_loop_aware_cse"]) if capability // 10 in [8, 9]: - pass_config["ttgpuir.add_fuse_nested_loops"] = [passes.ttgpuir.add_fuse_nested_loops, []] - pass_config["common.add_canonicalizer"] = [passes.common.add_canonicalizer, []] - pass_config["ttir.add_triton_licm"] = [passes.ttir.add_triton_licm, []] - pass_config["common.add_canonicalizer"] = [passes.common.add_canonicalizer, []] - pass_config["ttgpuir.add_combine_tensor_select_and_if"] = [passes.ttgpuir.add_combine_tensor_select_and_if, []] - pass_config["nvidia.hopper.add_hopper_warpspec"] = [nvidia.passes.hopper.add_hopper_warpspec, [num_stages, dump_enabled]] - pass_config["ttgpuir.add_assign_latencies"] = [passes.ttgpuir.add_assign_latencies, [num_stages]] - pass_config["ttgpuir.add_schedule_loops"] = [passes.ttgpuir.add_schedule_loops, []] - pass_config["ttgpuir.add_pipeline"] = [passes.ttgpuir.add_pipeline, [num_stages, dump_enabled]] + pass_config.append(pass_config_table["ttgpuir.add_fuse_nested_loops"]) + pass_config.append(pass_config_table["common.add_canonicalizer"]) + pass_config.append(pass_config_table["ttir.add_triton_licm"]) + pass_config.append(pass_config_table["common.add_canonicalizer"]) + pass_config.append(pass_config_table["ttgpuir.add_combine_tensor_select_and_if"]) + pass_config.append(pass_config_table["nvidia.hopper.add_hopper_warpspec"]) + pass_config.append(pass_config_table["ttgpuir.add_assign_latencies"]) + pass_config.append(pass_config_table["ttgpuir.add_schedule_loops"]) + pass_config.append(pass_config_table["ttgpuir.add_pipeline"]) elif capability // 10 >= 10: - pass_config["ttgpuir.add_fuse_nested_loops"] = [passes.ttgpuir.add_fuse_nested_loops, []] - pass_config["common.add_canonicalizer"] = [passes.common.add_canonicalizer, []] - pass_config["ttir.add_triton_licm"] = [passes.ttir.add_triton_licm, []] - pass_config["ttgpuir.add_optimize_accumulator_init"] = [passes.ttgpuir.add_optimize_accumulator_init, []] - pass_config["ttgpuir.add_hoist_tmem_alloc"] = [passes.ttgpuir.add_hoist_tmem_alloc, [False]] - pass_config["ttnvgpuir.add_promote_lhs_to_tmem"] = [nvidia.passes.ttnvgpuir.add_promote_lhs_to_tmem, []] - pass_config["ttgpuir.add_assign_latencies"] = [passes.ttgpuir.add_assign_latencies, [num_stages]] - pass_config["ttgpuir.add_schedule_loops"] = [passes.ttgpuir.add_schedule_loops, []] - pass_config["ttgpuir.add_warp_specialize"] = [passes.ttgpuir.add_warp_specialize, [num_stages]] - pass_config["ttgpuir.add_pipeline"] = [passes.ttgpuir.add_pipeline, [num_stages, dump_enabled]] - pass_config["ttgpuir.add_combine_tensor_select_and_if"] = [passes.ttgpuir.add_combine_tensor_select_and_if, []] + pass_config.append(pass_config_table["ttgpuir.add_fuse_nested_loops"]) + pass_config.append(pass_config_table["common.add_canonicalizer"]) + pass_config.append(pass_config_table["ttir.add_triton_licm"]) + pass_config.append(pass_config_table["ttgpuir.add_optimize_accumulator_init"]) + pass_config.append(pass_config_table["ttgpuir.add_hoist_tmem_alloc"]) + pass_config.append(pass_config_table["ttnvgpuir.add_promote_lhs_to_tmem"]) + pass_config.append(pass_config_table["ttgpuir.add_assign_latencies"]) + pass_config.append(pass_config_table["ttgpuir.add_schedule_loops"]) + pass_config.append(pass_config_table["ttgpuir.add_warp_specialize"]) + pass_config.append(pass_config_table["ttgpuir.add_pipeline"]) + pass_config.append(pass_config_table["ttgpuir.add_combine_tensor_select_and_if"]) # hoist again and allow hoisting out of if statements - pass_config["ttgpuir.add_hoist_tmem_alloc"] = [passes.ttgpuir.add_hoist_tmem_alloc, [True]] - pass_config["ttnvgpuir.add_remove_tmem_tokens"] = [nvidia.passes.ttnvgpuir.add_remove_tmem_tokens, []] + pass_config.append(pass_config_table["ttgpuir.add_hoist_tmem_alloc"]) + pass_config.append(pass_config_table["ttnvgpuir.add_remove_tmem_tokens"]) else: - pass_config["ttir.add_triton_licm"] = [passes.ttir.add_triton_licm, []] - pass_config["common.add_canonicalizer"] = [passes.common.add_canonicalizer, []] - pass_config["ttir.add_loop_aware_cse"] = [passes.ttir.add_loop_aware_cse, []] - pass_config["ttgpuir.add_prefetch"] = [passes.ttgpuir.add_prefetch, []] - pass_config["ttgpuir.add_optimize_dot_operands"] = [passes.ttgpuir.add_optimize_dot_operands, [capability >= 80]] - pass_config["ttgpuir.add_coalesce_async_copy"] = [passes.ttgpuir.add_coalesce_async_copy, []] - pass_config["ttnvgpuir.add_optimize_tmem_layouts"] = [nvidia.passes.ttnvgpuir.add_optimize_tmem_layouts, []] - pass_config["ttgpuir.add_remove_layout_conversions"] = [passes.ttgpuir.add_remove_layout_conversions, []] - pass_config["ttnvgpuir.add_interleave_tmem"] = [nvidia.passes.ttnvgpuir.add_interleave_tmem, []] - pass_config["ttgpuir.add_reduce_data_duplication"] = [passes.ttgpuir.add_reduce_data_duplication, []] - pass_config["ttgpuir.add_reorder_instructions"] = [passes.ttgpuir.add_reorder_instructions, []] - pass_config["ttir.add_loop_aware_cse"] = [passes.ttir.add_loop_aware_cse, []] - pass_config["common.add_symbol_dce"] = [passes.common.add_symbol_dce, []] + pass_config.append(pass_config_table["ttir.add_triton_licm"]) + pass_config.append(pass_config_table["common.add_canonicalizer"]) + pass_config.append(pass_config_table["ttir.add_loop_aware_cse"]) + pass_config.append(pass_config_table["ttgpuir.add_prefetch"]) + pass_config.append(pass_config_table["ttgpuir.add_optimize_dot_operands"]) + pass_config.append(pass_config_table["ttgpuir.add_coalesce_async_copy"]) + pass_config.append(pass_config_table["ttnvgpuir.add_optimize_tmem_layouts"]) + pass_config.append(pass_config_table["ttgpuir.add_remove_layout_conversions"]) + pass_config.append(pass_config_table["ttnvgpuir.add_interleave_tmem"]) + pass_config.append(pass_config_table["ttgpuir.add_reduce_data_duplication"]) + pass_config.append(pass_config_table["ttgpuir.add_reorder_instructions"]) + pass_config.append(pass_config_table["ttir.add_loop_aware_cse"]) + pass_config.append(pass_config_table["common.add_symbol_dce"]) if capability // 10 >= 9: - pass_config["ttnvgpuir.add_tma_lowering"] = [nvidia.passes.ttnvgpuir.add_tma_lowering, []] - pass_config["ttnvgpuir.add_fence_insertion"] = [nvidia.passes.ttnvgpuir.add_fence_insertion, [capability]] - pass_config["ttnvgpuir.add_lower_mma"] = [nvidia.passes.ttnvgpuir.add_lower_mma, []] - pass_config["common.add_sccp"] = [passes.common.add_sccp, []] - pass_config["common.add_cse"] = [passes.common.add_cse, []] - pass_config["common.add_canonicalizer"] = [passes.common.add_canonicalizer, []] + pass_config.append(pass_config_table["ttnvgpuir.add_tma_lowering"]) + pass_config.append(pass_config_table["ttnvgpuir.add_fence_insertion"]) + pass_config.append(pass_config_table["ttnvgpuir.add_lower_mma"]) + pass_config.append(pass_config_table["common.add_sccp"]) + pass_config.append(pass_config_table["common.add_cse"]) + pass_config.append(pass_config_table["common.add_canonicalizer"]) - return [pass_config, cluster_info] + return [pass_config_table, pass_config, cluster_info] @staticmethod def make_ttgir(mod, metadata, opt, capability, pass_config, global_config): @@ -326,18 +365,19 @@ def make_ttgir(mod, metadata, opt, capability, pass_config, global_config): dump_enabled = pm.enable_debug() pass_entries = pass_config[0] + pass_list = pass_config[1] config_pass_entries = global_config["ttgir"]["passes"] - cluster_info = pass_config[1] + cluster_info = pass_config[2] - new_pass_entries = dict() + new_pass_list = list() if(global_config == None): - new_pass_entries = pass_entries + new_pass_list = pass_list else: for e in config_pass_entries: if(e in list(pass_entries.keys())): - new_pass_entries[e] = [pass_entries[e][0], pass_entries[e][1]] + new_pass_list.append([pass_entries[e][0], pass_entries[e][1]]) - for e in new_pass_entries.values(): + for e in new_pass_list: p = e[0] args = e[1] if len(args) == 0: From df75e55f6af6d20aa1647a08a7fb5617221a96fd Mon Sep 17 00:00:00 2001 From: Corbin Robeck Date: Mon, 25 Aug 2025 11:46:54 -0700 Subject: [PATCH 07/11] clean up path logic for config file --- include/triton/Tools/Sys/GetEnv.hpp | 1 + third_party/nvidia/backend/compiler.py | 12 +++++++----- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/include/triton/Tools/Sys/GetEnv.hpp b/include/triton/Tools/Sys/GetEnv.hpp index efe3d930ef3d..82ce47283cbc 100644 --- a/include/triton/Tools/Sys/GetEnv.hpp +++ b/include/triton/Tools/Sys/GetEnv.hpp @@ -23,6 +23,7 @@ inline const std::set CACHE_INVALIDATING_ENV_VARS = { "LLVM_IR_ENABLE_DUMP", "LLVM_ENABLE_TIMING", "LLVM_PASS_PLUGIN_PATH", + "PASS_MANAGER_CONFIG_PATH", "MLIR_ENABLE_DIAGNOSTICS", "MLIR_ENABLE_DUMP", "MLIR_DUMP_PATH", diff --git a/third_party/nvidia/backend/compiler.py b/third_party/nvidia/backend/compiler.py index 1ff47fb93026..67ae30967dd6 100644 --- a/third_party/nvidia/backend/compiler.py +++ b/third_party/nvidia/backend/compiler.py @@ -366,13 +366,13 @@ def make_ttgir(mod, metadata, opt, capability, pass_config, global_config): pass_entries = pass_config[0] pass_list = pass_config[1] - config_pass_entries = global_config["ttgir"]["passes"] cluster_info = pass_config[2] new_pass_list = list() if(global_config == None): new_pass_list = pass_list else: + config_pass_entries = global_config["ttgir"]["passes"] for e in config_pass_entries: if(e in list(pass_entries.keys())): new_pass_list.append([pass_entries[e][0], pass_entries[e][1]]) @@ -587,11 +587,13 @@ def make_cubin(self, src, metadata, opt, capability): def add_stages(self, stages, options, language): global_config = None - if os.path.isfile("default.config") and os.access("default.config", os.R_OK): - with open("default.config", "r") as f: + global_config_path = None + if "PASS_MANAGER_CONFIG_PATH" in os.environ: + global_config_path = os.path.realpath(os.environ.get("PASS_MANAGER_CONFIG_PATH")) + if(global_config_path != None and os.access(global_config_path, os.R_OK)): + print(f"Loading global config from {global_config_path}") + with open(global_config_path, "r") as f: global_config = json.load(f) - # if global_config is not None: - # print(global_config["ttgir"]["passes"]) capability = self._parse_arch(options.arch) if language == Language.TRITON: ttgir_pass_config = self.make_ttgir_pass_config(options.num_warps, options.num_ctas, options.num_stages, capability, options.cluster_dims, dump_enabled=False) From 5e743b2ee4c4a805a6279946800e379d2d1c775d Mon Sep 17 00:00:00 2001 From: Puyan Lotfi Date: Tue, 26 Aug 2025 00:27:54 -0700 Subject: [PATCH 08/11] First stab at ttgir/mlir pass plugins --- include/triton/Tools/Sys/GetEnv.hpp | 1 + python/src/passes.cc | 16 ++++ test/lib/CMakeLists.txt | 1 + test/lib/Extensions/CMakeLists.txt | 29 +++++++ test/lib/Extensions/ExtensionHello.cpp | 113 +++++++++++++++++++++++++ test/lib/Extensions/Passes.td | 10 +++ 6 files changed, 170 insertions(+) create mode 100644 test/lib/Extensions/CMakeLists.txt create mode 100644 test/lib/Extensions/ExtensionHello.cpp create mode 100644 test/lib/Extensions/Passes.td diff --git a/include/triton/Tools/Sys/GetEnv.hpp b/include/triton/Tools/Sys/GetEnv.hpp index 82ce47283cbc..ded502772a26 100644 --- a/include/triton/Tools/Sys/GetEnv.hpp +++ b/include/triton/Tools/Sys/GetEnv.hpp @@ -45,6 +45,7 @@ inline const std::set CACHE_INVALIDATING_ENV_VARS = { "ALLOW_LHS_TMEM_LAYOUT_CONVERSION", "TRITON_F32_DEFAULT", "TRITON_PREFER_TMEM_16x256_LAYOUT", + "MLIR_PASS_PLUGIN_PATH", // clang-format on }; diff --git a/python/src/passes.cc b/python/src/passes.cc index e54da7e73ec6..514229528b56 100644 --- a/python/src/passes.cc +++ b/python/src/passes.cc @@ -12,6 +12,8 @@ #include "triton/Dialect/TritonGPU/Transforms/Passes.h" #include "triton/Dialect/TritonInstrument/Transforms/Passes.h" #include "triton/Target/LLVMIR/Passes.h" +#include "triton/Tools/Sys/GetEnv.hpp" +#include "mlir/Tools/Plugins/PassPlugin.h" #include #include @@ -92,6 +94,20 @@ void init_triton_passes_ttgpuir(py::module &&m) { createTritonGPUCoalesceAsyncCopy); ADD_PASS_WRAPPER_0("add_concurrency_sanitizer", createTritonInstrumentConcurrencySanitizer); + + std::string pluginFile = + mlir::triton::tools::getStrEnv("MLIR_PASS_PLUGIN_PATH"); + + if (!pluginFile.empty()) { + auto plugin = mlir::PassPlugin::load(pluginFile); + if (!plugin) { + llvm::Error Err = plugin.takeError(); + std::string ErrMsg = + "Pass Plugin Error: " + llvm::toString(std::move(Err)); + throw std::runtime_error(ErrMsg); + } + plugin.get().registerPassRegistryCallbacks(); + } } void init_triton_passes_convert(py::module &&m) { diff --git a/test/lib/CMakeLists.txt b/test/lib/CMakeLists.txt index ae92295191a2..3c10a85b944a 100644 --- a/test/lib/CMakeLists.txt +++ b/test/lib/CMakeLists.txt @@ -2,3 +2,4 @@ add_subdirectory(Analysis) add_subdirectory(Dialect) add_subdirectory(Instrumentation) add_subdirectory(Proton) +add_subdirectory(Extensions) diff --git a/test/lib/Extensions/CMakeLists.txt b/test/lib/Extensions/CMakeLists.txt new file mode 100644 index 000000000000..0b1d9a82d61c --- /dev/null +++ b/test/lib/Extensions/CMakeLists.txt @@ -0,0 +1,29 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls -name Extensions) +add_public_tablegen_target(TritonGPUExtensionIncGen) + +set(GPU_EXTENSION_PASSES + GPUExtensionTestLib + ) + +set(GPUExtensionTestLib_SOURCES + ExtensionHello.cpp + ) + +foreach( plugin ${GPU_EXTENSION_PASSES} ) + add_mlir_library(${plugin} + SHARED + ${${plugin}_SOURCES} + + DEPENDS + TritonGPUExtensionIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRInferTypeOpInterface + MLIRFuncDialect + ) + include_directories(${PROJECT_BINARY_DIR}/test/lib/Extensions) # Tablegen'd files + set_target_properties(${plugin} PROPERTIES LIBRARY_OUTPUT_DIRECTORY + ${PROJECT_BINARY_DIR}/test/lib/Extensions) +endforeach() diff --git a/test/lib/Extensions/ExtensionHello.cpp b/test/lib/Extensions/ExtensionHello.cpp new file mode 100644 index 000000000000..2c0010e8fccf --- /dev/null +++ b/test/lib/Extensions/ExtensionHello.cpp @@ -0,0 +1,113 @@ + +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Dominance.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Verifier.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" +#include "mlir/Transforms/RegionUtils.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include + +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" + +#include "mlir/IR/MLIRContext.h" +#include "mlir/InitAllDialects.h" +#include "mlir/Tools/Plugins/DialectPlugin.h" + +#include "mlir/Tools/Plugins/PassPlugin.h" +#include "llvm/Config/llvm-config.h" +#include "llvm/Support/Compiler.h" + +using namespace mlir; + +/// Dialect plugin registration mechanism. +/// Observe that it also allows to register passes. +/// Necessary symbol to register the dialect plugin. +// extern "C" LLVM_ATTRIBUTE_WEAK DialectPluginLibraryInfo +// mlirGetDialectPluginInfo() { +// return {MLIR_PLUGIN_API_VERSION, "Standalone", LLVM_VERSION_STRING, +// [](DialectRegistry *registry) { +// registry->insert(); +// mlir::standalone::registerPasses(); +// }}; +// } + +#include "mlir/Pass/Pass.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" + +namespace mlir { +namespace triton { +namespace gpu { + +#define GEN_PASS_DEF_TRITONGPUHELLOEXTENSION +#include "Passes.h.inc" + +namespace { + +class HelloExtension : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(DotOp dotOp, + PatternRewriter &rewriter) const override { + return success(); + } +}; + +} // anonymous namespace + +struct HelloExtensionPass : public impl::TritonGPUHelloExtensionBase { + void runOnOperation() + // override + { + // MLIRContext *context = &getContext(); + // ModuleOp m = getOperation(); + // RewritePatternSet decomposePatterns(context); + // decomposePatterns.add(context); + // if (applyPatternsGreedily(m, std::move(decomposePatterns)).failed()) { + // signalPassFailure(); + // } + } +}; + +} // namespace gpu +} // namespace triton +} // namespace mlir + + +inline void registerStandaloneSwitchBarFoo() { + ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> { + return mlir::triton::gpu::createTritonGPUHelloExtension(); + }); +} + +inline void registerPasses() { + registerStandaloneSwitchBarFoo(); +} + + +/// Pass plugin registration mechanism. +/// Necessary symbol to register the pass plugin. +extern "C" LLVM_ATTRIBUTE_WEAK PassPluginLibraryInfo mlirGetPassPluginInfo() { + return {MLIR_PLUGIN_API_VERSION, "HelloExtensionPlugin", LLVM_VERSION_STRING, + []() { registerPasses(); }}; +} diff --git a/test/lib/Extensions/Passes.td b/test/lib/Extensions/Passes.td new file mode 100644 index 000000000000..2d838edee9a7 --- /dev/null +++ b/test/lib/Extensions/Passes.td @@ -0,0 +1,10 @@ +#ifndef TRITONGPU_EXTENSION_PASSES +#define TRITONGPU_EXTENSION_PASSES + +include "mlir/Pass/PassBase.td" + +def TritonGPUHelloExtension : Pass<"tritongpu-HelloExtension", "mlir::ModuleOp"> { + let summary = "Hello World Extension"; + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect"]; +} +#endif From 337acdeeffc254c38909a9ad89c7f1317f6b7918 Mon Sep 17 00:00:00 2001 From: Puyan Lotfi Date: Wed, 27 Aug 2025 11:01:19 -0700 Subject: [PATCH 09/11] Fix the visibility hidden problem --- test/lib/Extensions/CMakeLists.txt | 16 +++++++++++++--- test/lib/Extensions/ExtensionHello.cpp | 21 +++++++-------------- 2 files changed, 20 insertions(+), 17 deletions(-) diff --git a/test/lib/Extensions/CMakeLists.txt b/test/lib/Extensions/CMakeLists.txt index 0b1d9a82d61c..e120987582b5 100644 --- a/test/lib/Extensions/CMakeLists.txt +++ b/test/lib/Extensions/CMakeLists.txt @@ -3,17 +3,22 @@ mlir_tablegen(Passes.h.inc -gen-pass-decls -name Extensions) add_public_tablegen_target(TritonGPUExtensionIncGen) set(GPU_EXTENSION_PASSES - GPUExtensionTestLib + GPUExtensionTestLib ) set(GPUExtensionTestLib_SOURCES ExtensionHello.cpp ) +# MODULE + foreach( plugin ${GPU_EXTENSION_PASSES} ) add_mlir_library(${plugin} - SHARED ${${plugin}_SOURCES} + SHARED + + ADDITIONAL_HEADER_DIRS + ${PROJECT_BINARY_DIR}/test/lib/Extensions DEPENDS TritonGPUExtensionIncGen @@ -23,7 +28,12 @@ foreach( plugin ${GPU_EXTENSION_PASSES} ) MLIRInferTypeOpInterface MLIRFuncDialect ) - include_directories(${PROJECT_BINARY_DIR}/test/lib/Extensions) # Tablegen'd files set_target_properties(${plugin} PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${PROJECT_BINARY_DIR}/test/lib/Extensions) + # This is set to -fvisibility=hidden in the top level CMake file + # which causes the llvmGetPassPluginInfo symbol to be hidden and + # an "entry point not found" error. Reset it just for this target + if(NOT MSVC) + target_compile_options(${plugin} PRIVATE -fvisibility=default) + endif() endforeach() diff --git a/test/lib/Extensions/ExtensionHello.cpp b/test/lib/Extensions/ExtensionHello.cpp index 2c0010e8fccf..749f873cd036 100644 --- a/test/lib/Extensions/ExtensionHello.cpp +++ b/test/lib/Extensions/ExtensionHello.cpp @@ -54,6 +54,8 @@ using namespace mlir; #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "triton/Dialect/TritonGPU/Transforms/Passes.h" +extern "C" void FOOBAR() { } + namespace mlir { namespace triton { namespace gpu { @@ -62,23 +64,18 @@ namespace gpu { #include "Passes.h.inc" namespace { - -class HelloExtension : public OpRewritePattern { -public: +struct HelloExtension : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(DotOp dotOp, PatternRewriter &rewriter) const override { return success(); } }; - } // anonymous namespace -struct HelloExtensionPass : public impl::TritonGPUHelloExtensionBase { - void runOnOperation() - // override - { +struct HelloExtensionPass : + public impl::TritonGPUHelloExtensionBase { + void runOnOperation() override { // MLIRContext *context = &getContext(); // ModuleOp m = getOperation(); // RewritePatternSet decomposePatterns(context); @@ -94,16 +91,12 @@ struct HelloExtensionPass : public impl::TritonGPUHelloExtensionBase std::unique_ptr<::mlir::Pass> { return mlir::triton::gpu::createTritonGPUHelloExtension(); }); } -inline void registerPasses() { - registerStandaloneSwitchBarFoo(); -} - /// Pass plugin registration mechanism. /// Necessary symbol to register the pass plugin. From 71c0289bdadba422a49fb8bb83e4c381312db843 Mon Sep 17 00:00:00 2001 From: Puyan Lotfi Date: Fri, 29 Aug 2025 14:22:51 -0700 Subject: [PATCH 10/11] Switch to a bring your own compiler.py setup --- third_party/nvidia/backend/compiler.py | 213 +++++++++---------------- 1 file changed, 73 insertions(+), 140 deletions(-) diff --git a/third_party/nvidia/backend/compiler.py b/third_party/nvidia/backend/compiler.py index 67ae30967dd6..5a6c81b8a4a6 100644 --- a/third_party/nvidia/backend/compiler.py +++ b/third_party/nvidia/backend/compiler.py @@ -244,154 +244,87 @@ def make_ttir(mod, metadata, opt, capability): return mod @staticmethod - def make_ttgir_pass_config(num_warps, num_ctas, num_stages, capability, cluster_dims, dump_enabled): + def make_ttgir(mod, metadata, opt, capability): + # Set maxnreg on all kernels, if it was provided. + if opt.maxnreg is not None: + mod.set_attr("ttg.maxnreg", ir.builder(mod.context).get_int32_attr(opt.maxnreg)) + cluster_info = nvidia.ClusterInfo() - if cluster_dims is not None: - cluster_info.clusterDimX = cluster_dims[0] - cluster_info.clusterDimY = cluster_dims[1] - cluster_info.clusterDimZ = cluster_dims[2] - - pass_config_table = dict() - pass_config_table["ttir.add_convert_to_ttgpuir"] = [passes.ttir.add_convert_to_ttgpuir, [f"cuda:{capability}", num_warps, 32, num_ctas]] - pass_config_table["ttgpuir.add_coalesce"] = [passes.ttgpuir.add_coalesce, []] - pass_config_table["ttgpuir.add_f32_dot_tc"] = [passes.ttgpuir.add_f32_dot_tc, []] - pass_config_table["ttnvgpuir.add_plan_cta"] = [nvidia.passes.ttnvgpuir.add_plan_cta, [cluster_info]] - pass_config_table["ttgpuir.add_remove_layout_conversions"] = [passes.ttgpuir.add_remove_layout_conversions, []] - pass_config_table["ttgpuir.add_optimize_thread_locality"] = [passes.ttgpuir.add_optimize_thread_locality, []] - pass_config_table["ttgpuir.add_accelerate_matmul"] = [passes.ttgpuir.add_accelerate_matmul, []] - pass_config_table["ttgpuir.add_optimize_dot_operands"] = [passes.ttgpuir.add_optimize_dot_operands, [capability >= 80]] - pass_config_table["ttnvgpuir.add_optimize_descriptor_encoding"] = [nvidia.passes.ttnvgpuir.add_optimize_descriptor_encoding, []] - pass_config_table["ttir.add_loop_aware_cse"] = [passes.ttir.add_loop_aware_cse, []] - pass_config_table["ttgpuir.add_fuse_nested_loops"] = [passes.ttgpuir.add_fuse_nested_loops, []] - pass_config_table["common.add_canonicalizer"] = [passes.common.add_canonicalizer, []] - pass_config_table["ttir.add_triton_licm"] = [passes.ttir.add_triton_licm, []] - pass_config_table["ttgpuir.add_combine_tensor_select_and_if"] = [passes.ttgpuir.add_combine_tensor_select_and_if, []] - pass_config_table["nvidia.hopper.add_hopper_warpspec"] = [nvidia.passes.hopper.add_hopper_warpspec, [num_stages, dump_enabled]] - pass_config_table["ttgpuir.add_assign_latencies"] = [passes.ttgpuir.add_assign_latencies, [num_stages]] - pass_config_table["ttgpuir.add_schedule_loops"] = [passes.ttgpuir.add_schedule_loops, []] - pass_config_table["ttgpuir.add_pipeline"] = [passes.ttgpuir.add_pipeline, [num_stages, dump_enabled]] - pass_config_table["ttgpuir.add_optimize_accumulator_init"] = [passes.ttgpuir.add_optimize_accumulator_init, []] - pass_config_table["ttgpuir.add_hoist_tmem_alloc{False}"] = [passes.ttgpuir.add_hoist_tmem_alloc, [False]] # TODO: This is the one case when the params for a pass differ in two different instances - pass_config_table["ttnvgpuir.add_promote_lhs_to_tmem"] = [nvidia.passes.ttnvgpuir.add_promote_lhs_to_tmem, []] - pass_config_table["ttgpuir.add_warp_specialize"] = [passes.ttgpuir.add_warp_specialize, [num_stages]] - pass_config_table["ttgpuir.add_hoist_tmem_alloc{True}"] = [passes.ttgpuir.add_hoist_tmem_alloc, [True]] # TODO: This is the one case when the params for a pass differ in two different instances - pass_config_table["ttnvgpuir.add_remove_tmem_tokens"] = [nvidia.passes.ttnvgpuir.add_remove_tmem_tokens, []] - pass_config_table["ttgpuir.add_prefetch"] = [passes.ttgpuir.add_prefetch, []] - pass_config_table["ttgpuir.add_coalesce_async_copy"] = [passes.ttgpuir.add_coalesce_async_copy, []] - pass_config_table["ttnvgpuir.add_optimize_tmem_layouts"] = [nvidia.passes.ttnvgpuir.add_optimize_tmem_layouts, []] - pass_config_table["ttnvgpuir.add_interleave_tmem"] = [nvidia.passes.ttnvgpuir.add_interleave_tmem, []] - pass_config_table["ttgpuir.add_reduce_data_duplication"] = [passes.ttgpuir.add_reduce_data_duplication, []] - pass_config_table["ttgpuir.add_reorder_instructions"] = [passes.ttgpuir.add_reorder_instructions, []] - pass_config_table["common.add_symbol_dce"] = [passes.common.add_symbol_dce, []] - pass_config_table["ttnvgpuir.add_tma_lowering"] = [nvidia.passes.ttnvgpuir.add_tma_lowering, []] - pass_config_table["ttnvgpuir.add_fence_insertion"] = [nvidia.passes.ttnvgpuir.add_fence_insertion, [capability]] - pass_config_table["ttnvgpuir.add_lower_mma"] = [nvidia.passes.ttnvgpuir.add_lower_mma, []] - pass_config_table["common.add_cse"] = [passes.common.add_cse, []] - pass_config_table["common.add_sccp"] = [passes.common.add_sccp, []] - - pass_config = list() - pass_config.append(pass_config_table["ttir.add_convert_to_ttgpuir"]) - # optimize TTGIR - pass_config.append(pass_config_table["ttgpuir.add_coalesce"]) + if opt.cluster_dims is not None: + cluster_info.clusterDimX = opt.cluster_dims[0] + cluster_info.clusterDimY = opt.cluster_dims[1] + cluster_info.clusterDimZ = opt.cluster_dims[2] + pm = ir.pass_manager(mod.context) + dump_enabled = pm.enable_debug() - if capability // 10 >= 8: - pass_config.append(pass_config_table["ttgpuir.add_f32_dot_tc"]) + if doBYOPassSetup: + byo_make_ttgiir(mod, metadata, opt, capability, passes, nvidia) + pm.run(mod, '.make_ttgir.repro.mlir') + metadata["cluster_dims"] = (cluster_info.clusterDimX, cluster_info.clusterDimY, cluster_info.clusterDimZ) + tensordesc_meta = mod.get_tensordesc_metadata() + metadata["tensordesc_meta"] = tensordesc_meta + return mod + passes.ttir.add_convert_to_ttgpuir(pm, f"cuda:{capability}", opt.num_warps, 32, opt.num_ctas) + # optimize TTGIR + passes.ttgpuir.add_coalesce(pm) + if capability // 10 >= 8: + passes.ttgpuir.add_f32_dot_tc(pm) # TODO(Qingyi): Move PlanCTAPass to the front of CoalescePass - pass_config.append(pass_config_table["ttnvgpuir.add_plan_cta"]) - pass_config.append(pass_config_table["ttgpuir.add_remove_layout_conversions"]) - pass_config.append(pass_config_table["ttgpuir.add_optimize_thread_locality"]) - pass_config.append(pass_config_table["ttgpuir.add_accelerate_matmul"]) - pass_config.append(pass_config_table["ttgpuir.add_remove_layout_conversions"]) - pass_config.append(pass_config_table["ttgpuir.add_optimize_dot_operands"]) - pass_config.append(pass_config_table["ttnvgpuir.add_optimize_descriptor_encoding"]) - pass_config.append(pass_config_table["ttir.add_loop_aware_cse"]) - + nvidia.passes.ttnvgpuir.add_plan_cta(pm, cluster_info) + passes.ttgpuir.add_remove_layout_conversions(pm) + passes.ttgpuir.add_optimize_thread_locality(pm) + passes.ttgpuir.add_accelerate_matmul(pm) + passes.ttgpuir.add_remove_layout_conversions(pm) + passes.ttgpuir.add_optimize_dot_operands(pm, capability >= 80) + nvidia.passes.ttnvgpuir.add_optimize_descriptor_encoding(pm) + passes.ttir.add_loop_aware_cse(pm) if capability // 10 in [8, 9]: - pass_config.append(pass_config_table["ttgpuir.add_fuse_nested_loops"]) - pass_config.append(pass_config_table["common.add_canonicalizer"]) - pass_config.append(pass_config_table["ttir.add_triton_licm"]) - pass_config.append(pass_config_table["common.add_canonicalizer"]) - pass_config.append(pass_config_table["ttgpuir.add_combine_tensor_select_and_if"]) - pass_config.append(pass_config_table["nvidia.hopper.add_hopper_warpspec"]) - pass_config.append(pass_config_table["ttgpuir.add_assign_latencies"]) - pass_config.append(pass_config_table["ttgpuir.add_schedule_loops"]) - pass_config.append(pass_config_table["ttgpuir.add_pipeline"]) + passes.ttgpuir.add_fuse_nested_loops(pm) + passes.common.add_canonicalizer(pm) + passes.ttir.add_triton_licm(pm) + passes.common.add_canonicalizer(pm) + passes.ttgpuir.add_combine_tensor_select_and_if(pm) + nvidia.passes.hopper.add_hopper_warpspec(pm, opt.num_stages, dump_enabled) + passes.ttgpuir.add_assign_latencies(pm, opt.num_stages) + passes.ttgpuir.add_schedule_loops(pm) + passes.ttgpuir.add_pipeline(pm, opt.num_stages, dump_enabled) elif capability // 10 >= 10: - pass_config.append(pass_config_table["ttgpuir.add_fuse_nested_loops"]) - pass_config.append(pass_config_table["common.add_canonicalizer"]) - pass_config.append(pass_config_table["ttir.add_triton_licm"]) - pass_config.append(pass_config_table["ttgpuir.add_optimize_accumulator_init"]) - pass_config.append(pass_config_table["ttgpuir.add_hoist_tmem_alloc"]) - pass_config.append(pass_config_table["ttnvgpuir.add_promote_lhs_to_tmem"]) - pass_config.append(pass_config_table["ttgpuir.add_assign_latencies"]) - pass_config.append(pass_config_table["ttgpuir.add_schedule_loops"]) - pass_config.append(pass_config_table["ttgpuir.add_warp_specialize"]) - pass_config.append(pass_config_table["ttgpuir.add_pipeline"]) - pass_config.append(pass_config_table["ttgpuir.add_combine_tensor_select_and_if"]) + passes.ttgpuir.add_fuse_nested_loops(pm) + passes.common.add_canonicalizer(pm) + passes.ttir.add_triton_licm(pm) + passes.ttgpuir.add_optimize_accumulator_init(pm) + passes.ttgpuir.add_hoist_tmem_alloc(pm, False) + nvidia.passes.ttnvgpuir.add_promote_lhs_to_tmem(pm) + passes.ttgpuir.add_assign_latencies(pm, opt.num_stages) + passes.ttgpuir.add_schedule_loops(pm) + passes.ttgpuir.add_warp_specialize(pm, opt.num_stages) + passes.ttgpuir.add_pipeline(pm, opt.num_stages, dump_enabled) + passes.ttgpuir.add_combine_tensor_select_and_if(pm) # hoist again and allow hoisting out of if statements - pass_config.append(pass_config_table["ttgpuir.add_hoist_tmem_alloc"]) - pass_config.append(pass_config_table["ttnvgpuir.add_remove_tmem_tokens"]) + passes.ttgpuir.add_hoist_tmem_alloc(pm, True) + nvidia.passes.ttnvgpuir.add_remove_tmem_tokens(pm) else: - pass_config.append(pass_config_table["ttir.add_triton_licm"]) - pass_config.append(pass_config_table["common.add_canonicalizer"]) - pass_config.append(pass_config_table["ttir.add_loop_aware_cse"]) - pass_config.append(pass_config_table["ttgpuir.add_prefetch"]) - pass_config.append(pass_config_table["ttgpuir.add_optimize_dot_operands"]) - pass_config.append(pass_config_table["ttgpuir.add_coalesce_async_copy"]) - pass_config.append(pass_config_table["ttnvgpuir.add_optimize_tmem_layouts"]) - pass_config.append(pass_config_table["ttgpuir.add_remove_layout_conversions"]) - pass_config.append(pass_config_table["ttnvgpuir.add_interleave_tmem"]) - pass_config.append(pass_config_table["ttgpuir.add_reduce_data_duplication"]) - pass_config.append(pass_config_table["ttgpuir.add_reorder_instructions"]) - pass_config.append(pass_config_table["ttir.add_loop_aware_cse"]) - pass_config.append(pass_config_table["common.add_symbol_dce"]) + passes.ttir.add_triton_licm(pm) + passes.common.add_canonicalizer(pm) + passes.ttir.add_loop_aware_cse(pm) + passes.ttgpuir.add_prefetch(pm) + passes.ttgpuir.add_optimize_dot_operands(pm, capability >= 80) + passes.ttgpuir.add_coalesce_async_copy(pm) + nvidia.passes.ttnvgpuir.add_optimize_tmem_layouts(pm) + passes.ttgpuir.add_remove_layout_conversions(pm) + nvidia.passes.ttnvgpuir.add_interleave_tmem(pm) + passes.ttgpuir.add_reduce_data_duplication(pm) + passes.ttgpuir.add_reorder_instructions(pm) + passes.ttir.add_loop_aware_cse(pm) + passes.common.add_symbol_dce(pm) if capability // 10 >= 9: - pass_config.append(pass_config_table["ttnvgpuir.add_tma_lowering"]) - pass_config.append(pass_config_table["ttnvgpuir.add_fence_insertion"]) - pass_config.append(pass_config_table["ttnvgpuir.add_lower_mma"]) - pass_config.append(pass_config_table["common.add_sccp"]) - pass_config.append(pass_config_table["common.add_cse"]) - pass_config.append(pass_config_table["common.add_canonicalizer"]) - - return [pass_config_table, pass_config, cluster_info] - - @staticmethod - def make_ttgir(mod, metadata, opt, capability, pass_config, global_config): - # Set maxnreg on all kernels, if it was provided. - if opt.maxnreg is not None: - mod.set_attr("ttg.maxnreg", ir.builder(mod.context).get_int32_attr(opt.maxnreg)) - pm = ir.pass_manager(mod.context) - dump_enabled = pm.enable_debug() - - pass_entries = pass_config[0] - pass_list = pass_config[1] - cluster_info = pass_config[2] - - new_pass_list = list() - if(global_config == None): - new_pass_list = pass_list - else: - config_pass_entries = global_config["ttgir"]["passes"] - for e in config_pass_entries: - if(e in list(pass_entries.keys())): - new_pass_list.append([pass_entries[e][0], pass_entries[e][1]]) - - for e in new_pass_list: - p = e[0] - args = e[1] - if len(args) == 0: - p(pm) - elif len(args) == 1: - p(pm, args[0]) - elif len(args) == 2: - p(pm, args[0], args[1]) - elif len(args) == 3: - p(pm, args[0], args[1], args[2]) - elif len(args) == 4: - p(pm, args[0], args[1], args[2], args[3]) - else: - raise Exception("Bad Arg Count") + nvidia.passes.ttnvgpuir.add_tma_lowering(pm) + nvidia.passes.ttnvgpuir.add_fence_insertion(pm, capability) + nvidia.passes.ttnvgpuir.add_lower_mma(pm) + passes.common.add_sccp(pm) + passes.common.add_cse(pm) + passes.common.add_canonicalizer(pm) pm.run(mod, '.make_ttgir.repro.mlir') metadata["cluster_dims"] = (cluster_info.clusterDimX, cluster_info.clusterDimY, cluster_info.clusterDimZ) @@ -596,7 +529,7 @@ def add_stages(self, stages, options, language): global_config = json.load(f) capability = self._parse_arch(options.arch) if language == Language.TRITON: - ttgir_pass_config = self.make_ttgir_pass_config(options.num_warps, options.num_ctas, options.num_stages, capability, options.cluster_dims, dump_enabled=False) + # ttgir_pass_config = self.make_ttgir_pass_config(options.num_warps, options.num_ctas, options.num_stages, capability, options.cluster_dims, dump_enabled=False) stages["ttir"] = lambda src, metadata: self.make_ttir(src, metadata, options, capability) stages["ttgir"] = lambda src, metadata: self.make_ttgir(src, metadata, options, capability, ttgir_pass_config, global_config) elif language == Language.GLUON: From 4f92836d01e25655b5c745a74e57829c57865268 Mon Sep 17 00:00:00 2001 From: Puyan Lotfi Date: Fri, 29 Aug 2025 14:42:57 -0700 Subject: [PATCH 11/11] Bring your own compiler.py --- byo_compiler.py | 65 ++++++++++++++++++++++++++ third_party/nvidia/backend/compiler.py | 16 +++++-- 2 files changed, 78 insertions(+), 3 deletions(-) create mode 100644 byo_compiler.py diff --git a/byo_compiler.py b/byo_compiler.py new file mode 100644 index 000000000000..83467a013885 --- /dev/null +++ b/byo_compiler.py @@ -0,0 +1,65 @@ + +@staticmethod +def byo_make_ttgir(pm, mod, metadata, opt, capability, cluster_info, dump_enabled, passes, nvidia): + + passes.ttir.add_convert_to_ttgpuir(pm, f"cuda:{capability}", opt.num_warps, 32, opt.num_ctas) + # optimize TTGIR + passes.ttgpuir.add_coalesce(pm) + if capability // 10 >= 8: + passes.ttgpuir.add_f32_dot_tc(pm) + # TODO(Qingyi): Move PlanCTAPass to the front of CoalescePass + nvidia.passes.ttnvgpuir.add_plan_cta(pm, cluster_info) + passes.ttgpuir.add_remove_layout_conversions(pm) + passes.ttgpuir.add_optimize_thread_locality(pm) + passes.ttgpuir.add_accelerate_matmul(pm) + passes.ttgpuir.add_remove_layout_conversions(pm) + passes.ttgpuir.add_optimize_dot_operands(pm, capability >= 80) + nvidia.passes.ttnvgpuir.add_optimize_descriptor_encoding(pm) + passes.ttir.add_loop_aware_cse(pm) + if capability // 10 in [8, 9]: + passes.ttgpuir.add_fuse_nested_loops(pm) + passes.common.add_canonicalizer(pm) + passes.ttir.add_triton_licm(pm) + passes.common.add_canonicalizer(pm) + passes.ttgpuir.add_combine_tensor_select_and_if(pm) + nvidia.passes.hopper.add_hopper_warpspec(pm, opt.num_stages, dump_enabled) + passes.ttgpuir.add_assign_latencies(pm, opt.num_stages) + passes.ttgpuir.add_schedule_loops(pm) + passes.ttgpuir.add_pipeline(pm, opt.num_stages, dump_enabled) + elif capability // 10 >= 10: + passes.ttgpuir.add_fuse_nested_loops(pm) + passes.common.add_canonicalizer(pm) + passes.ttir.add_triton_licm(pm) + passes.ttgpuir.add_optimize_accumulator_init(pm) + passes.ttgpuir.add_hoist_tmem_alloc(pm, False) + nvidia.passes.ttnvgpuir.add_promote_lhs_to_tmem(pm) + passes.ttgpuir.add_assign_latencies(pm, opt.num_stages) + passes.ttgpuir.add_schedule_loops(pm) + passes.ttgpuir.add_warp_specialize(pm, opt.num_stages) + passes.ttgpuir.add_pipeline(pm, opt.num_stages, dump_enabled) + passes.ttgpuir.add_combine_tensor_select_and_if(pm) + # hoist again and allow hoisting out of if statements + passes.ttgpuir.add_hoist_tmem_alloc(pm, True) + nvidia.passes.ttnvgpuir.add_remove_tmem_tokens(pm) + else: + passes.ttir.add_triton_licm(pm) + passes.common.add_canonicalizer(pm) + passes.ttir.add_loop_aware_cse(pm) + passes.ttgpuir.add_prefetch(pm) + passes.ttgpuir.add_optimize_dot_operands(pm, capability >= 80) + passes.ttgpuir.add_coalesce_async_copy(pm) + nvidia.passes.ttnvgpuir.add_optimize_tmem_layouts(pm) + passes.ttgpuir.add_remove_layout_conversions(pm) + nvidia.passes.ttnvgpuir.add_interleave_tmem(pm) + passes.ttgpuir.add_reduce_data_duplication(pm) + passes.ttgpuir.add_reorder_instructions(pm) + passes.ttir.add_loop_aware_cse(pm) + passes.common.add_symbol_dce(pm) + if capability // 10 >= 9: + nvidia.passes.ttnvgpuir.add_tma_lowering(pm) + nvidia.passes.ttnvgpuir.add_fence_insertion(pm, capability) + nvidia.passes.ttnvgpuir.add_lower_mma(pm) + passes.common.add_sccp(pm) + passes.common.add_cse(pm) + passes.common.add_canonicalizer(pm) + diff --git a/third_party/nvidia/backend/compiler.py b/third_party/nvidia/backend/compiler.py index 5a6c81b8a4a6..72560c8bdf72 100644 --- a/third_party/nvidia/backend/compiler.py +++ b/third_party/nvidia/backend/compiler.py @@ -16,6 +16,9 @@ from pathlib import Path import json +from importlib.util import spec_from_file_location, module_from_spec +import sys + def min_dot_size(target: GPUTarget): @@ -257,8 +260,15 @@ def make_ttgir(mod, metadata, opt, capability): pm = ir.pass_manager(mod.context) dump_enabled = pm.enable_debug() - if doBYOPassSetup: - byo_make_ttgiir(mod, metadata, opt, capability, passes, nvidia) + file_path = '/home/plotfi/opt/dev/Triton-MetaGPU-Clean/triton/byo_compiler.py' + module_name = 'byo_compiler_setup' + + spec = spec_from_file_location(module_name, file_path) + if spec: + module = module_from_spec(spec) + sys.modules[module_name] = module # Add to sys.modules if you want it discoverable + spec.loader.exec_module(module) + module.byo_make_ttgir(pm, mod, metadata, opt, capability, cluster_info, dump_enabled, passes, nvidia) pm.run(mod, '.make_ttgir.repro.mlir') metadata["cluster_dims"] = (cluster_info.clusterDimX, cluster_info.clusterDimY, cluster_info.clusterDimZ) tensordesc_meta = mod.get_tensordesc_metadata() @@ -531,7 +541,7 @@ def add_stages(self, stages, options, language): if language == Language.TRITON: # ttgir_pass_config = self.make_ttgir_pass_config(options.num_warps, options.num_ctas, options.num_stages, capability, options.cluster_dims, dump_enabled=False) stages["ttir"] = lambda src, metadata: self.make_ttir(src, metadata, options, capability) - stages["ttgir"] = lambda src, metadata: self.make_ttgir(src, metadata, options, capability, ttgir_pass_config, global_config) + stages["ttgir"] = lambda src, metadata: self.make_ttgir(src, metadata, options, capability) elif language == Language.GLUON: stages["ttgir"] = lambda src, metadata: self.gluon_to_ttgir(src, metadata, options, capability) stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options, capability)