Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions byo_compiler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@

@staticmethod
def byo_make_ttgir(pm, mod, metadata, opt, capability, cluster_info, dump_enabled, passes, nvidia):

passes.ttir.add_convert_to_ttgpuir(pm, f"cuda:{capability}", opt.num_warps, 32, opt.num_ctas)
# optimize TTGIR
passes.ttgpuir.add_coalesce(pm)
if capability // 10 >= 8:
passes.ttgpuir.add_f32_dot_tc(pm)
# TODO(Qingyi): Move PlanCTAPass to the front of CoalescePass
nvidia.passes.ttnvgpuir.add_plan_cta(pm, cluster_info)
passes.ttgpuir.add_remove_layout_conversions(pm)
passes.ttgpuir.add_optimize_thread_locality(pm)
passes.ttgpuir.add_accelerate_matmul(pm)
passes.ttgpuir.add_remove_layout_conversions(pm)
passes.ttgpuir.add_optimize_dot_operands(pm, capability >= 80)
nvidia.passes.ttnvgpuir.add_optimize_descriptor_encoding(pm)
passes.ttir.add_loop_aware_cse(pm)
if capability // 10 in [8, 9]:
passes.ttgpuir.add_fuse_nested_loops(pm)
passes.common.add_canonicalizer(pm)
passes.ttir.add_triton_licm(pm)
passes.common.add_canonicalizer(pm)
passes.ttgpuir.add_combine_tensor_select_and_if(pm)
nvidia.passes.hopper.add_hopper_warpspec(pm, opt.num_stages, dump_enabled)
passes.ttgpuir.add_assign_latencies(pm, opt.num_stages)
passes.ttgpuir.add_schedule_loops(pm)
passes.ttgpuir.add_pipeline(pm, opt.num_stages, dump_enabled)
elif capability // 10 >= 10:
passes.ttgpuir.add_fuse_nested_loops(pm)
passes.common.add_canonicalizer(pm)
passes.ttir.add_triton_licm(pm)
passes.ttgpuir.add_optimize_accumulator_init(pm)
passes.ttgpuir.add_hoist_tmem_alloc(pm, False)
nvidia.passes.ttnvgpuir.add_promote_lhs_to_tmem(pm)
passes.ttgpuir.add_assign_latencies(pm, opt.num_stages)
passes.ttgpuir.add_schedule_loops(pm)
passes.ttgpuir.add_warp_specialize(pm, opt.num_stages)
passes.ttgpuir.add_pipeline(pm, opt.num_stages, dump_enabled)
passes.ttgpuir.add_combine_tensor_select_and_if(pm)
# hoist again and allow hoisting out of if statements
passes.ttgpuir.add_hoist_tmem_alloc(pm, True)
nvidia.passes.ttnvgpuir.add_remove_tmem_tokens(pm)
else:
passes.ttir.add_triton_licm(pm)
passes.common.add_canonicalizer(pm)
passes.ttir.add_loop_aware_cse(pm)
passes.ttgpuir.add_prefetch(pm)
passes.ttgpuir.add_optimize_dot_operands(pm, capability >= 80)
passes.ttgpuir.add_coalesce_async_copy(pm)
nvidia.passes.ttnvgpuir.add_optimize_tmem_layouts(pm)
passes.ttgpuir.add_remove_layout_conversions(pm)
nvidia.passes.ttnvgpuir.add_interleave_tmem(pm)
passes.ttgpuir.add_reduce_data_duplication(pm)
passes.ttgpuir.add_reorder_instructions(pm)
passes.ttir.add_loop_aware_cse(pm)
passes.common.add_symbol_dce(pm)
if capability // 10 >= 9:
nvidia.passes.ttnvgpuir.add_tma_lowering(pm)
nvidia.passes.ttnvgpuir.add_fence_insertion(pm, capability)
nvidia.passes.ttnvgpuir.add_lower_mma(pm)
passes.common.add_sccp(pm)
passes.common.add_cse(pm)
passes.common.add_canonicalizer(pm)

