Skip to content

Commit f06e43d

Browse files
committed
add tptr-to-llvm pass
1 parent 920d006 commit f06e43d

File tree

14 files changed

+608
-8
lines changed

14 files changed

+608
-8
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ if(TRITON_SHARED_BUILD_CPU_BACKEND)
4040
MLIRControlFlowToLLVM
4141
MLIRMemRefToLLVM
4242
MLIRVectorToSCF
43+
TPtrToLLVM
4344
MLIRUBToLLVM
4445
# MLIRMathToLibm
4546

backend/compiler.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,8 @@ def _ttsharedir_to_llir(ttsharedir: str):
100100
triton_shared.to_llir.add_memref_expand(pm)
101101
triton_shared.to_llir.add_finalize_memref_to_llvm(pm)
102102
triton_shared.to_llir.add_convert_func_to_llvm(pm)
103+
# triton_shared.debug.enable_mlir_debug("tptr-to-llvm")
104+
triton_shared.to_llir.add_convert_tptr_to_llvm(pm)
103105

104106
triton_shared.to_llir.add_convert_cf_to_llvm(pm)
105107
triton_shared.to_llir.add_lower_affine(pm)

include/triton-shared/Conversion/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@ add_subdirectory(TritonPtrToMemref)
66
add_subdirectory(TritonToUnstructured)
77
add_subdirectory(StructuredToMemref)
88
add_subdirectory(UnstructuredToMemref)
9+
add_subdirectory(TPtrToLLVM)
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
set(LLVM_TARGET_DEFINITIONS Passes.td)
2+
mlir_tablegen(Passes.h.inc -gen-pass-decls --name TPtrToLLVM)
3+
add_public_tablegen_target(TPtrToLLVMConversionPassIncGen)
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
#ifndef TPTR_TO_LLVM_CONVERSION_PASSES_H
2+
#define TPTR_TO_LLVM_CONVERSION_PASSES_H
3+
4+
#include "triton-shared/Conversion/TPtrToLLVM/TPtrToLLVM.h"
5+
6+
namespace mlir {
7+
namespace tptr {
8+
9+
#define GEN_PASS_REGISTRATION
10+
#include "triton-shared/Conversion/TPtrToLLVM/Passes.h.inc"
11+
12+
} // namespace triton
13+
} // namespace mlir
14+
15+
#endif
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
#ifndef TPTR_TO_LLVM_CONVERSION_PASSES
2+
#define TPTR_TO_LLVM_CONVERSION_PASSES
3+
4+
include "mlir/Pass/PassBase.td"
5+
6+
def TPtrToLLVM : Pass<"tptr-to-llvm", "mlir::ModuleOp"> {
7+
let summary = "Convert Tptr operations into LLVM";
8+
let options = [
9+
// Option<"addptrToLLVM", "ptradd-to-llvm", "bool", /*default*/"true",
10+
// "Convert tptr.ptradd on tensors to llvm">,
11+
// Option<"fromMemrefToLLVM", "from-memref-to-llvm", "bool", /*default*/"true",
12+
// "Convert tptr.from_memref on tensors to llvm">,
13+
// Option<"toMemrefToLLVM", "to-memref-to-llvm", "bool", /*default*/"true",
14+
// "Convert tptr.to_memref on tensors to llvm">,
15+
// Option<"typeoffsetToConst", "typeoffset-to-const", "bool", /*default*/"true",
16+
// "Convert tptr.typeoffset to llvm.mlir.const">,
17+
];
18+
let dependentDialects = ["mlir::tptr::TPtrDialect", "mlir::LLVM::LLVMDialect"];
19+
}
20+
#endif
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
#ifndef TRITON_CONVERSION_TPTR_TO_LLVM_TPTRTOLLVM_H
2+
#define TRITON_CONVERSION_TPTR_TO_LLVM_TPTRTOLLVM_H
3+
4+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
5+
#include "mlir/Pass/Pass.h"
6+
#include "mlir/Transforms/DialectConversion.h"
7+
#include "triton-shared/Dialect/TPtr/IR/TPtrDialect.h"
8+
9+
namespace mlir {
10+
namespace tptr {
11+
12+
#define GEN_PASS_DECL
13+
// #define GEN_PASS_CLASSES
14+
// #define GEN_PASS_DECL_TPTRTOLLVM
15+
// #define GEN_PASS_DEF_TPTRTOLLVM
16+
#include "triton-shared/Conversion/TPtrToLLVM/Passes.h.inc"
17+
18+
// void populateTPtrToLLVMCanonicalizationPatterns(RewritePatternSet &patterns);
19+
20+
void populateTPtrToLLVMConversionPatterns(RewritePatternSet &patterns,
21+
TypeConverter &typeconverter);
22+
23+
std::unique_ptr<OperationPass<ModuleOp>> createTPtrToLLVMPass();
24+
25+
static bool isOneToOneCast(UnrealizedConversionCastOp op) {
26+
return (op.getInputs().size() == 1 && op->getNumResults() == 1);
27+
}
28+
29+
} // namespace tptr
30+
} // namespace mlir
31+
32+
#endif

lib/Conversion/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@ add_subdirectory(TritonArithToLinalg)
66
add_subdirectory(StructuredToMemref)
77
add_subdirectory(TritonPtrToMemref)
88
add_subdirectory(UnstructuredToMemref)
9+
add_subdirectory(TPtrToLLVM)
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
add_triton_library(TPtrToLLVM
2+
TPtrToLLVMPass.cpp
3+
TPtrToLLVM.cpp
4+
5+
DEPENDS
6+
TPtrToLLVMConversionPassIncGen
7+
TPtrTableGen
8+
9+
LINK_LIBS PUBLIC
10+
MLIRIR
11+
MLIRPass
12+
MLIRTransforms
13+
MLIRSupport
14+
MLIRReconcileUnrealizedCasts
15+
TPtrIR
16+
MLIRDialectUtils
17+
)

0 commit comments

Comments
 (0)