Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,5 @@ build/*
.vscode/*
/python/examples/test_core.py
/python/examples/test_annotations.py

.clangd/*
19 changes: 15 additions & 4 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,28 @@ add_subdirectory(lib)
add_subdirectory(test)
add_subdirectory(tools/triton-shared-opt)

if (TRITON_SHARED_BUILD_CPU_BACKEND)
if(TRITON_SHARED_BUILD_CPU_BACKEND)
add_triton_plugin(TritonShared ${CMAKE_CURRENT_SOURCE_DIR}/triton_shared.cc LINK_LIBS TritonSharedAnalysis TritonToLinalg TritonTilingExtIR)
target_link_libraries(TritonShared PRIVATE Python3::Module pybind11::headers)
target_link_libraries(TritonShared
PRIVATE

MLIRRegisterAllDialects
TPtrToLLVM

Python3::Module
pybind11::headers
)
endif()

# Add symlinks to selected pytest files in triton. The tests are imported into triton-shared’s test folder to
# run under triton-shared's conftest configuration.
# Add symlinks to selected pytest files and the clang-format setting in triton. The tests are imported into triton-shared’s test folder to
# run under triton-shared's conftest configuration, and the clang-format link ensures consistent code style enforcement across both repositories.
cmake_path(APPEND CMAKE_CURRENT_SOURCE_DIR "python" "examples" "test_core.py" OUTPUT_VARIABLE TRITON_SHARED_TEST_CORE)
cmake_path(APPEND CMAKE_CURRENT_SOURCE_DIR "python" "examples" "test_annotations.py" OUTPUT_VARIABLE TRITON_SHARED_TEST_ANNOTATIONS)
cmake_path(APPEND CMAKE_CURRENT_SOURCE_DIR ".clang-format" OUTPUT_VARIABLE TRITON_SHARED_CLANG_FORMAT_SETTING)
cmake_path(APPEND CMAKE_SOURCE_DIR "python" "test" "unit" "language" "test_core.py" OUTPUT_VARIABLE TRITON_TEST_CORE)
cmake_path(APPEND CMAKE_SOURCE_DIR "python" "test" "unit" "language" "test_annotations.py" OUTPUT_VARIABLE TRITON_TEST_ANNOTATIONS)
cmake_path(APPEND CMAKE_SOURCE_DIR ".clang-format" OUTPUT_VARIABLE TRITON_CLANG_FORMAT_SETTING)

add_symlink(${TRITON_TEST_CORE} ${TRITON_SHARED_TEST_CORE})
add_symlink(${TRITON_TEST_ANNOTATIONS} ${TRITON_SHARED_TEST_ANNOTATIONS})
add_symlink(${TRITON_CLANG_FORMAT_SETTING} ${TRITON_SHARED_CLANG_FORMAT_SETTING})
90 changes: 48 additions & 42 deletions backend/compiler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from triton.backends.compiler import BaseBackend, GPUTarget
from triton._C.libtriton import ir, passes
from triton._C.libtriton import ir, passes, triton_shared
from dataclasses import dataclass
from typing import Any, Dict, Tuple
from types import ModuleType
Expand Down Expand Up @@ -42,7 +42,7 @@ def _get_sanitizer_type():
if sanitizer_type != "" and sanitizer_type != "asan" and sanitizer_type != "tsan":
# throw error
raise Exception(f"TRITON_SHARED_SANITIZER_TYPE {sanitizer_type} is invalid.")

return sanitizer_type

def _ttir_to_ttsharedir(mod):
Expand Down Expand Up @@ -78,41 +78,47 @@ def _ttsharedir_to_llir(ttsharedir: str):
llmlir_path = os.path.join(tmpdir, "ll.mlir")
llir_path = os.path.join(tmpdir, "ll.ir")
Path(ttshared_path).write_text(ttsharedir)
mlir_opt_path = _get_llvm_bin_path("mlir-opt")
context = ir.context()
triton_shared.ir.load_dialects(context)
mod = ir.parse_mlir_module(ttshared_path, context)
# TritonShared-MLIR to LLVM-MLIR
subprocess.check_call([mlir_opt_path, ttshared_path,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we preserve all the original comments since they explain why we need certain passes.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Of course

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not resolved, there are other comments here.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I didn't understand your meaning well before, I will add it back

"--convert-linalg-to-affine-loops",
# Note: eliminate-empty-tensors fails when there are multiple func.return ops
# in a single kernel which are the results of early returns.
# See python/examples/test_early_return.py for examples.
# We disable this pass for now since performance on CPU isn't the main
# focus at the moment.
# "--eliminate-empty-tensors",
"--empty-tensor-to-alloc-tensor",
"--one-shot-bufferize=allow-return-allocs-from-loops=true",
"--lower-affine",
"--convert-linalg-to-loops",
"--expand-strided-metadata",
"--convert-scf-to-cf",
"--convert-arith-to-llvm",
"--convert-math-to-llvm",
"--convert-complex-to-llvm",
"--convert-vector-to-llvm",
"--convert-index-to-llvm",
"--memref-expand",
"--finalize-memref-to-llvm",
"--convert-func-to-llvm",
"--convert-cf-to-llvm",
# Lowering memrefs creates more affine.apply ops.
# Lowering these affine ops again creates further arith ops,
# so we have to run these two passes again here.
"--lower-affine",
"--convert-arith-to-llvm",
# Remove all unrealized casts created
"--reconcile-unrealized-casts",
"--mlir-print-debuginfo",
"-o",
llmlir_path])

pm = ir.pass_manager(context)
pm.enable_debug()
triton_shared.to_llir.add_convert_linalg_to_affine_loops(pm)
# Note: eliminate-empty-tensors fails when there are multiple func.return ops
# in a single kernel which are the results of early returns.
# See python/examples/test_early_return.py for examples.
# We disable this pass for now since performance on CPU isn't the main
# focus at the moment.
# triton_shared.to_llir.add_eliminate_empty_tensors(pm)
triton_shared.to_llir.add_empty_tensor_to_alloc_tensor(pm)
triton_shared.to_llir.add_one_shot_bufferize_with_options(
pm, allow_return_allocs_from_loops=True)
triton_shared.to_llir.add_lower_affine(pm)
triton_shared.to_llir.add_convert_linalg_to_loops(pm)
triton_shared.to_llir.add_expand_strided_metadata(pm)
triton_shared.to_llir.add_convert_scf_to_cf(pm)
triton_shared.to_llir.add_convert_tptr_to_llvm(pm)
triton_shared.to_llir.add_convert_arith_to_llvm(pm)
triton_shared.to_llir.add_convert_math_to_llvm(pm)
triton_shared.to_llir.add_convert_complex_to_llvm(pm)
triton_shared.to_llir.add_convert_vector_to_llvm(pm)
triton_shared.to_llir.add_convert_index_to_llvm(pm)
triton_shared.to_llir.add_memref_expand(pm)
triton_shared.to_llir.add_finalize_memref_to_llvm(pm)
triton_shared.to_llir.add_convert_func_to_llvm(pm)
triton_shared.to_llir.add_convert_cf_to_llvm(pm)
# Lowering memrefs creates more affine.apply ops.
# Lowering these affine ops again creates further arith ops,
# so we have to run these two passes again here.
triton_shared.to_llir.add_lower_affine(pm)
triton_shared.to_llir.add_convert_arith_to_llvm(pm)
# Remove all unrealized casts created
triton_shared.to_llir.add_reconcile_unrealized_casts(pm)
pm.run(mod)

Path(llmlir_path).write_text(str(mod))

# LLVM-MLIR to LLVM-IR
mlir_translate_path = _get_llvm_bin_path("mlir-translate")
Expand Down Expand Up @@ -145,16 +151,16 @@ def _llir_to_bin(llir: str, metadata):
# using a sanitizer
# invoke pass to append sanitizer attributes
instrumented_src_path = os.path.join(tmpdir, "kernel-instrumented.ll")

opt_path = _get_llvm_bin_path("opt")
top_level_triton_path = os.path.dirname(triton.__file__)
sanitizer_attributes_pass_path = str(next(Path(top_level_triton_path).rglob("libSanitizerAttributes.so"), None))

if not sanitizer_attributes_pass_path:
raise Exception(f"libSanitizerAttributes.so does not exist.")

subprocess.check_call([opt_path, "-load-pass-plugin", sanitizer_attributes_pass_path,
"-passes=sanitizer-attributes", f"-sanitizer-type={sanitizer_type}", "-S", src_path,
subprocess.check_call([opt_path, "-load-pass-plugin", sanitizer_attributes_pass_path,
"-passes=sanitizer-attributes", f"-sanitizer-type={sanitizer_type}", "-S", src_path,
"-o", instrumented_src_path])

# compile to object file
Expand All @@ -166,12 +172,12 @@ def _llir_to_bin(llir: str, metadata):
subprocess_args.extend(["-g", "-fsanitize=address", "-mllvm", "-asan-stack=0"])
elif sanitizer_type == "tsan":
subprocess_args.extend(["-g", "-fsanitize=thread"])

subprocess.check_call(subprocess_args)
else:
llc_path = _get_llvm_bin_path("llc")
subprocess.check_call([llc_path, src_path, "-filetype=obj", "-relocation-model=pic", "-o", dst_path])

return Path(dst_path).read_bytes()


Expand Down Expand Up @@ -265,11 +271,11 @@ def add_stages(self, stages, options, language):
stages["llir"] = lambda src, metadata: _optimize_llir(_ttsharedir_to_llir(src))
stages["obj"] = lambda src, metadata: _llir_to_bin(src, metadata)


@functools.lru_cache()
def hash(self):
return self.target

# The CPU backend does not use any extra python modules, return an empty dictionary
def get_module_map(self) -> Dict[str, ModuleType]:
return {}

1 change: 1 addition & 0 deletions include/triton-shared/Conversion/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ add_subdirectory(TritonPtrToMemref)
add_subdirectory(TritonToUnstructured)
add_subdirectory(StructuredToMemref)
add_subdirectory(UnstructuredToMemref)
add_subdirectory(TPtrToLLVM)
3 changes: 3 additions & 0 deletions include/triton-shared/Conversion/TPtrToLLVM/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls --name TPtrToLLVM)
add_public_tablegen_target(TPtrToLLVMConversionPassIncGen)
15 changes: 15 additions & 0 deletions include/triton-shared/Conversion/TPtrToLLVM/Passes.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#ifndef TPTR_TO_LLVM_CONVERSION_PASSES_H
#define TPTR_TO_LLVM_CONVERSION_PASSES_H

#include "triton-shared/Conversion/TPtrToLLVM/TPtrToLLVM.h"

namespace mlir {
namespace tptr {

#define GEN_PASS_REGISTRATION
#include "triton-shared/Conversion/TPtrToLLVM/Passes.h.inc"

} // namespace triton
} // namespace mlir

#endif
10 changes: 10 additions & 0 deletions include/triton-shared/Conversion/TPtrToLLVM/Passes.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#ifndef TPTR_TO_LLVM_CONVERSION_PASSES
#define TPTR_TO_LLVM_CONVERSION_PASSES

include "mlir/Pass/PassBase.td"

def TPtrToLLVM : Pass<"tptr-to-llvm", "mlir::ModuleOp"> {
let summary = "Convert Tptr operations into LLVM";
let dependentDialects = ["mlir::tptr::TPtrDialect", "mlir::LLVM::LLVMDialect", "mlir::ptr::PtrDialect"];
}
#endif
23 changes: 23 additions & 0 deletions include/triton-shared/Conversion/TPtrToLLVM/TPtrToLLVM.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#ifndef TRITON_CONVERSION_TPTR_TO_LLVM_TPTRTOLLVM_H
#define TRITON_CONVERSION_TPTR_TO_LLVM_TPTRTOLLVM_H

#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "triton-shared/Dialect/TPtr/IR/TPtrDialect.h"

namespace mlir {
namespace tptr {

#define GEN_PASS_DECL
#include "triton-shared/Conversion/TPtrToLLVM/Passes.h.inc"

void populateTPtrToLLVMConversionPatterns(RewritePatternSet &patterns,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit) Put a new line between the include and function declaration.

TypeConverter &typeconverter);

std::unique_ptr<OperationPass<ModuleOp>> createTPtrToLLVMPass();

} // namespace tptr
} // namespace mlir

#endif
1 change: 1 addition & 0 deletions lib/Conversion/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ add_subdirectory(TritonArithToLinalg)
add_subdirectory(StructuredToMemref)
add_subdirectory(TritonPtrToMemref)
add_subdirectory(UnstructuredToMemref)
add_subdirectory(TPtrToLLVM)
16 changes: 16 additions & 0 deletions lib/Conversion/TPtrToLLVM/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
add_triton_library(TPtrToLLVM
TPtrToLLVMPass.cpp
TPtrToLLVM.cpp

DEPENDS
TPtrToLLVMConversionPassIncGen
TPtrTableGen

LINK_LIBS PUBLIC
MLIRIR
MLIRPass
MLIRTransforms
MLIRSupport
MLIRDialectUtils
TPtrIR
)
Loading