57 changes: 57 additions & 0 deletions default.config
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
{
"ttgir" : {
"passes" : [
"ttir.add_convert_to_ttgpuir",
"ttgpuir.add_coalesce",
"ttgpuir.add_f32_dot_tc",
"ttnvgpuir.add_plan_cta",
"ttgpuir.add_remove_layout_conversions",
"ttgpuir.add_optimize_thread_locality",
"ttgpuir.add_accelerate_matmul",
"ttgpuir.add_remove_layout_conversions",
"ttgpuir.add_optimize_dot_operands",
"ttnvgpuir.add_optimize_descriptor_encoding",
"ttir.add_loop_aware_cse",
"ttgpuir.add_fuse_nested_loops",
"common.add_canonicalizer",
"ttir.add_triton_licm",
"common.add_canonicalizer",
"ttgpuir.add_combine_tensor_select_and_if",
"nvidia.hopper.add_hopper_warpspec",
"ttgpuir.add_assign_latencies",
"ttgpuir.add_schedule_loops",
"ttgpuir.add_pipeline",
"ttgpuir.add_fuse_nested_loops",
"common.add_canonicalizer",
"ttir.add_triton_licm",
"ttgpuir.add_optimize_accumulator_init",
"ttgpuir.add_hoist_tmem_alloc{False}",
"ttnvgpuir.add_promote_lhs_to_tmem",
"ttgpuir.add_assign_latencies",
"ttgpuir.add_schedule_loops",
"ttgpuir.add_warp_specialize",
"ttgpuir.add_pipeline",
"ttgpuir.add_combine_tensor_select_and_if",
"ttgpuir.add_hoist_tmem_alloc{True}",
"ttnvgpuir.add_remove_tmem_tokens",
"ttir.add_triton_licm",
"common.add_canonicalizer",
"ttir.add_loop_aware_cse",
"ttgpuir.add_prefetch",
"ttgpuir.add_optimize_dot_operands",
"ttgpuir.add_coalesce_async_copy",
"ttnvgpuir.add_optimize_tmem_layouts",
"ttgpuir.add_remove_layout_conversions",
"ttnvgpuir.add_interleave_tmem",
"ttgpuir.add_reduce_data_duplication",
"ttgpuir.add_reorder_instructions",
"ttir.add_loop_aware_cse",
"common.add_symbol_dce",
"ttnvgpuir.add_tma_lowering",
"ttnvgpuir.add_fence_insertion",
"ttnvgpuir.add_lower_mma",
"common.add_sccp",
"common.add_canonicalizer"
]
}
}
2 changes: 2 additions & 0 deletions include/triton/Tools/Sys/GetEnv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ inline const std::set<std::string> CACHE_INVALIDATING_ENV_VARS = {
"LLVM_IR_ENABLE_DUMP",
"LLVM_ENABLE_TIMING",
"LLVM_PASS_PLUGIN_PATH",
"PASS_MANAGER_CONFIG_PATH",
"MLIR_ENABLE_DIAGNOSTICS",
"MLIR_ENABLE_DUMP",
"MLIR_DUMP_PATH",
Expand All @@ -44,6 +45,7 @@ inline const std::set<std::string> CACHE_INVALIDATING_ENV_VARS = {
"ALLOW_LHS_TMEM_LAYOUT_CONVERSION",
"TRITON_F32_DEFAULT",
"TRITON_PREFER_TMEM_16x256_LAYOUT",
"MLIR_PASS_PLUGIN_PATH",
// clang-format on
};

Expand Down
9 changes: 8 additions & 1 deletion python/src/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1816,7 +1816,7 @@ void init_triton_ir(py::module &&m) {
})
.def(
"run",
[](PassManager &self, ModuleOp &mod) {
[](PassManager &self, ModuleOp &mod, std::string repro_suffix) {
// TODO: maybe dump module to file and print error for better
// diagnostics

Expand Down Expand Up @@ -1900,7 +1900,14 @@ void init_triton_ir(py::module &&m) {
}
if (failed(self.run(mod.getOperation())))
throw std::runtime_error("PassManager::run failed");

if (!repro_suffix.empty() && !reproducerPath.empty() && reproducerPath != "-" &&
llvm::sys::fs::copy_file(reproducerPath, reproducerPath + repro_suffix)) {
throw std::runtime_error("PassManager::run failed (repro temp)");
}
},
py::arg("mod"),
py::arg("repro_suffix") = "",
py::call_guard<py::gil_scoped_release>());
}

Expand Down
16 changes: 16 additions & 0 deletions python/src/passes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
#include "triton/Dialect/TritonInstrument/Transforms/Passes.h"
#include "triton/Target/LLVMIR/Passes.h"
#include "triton/Tools/Sys/GetEnv.hpp"
#include "mlir/Tools/Plugins/PassPlugin.h"
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

Expand Down Expand Up @@ -92,6 +94,20 @@ void init_triton_passes_ttgpuir(py::module &&m) {
createTritonGPUCoalesceAsyncCopy);
ADD_PASS_WRAPPER_0("add_concurrency_sanitizer",
createTritonInstrumentConcurrencySanitizer);

std::string pluginFile =
mlir::triton::tools::getStrEnv("MLIR_PASS_PLUGIN_PATH");

if (!pluginFile.empty()) {
auto plugin = mlir::PassPlugin::load(pluginFile);
if (!plugin) {
llvm::Error Err = plugin.takeError();
std::string ErrMsg =
"Pass Plugin Error: " + llvm::toString(std::move(Err));
throw std::runtime_error(ErrMsg);
}
plugin.get().registerPassRegistryCallbacks();
}
}

void init_triton_passes_convert(py::module &&m) {
Expand Down
8 changes: 8 additions & 0 deletions python/test/unit/language/test_reproducer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from triton._internal_testing import is_cuda, is_hip
import triton
import re
import os


def test_triton_reproducer_path(monkeypatch, tmp_path):
Expand All @@ -21,6 +23,12 @@ def triton_():
# matter what the kernel does, just that the PassManager runs its passes.
triton_[(1, )]()

if is_cuda() and not is_hip():
base_path = str(repro_path)
assert os.path.exists(base_path + '.make_ttir.repro.mlir')
assert os.path.exists(base_path + '.make_ttgir.repro.mlir')
assert os.path.exists(base_path + '.make_llir.repro.mlir')

repro = repro_path.read_text()
assert "mlir_reproducer" in repro, f"Expected MLIR reproducer in {repro_path}. Got:\n{repro}"
m = re.search(r"pipeline: \"(.*)\"", repro)
Expand Down
1 change: 1 addition & 0 deletions test/lib/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ add_subdirectory(Analysis)
add_subdirectory(Dialect)
add_subdirectory(Instrumentation)
add_subdirectory(Proton)
add_subdirectory(Extensions)
39 changes: 39 additions & 0 deletions test/lib/Extensions/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls -name Extensions)
add_public_tablegen_target(TritonGPUExtensionIncGen)

set(GPU_EXTENSION_PASSES
GPUExtensionTestLib
)

set(GPUExtensionTestLib_SOURCES
ExtensionHello.cpp
)

# MODULE

foreach( plugin ${GPU_EXTENSION_PASSES} )
add_mlir_library(${plugin}
${${plugin}_SOURCES}
SHARED

ADDITIONAL_HEADER_DIRS
${PROJECT_BINARY_DIR}/test/lib/Extensions

DEPENDS
TritonGPUExtensionIncGen

LINK_LIBS PUBLIC
MLIRIR
MLIRInferTypeOpInterface
MLIRFuncDialect
)
set_target_properties(${plugin} PROPERTIES LIBRARY_OUTPUT_DIRECTORY
${PROJECT_BINARY_DIR}/test/lib/Extensions)
# This is set to -fvisibility=hidden in the top level CMake file
# which causes the llvmGetPassPluginInfo symbol to be hidden and
# an "entry point not found" error. Reset it just for this target
if(NOT MSVC)
target_compile_options(${plugin} PRIVATE -fvisibility=default)
endif()
endforeach()
106 changes: 106 additions & 0 deletions test/lib/Extensions/ExtensionHello.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@

#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Dominance.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Verifier.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h"
#include "mlir/Transforms/RegionUtils.h"
#include "triton/Analysis/Utility.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h"
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
#include <deque>

#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"

#include "mlir/IR/MLIRContext.h"
#include "mlir/InitAllDialects.h"
#include "mlir/Tools/Plugins/DialectPlugin.h"

#include "mlir/Tools/Plugins/PassPlugin.h"
#include "llvm/Config/llvm-config.h"
#include "llvm/Support/Compiler.h"

using namespace mlir;

/// Dialect plugin registration mechanism.
/// Observe that it also allows to register passes.
/// Necessary symbol to register the dialect plugin.
// extern "C" LLVM_ATTRIBUTE_WEAK DialectPluginLibraryInfo
// mlirGetDialectPluginInfo() {
// return {MLIR_PLUGIN_API_VERSION, "Standalone", LLVM_VERSION_STRING,
// [](DialectRegistry *registry) {
// registry->insert<mlir::standalone::StandaloneDialect>();
// mlir::standalone::registerPasses();
// }};
// }

#include "mlir/Pass/Pass.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"

extern "C" void FOOBAR() { }

namespace mlir {
namespace triton {
namespace gpu {

#define GEN_PASS_DEF_TRITONGPUHELLOEXTENSION
#include "Passes.h.inc"

namespace {
struct HelloExtension : public OpRewritePattern<DotOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(DotOp dotOp,
PatternRewriter &rewriter) const override {
return success();
}
};
} // anonymous namespace

struct HelloExtensionPass :
public impl::TritonGPUHelloExtensionBase<HelloExtensionPass> {
void runOnOperation() override {
// MLIRContext *context = &getContext();
// ModuleOp m = getOperation();
// RewritePatternSet decomposePatterns(context);
// decomposePatterns.add<HelloExtension>(context);
// if (applyPatternsGreedily(m, std::move(decomposePatterns)).failed()) {
// signalPassFailure();
// }
}
};

} // namespace gpu
} // namespace triton
} // namespace mlir


void registerPasses() {
::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
return mlir::triton::gpu::createTritonGPUHelloExtension();
});
}


/// Pass plugin registration mechanism.
/// Necessary symbol to register the pass plugin.
extern "C" LLVM_ATTRIBUTE_WEAK PassPluginLibraryInfo mlirGetPassPluginInfo() {
return {MLIR_PLUGIN_API_VERSION, "HelloExtensionPlugin", LLVM_VERSION_STRING,
[]() { registerPasses(); }};
}
10 changes: 10 additions & 0 deletions test/lib/Extensions/Passes.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#ifndef TRITONGPU_EXTENSION_PASSES
#define TRITONGPU_EXTENSION_PASSES

include "mlir/Pass/PassBase.td"

def TritonGPUHelloExtension : Pass<"tritongpu-HelloExtension", "mlir::ModuleOp"> {
let summary = "Hello World Extension";
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect"];
}
#endif
Loading