diff --git a/.gitignore b/.gitignore index 544caa81..54728277 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,5 @@ build/* .vscode/* /python/examples/test_core.py /python/examples/test_annotations.py + +.clangd/* \ No newline at end of file diff --git a/CMakeLists.txt b/CMakeLists.txt index 7ae68d38..35f23839 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -22,17 +22,28 @@ add_subdirectory(lib) add_subdirectory(test) add_subdirectory(tools/triton-shared-opt) -if (TRITON_SHARED_BUILD_CPU_BACKEND) +if(TRITON_SHARED_BUILD_CPU_BACKEND) add_triton_plugin(TritonShared ${CMAKE_CURRENT_SOURCE_DIR}/triton_shared.cc LINK_LIBS TritonSharedAnalysis TritonToLinalg TritonTilingExtIR) - target_link_libraries(TritonShared PRIVATE Python3::Module pybind11::headers) + target_link_libraries(TritonShared + PRIVATE + + MLIRRegisterAllDialects + TPtrToLLVM + + Python3::Module + pybind11::headers + ) endif() -# Add symlinks to selected pytest files in triton. The tests are imported into triton-shared’s test folder to -# run under triton-shared's conftest configuration. +# Add symlinks to selected pytest files and the clang-format setting in triton. The tests are imported into triton-shared’s test folder to +# run under triton-shared's conftest configuration, and the clang-format link ensures consistent code style enforcement across both repositories. cmake_path(APPEND CMAKE_CURRENT_SOURCE_DIR "python" "examples" "test_core.py" OUTPUT_VARIABLE TRITON_SHARED_TEST_CORE) cmake_path(APPEND CMAKE_CURRENT_SOURCE_DIR "python" "examples" "test_annotations.py" OUTPUT_VARIABLE TRITON_SHARED_TEST_ANNOTATIONS) +cmake_path(APPEND CMAKE_CURRENT_SOURCE_DIR ".clang-format" OUTPUT_VARIABLE TRITON_SHARED_CLANG_FORMAT_SETTING) cmake_path(APPEND CMAKE_SOURCE_DIR "python" "test" "unit" "language" "test_core.py" OUTPUT_VARIABLE TRITON_TEST_CORE) cmake_path(APPEND CMAKE_SOURCE_DIR "python" "test" "unit" "language" "test_annotations.py" OUTPUT_VARIABLE TRITON_TEST_ANNOTATIONS) +cmake_path(APPEND CMAKE_SOURCE_DIR ".clang-format" OUTPUT_VARIABLE TRITON_CLANG_FORMAT_SETTING) add_symlink(${TRITON_TEST_CORE} ${TRITON_SHARED_TEST_CORE}) add_symlink(${TRITON_TEST_ANNOTATIONS} ${TRITON_SHARED_TEST_ANNOTATIONS}) +add_symlink(${TRITON_CLANG_FORMAT_SETTING} ${TRITON_SHARED_CLANG_FORMAT_SETTING}) \ No newline at end of file diff --git a/backend/compiler.py b/backend/compiler.py index 9efdb48c..55028c92 100644 --- a/backend/compiler.py +++ b/backend/compiler.py @@ -1,5 +1,5 @@ from triton.backends.compiler import BaseBackend, GPUTarget -from triton._C.libtriton import ir, passes +from triton._C.libtriton import ir, passes, triton_shared from dataclasses import dataclass from typing import Any, Dict, Tuple from types import ModuleType @@ -42,7 +42,7 @@ def _get_sanitizer_type(): if sanitizer_type != "" and sanitizer_type != "asan" and sanitizer_type != "tsan": # throw error raise Exception(f"TRITON_SHARED_SANITIZER_TYPE {sanitizer_type} is invalid.") - + return sanitizer_type def _ttir_to_ttsharedir(mod): @@ -78,41 +78,47 @@ def _ttsharedir_to_llir(ttsharedir: str): llmlir_path = os.path.join(tmpdir, "ll.mlir") llir_path = os.path.join(tmpdir, "ll.ir") Path(ttshared_path).write_text(ttsharedir) - mlir_opt_path = _get_llvm_bin_path("mlir-opt") + context = ir.context() + triton_shared.ir.load_dialects(context) + mod = ir.parse_mlir_module(ttshared_path, context) # TritonShared-MLIR to LLVM-MLIR - subprocess.check_call([mlir_opt_path, ttshared_path, - "--convert-linalg-to-affine-loops", - # Note: eliminate-empty-tensors fails when there are multiple func.return ops - # in a single kernel which are the results of early returns. - # See python/examples/test_early_return.py for examples. - # We disable this pass for now since performance on CPU isn't the main - # focus at the moment. - # "--eliminate-empty-tensors", - "--empty-tensor-to-alloc-tensor", - "--one-shot-bufferize=allow-return-allocs-from-loops=true", - "--lower-affine", - "--convert-linalg-to-loops", - "--expand-strided-metadata", - "--convert-scf-to-cf", - "--convert-arith-to-llvm", - "--convert-math-to-llvm", - "--convert-complex-to-llvm", - "--convert-vector-to-llvm", - "--convert-index-to-llvm", - "--memref-expand", - "--finalize-memref-to-llvm", - "--convert-func-to-llvm", - "--convert-cf-to-llvm", - # Lowering memrefs creates more affine.apply ops. - # Lowering these affine ops again creates further arith ops, - # so we have to run these two passes again here. - "--lower-affine", - "--convert-arith-to-llvm", - # Remove all unrealized casts created - "--reconcile-unrealized-casts", - "--mlir-print-debuginfo", - "-o", - llmlir_path]) + + pm = ir.pass_manager(context) + pm.enable_debug() + triton_shared.to_llir.add_convert_linalg_to_affine_loops(pm) + # Note: eliminate-empty-tensors fails when there are multiple func.return ops + # in a single kernel which are the results of early returns. + # See python/examples/test_early_return.py for examples. + # We disable this pass for now since performance on CPU isn't the main + # focus at the moment. + # triton_shared.to_llir.add_eliminate_empty_tensors(pm) + triton_shared.to_llir.add_empty_tensor_to_alloc_tensor(pm) + triton_shared.to_llir.add_one_shot_bufferize_with_options( + pm, allow_return_allocs_from_loops=True) + triton_shared.to_llir.add_lower_affine(pm) + triton_shared.to_llir.add_convert_linalg_to_loops(pm) + triton_shared.to_llir.add_expand_strided_metadata(pm) + triton_shared.to_llir.add_convert_scf_to_cf(pm) + triton_shared.to_llir.add_convert_tptr_to_llvm(pm) + triton_shared.to_llir.add_convert_arith_to_llvm(pm) + triton_shared.to_llir.add_convert_math_to_llvm(pm) + triton_shared.to_llir.add_convert_complex_to_llvm(pm) + triton_shared.to_llir.add_convert_vector_to_llvm(pm) + triton_shared.to_llir.add_convert_index_to_llvm(pm) + triton_shared.to_llir.add_memref_expand(pm) + triton_shared.to_llir.add_finalize_memref_to_llvm(pm) + triton_shared.to_llir.add_convert_func_to_llvm(pm) + triton_shared.to_llir.add_convert_cf_to_llvm(pm) + # Lowering memrefs creates more affine.apply ops. + # Lowering these affine ops again creates further arith ops, + # so we have to run these two passes again here. + triton_shared.to_llir.add_lower_affine(pm) + triton_shared.to_llir.add_convert_arith_to_llvm(pm) + # Remove all unrealized casts created + triton_shared.to_llir.add_reconcile_unrealized_casts(pm) + pm.run(mod) + + Path(llmlir_path).write_text(str(mod)) # LLVM-MLIR to LLVM-IR mlir_translate_path = _get_llvm_bin_path("mlir-translate") @@ -145,7 +151,7 @@ def _llir_to_bin(llir: str, metadata): # using a sanitizer # invoke pass to append sanitizer attributes instrumented_src_path = os.path.join(tmpdir, "kernel-instrumented.ll") - + opt_path = _get_llvm_bin_path("opt") top_level_triton_path = os.path.dirname(triton.__file__) sanitizer_attributes_pass_path = str(next(Path(top_level_triton_path).rglob("libSanitizerAttributes.so"), None)) @@ -153,8 +159,8 @@ def _llir_to_bin(llir: str, metadata): if not sanitizer_attributes_pass_path: raise Exception(f"libSanitizerAttributes.so does not exist.") - subprocess.check_call([opt_path, "-load-pass-plugin", sanitizer_attributes_pass_path, - "-passes=sanitizer-attributes", f"-sanitizer-type={sanitizer_type}", "-S", src_path, + subprocess.check_call([opt_path, "-load-pass-plugin", sanitizer_attributes_pass_path, + "-passes=sanitizer-attributes", f"-sanitizer-type={sanitizer_type}", "-S", src_path, "-o", instrumented_src_path]) # compile to object file @@ -166,12 +172,12 @@ def _llir_to_bin(llir: str, metadata): subprocess_args.extend(["-g", "-fsanitize=address", "-mllvm", "-asan-stack=0"]) elif sanitizer_type == "tsan": subprocess_args.extend(["-g", "-fsanitize=thread"]) - + subprocess.check_call(subprocess_args) else: llc_path = _get_llvm_bin_path("llc") subprocess.check_call([llc_path, src_path, "-filetype=obj", "-relocation-model=pic", "-o", dst_path]) - + return Path(dst_path).read_bytes() @@ -265,7 +271,6 @@ def add_stages(self, stages, options, language): stages["llir"] = lambda src, metadata: _optimize_llir(_ttsharedir_to_llir(src)) stages["obj"] = lambda src, metadata: _llir_to_bin(src, metadata) - @functools.lru_cache() def hash(self): return self.target @@ -273,3 +278,4 @@ def hash(self): # The CPU backend does not use any extra python modules, return an empty dictionary def get_module_map(self) -> Dict[str, ModuleType]: return {} + diff --git a/include/triton-shared/Conversion/CMakeLists.txt b/include/triton-shared/Conversion/CMakeLists.txt index a4a03949..e30b1675 100644 --- a/include/triton-shared/Conversion/CMakeLists.txt +++ b/include/triton-shared/Conversion/CMakeLists.txt @@ -6,3 +6,4 @@ add_subdirectory(TritonPtrToMemref) add_subdirectory(TritonToUnstructured) add_subdirectory(StructuredToMemref) add_subdirectory(UnstructuredToMemref) +add_subdirectory(TPtrToLLVM) diff --git a/include/triton-shared/Conversion/TPtrToLLVM/CMakeLists.txt b/include/triton-shared/Conversion/TPtrToLLVM/CMakeLists.txt new file mode 100644 index 00000000..dde21b4c --- /dev/null +++ b/include/triton-shared/Conversion/TPtrToLLVM/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls --name TPtrToLLVM) +add_public_tablegen_target(TPtrToLLVMConversionPassIncGen) \ No newline at end of file diff --git a/include/triton-shared/Conversion/TPtrToLLVM/Passes.h b/include/triton-shared/Conversion/TPtrToLLVM/Passes.h new file mode 100644 index 00000000..b1c1ab01 --- /dev/null +++ b/include/triton-shared/Conversion/TPtrToLLVM/Passes.h @@ -0,0 +1,15 @@ +#ifndef TPTR_TO_LLVM_CONVERSION_PASSES_H +#define TPTR_TO_LLVM_CONVERSION_PASSES_H + +#include "triton-shared/Conversion/TPtrToLLVM/TPtrToLLVM.h" + +namespace mlir { +namespace tptr { + +#define GEN_PASS_REGISTRATION +#include "triton-shared/Conversion/TPtrToLLVM/Passes.h.inc" + +} // namespace triton +} // namespace mlir + +#endif diff --git a/include/triton-shared/Conversion/TPtrToLLVM/Passes.td b/include/triton-shared/Conversion/TPtrToLLVM/Passes.td new file mode 100644 index 00000000..3380626d --- /dev/null +++ b/include/triton-shared/Conversion/TPtrToLLVM/Passes.td @@ -0,0 +1,10 @@ +#ifndef TPTR_TO_LLVM_CONVERSION_PASSES +#define TPTR_TO_LLVM_CONVERSION_PASSES + +include "mlir/Pass/PassBase.td" + +def TPtrToLLVM : Pass<"tptr-to-llvm", "mlir::ModuleOp"> { + let summary = "Convert Tptr operations into LLVM"; + let dependentDialects = ["mlir::tptr::TPtrDialect", "mlir::LLVM::LLVMDialect", "mlir::ptr::PtrDialect"]; +} +#endif diff --git a/include/triton-shared/Conversion/TPtrToLLVM/TPtrToLLVM.h b/include/triton-shared/Conversion/TPtrToLLVM/TPtrToLLVM.h new file mode 100644 index 00000000..0e9f1e18 --- /dev/null +++ b/include/triton-shared/Conversion/TPtrToLLVM/TPtrToLLVM.h @@ -0,0 +1,23 @@ +#ifndef TRITON_CONVERSION_TPTR_TO_LLVM_TPTRTOLLVM_H +#define TRITON_CONVERSION_TPTR_TO_LLVM_TPTRTOLLVM_H + +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "triton-shared/Dialect/TPtr/IR/TPtrDialect.h" + +namespace mlir { +namespace tptr { + +#define GEN_PASS_DECL +#include "triton-shared/Conversion/TPtrToLLVM/Passes.h.inc" + +void populateTPtrToLLVMConversionPatterns(RewritePatternSet &patterns, + TypeConverter &typeconverter); + +std::unique_ptr> createTPtrToLLVMPass(); + +} // namespace tptr +} // namespace mlir + +#endif diff --git a/lib/Conversion/CMakeLists.txt b/lib/Conversion/CMakeLists.txt index 358b4f92..a08ba81d 100644 --- a/lib/Conversion/CMakeLists.txt +++ b/lib/Conversion/CMakeLists.txt @@ -6,3 +6,4 @@ add_subdirectory(TritonArithToLinalg) add_subdirectory(StructuredToMemref) add_subdirectory(TritonPtrToMemref) add_subdirectory(UnstructuredToMemref) +add_subdirectory(TPtrToLLVM) diff --git a/lib/Conversion/TPtrToLLVM/CMakeLists.txt b/lib/Conversion/TPtrToLLVM/CMakeLists.txt new file mode 100644 index 00000000..c72e97be --- /dev/null +++ b/lib/Conversion/TPtrToLLVM/CMakeLists.txt @@ -0,0 +1,16 @@ +add_triton_library(TPtrToLLVM + TPtrToLLVMPass.cpp + TPtrToLLVM.cpp + + DEPENDS + TPtrToLLVMConversionPassIncGen + TPtrTableGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRPass + MLIRTransforms + MLIRSupport + MLIRDialectUtils + TPtrIR +) diff --git a/lib/Conversion/TPtrToLLVM/TPtrToLLVM.cpp b/lib/Conversion/TPtrToLLVM/TPtrToLLVM.cpp new file mode 100644 index 00000000..cdda0786 --- /dev/null +++ b/lib/Conversion/TPtrToLLVM/TPtrToLLVM.cpp @@ -0,0 +1,646 @@ +#include + +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Ptr/IR/PtrTypes.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Value.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" + +#include "mlir/Transforms/DialectConversion.h" +#include "triton-shared/Conversion/TPtrToLLVM/TPtrToLLVM.h" +#include "triton-shared/Dialect/TPtr/IR/TPtrDialect.h" + +namespace mlir { +namespace tptr { + +#define DEBUG_TYPE "tptr-to-llvm" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +static bool isOneToOneCast(UnrealizedConversionCastOp op) { + return (op.getInputs().size() == 1 && op->getNumResults() == 1); +} + +// PtrAddOp -> llvm.getelementptr conversion +struct PtrAddConverter : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + Type convertPtrPointerType(ptr::PtrType type) const { + auto ctx = type.getContext(); + return LLVM::LLVMPointerType::get(ctx); + } + + LogicalResult + matchAndRewrite(tptr::PtrAddOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + LDBG("matchAndRewrite: ptradd " << op); + + if (!isa(op.getType())) { + return rewriter.notifyMatchFailure(op, "expected ptr type"); + } + auto ptrTy = cast(op.getType()); + + // Infer element type + Type elemTy = nullptr; + auto origOffset = op.getOffset(); + if (auto mulOp = origOffset.getDefiningOp()) { + if (auto typeOffsetOp = + mulOp.getRhs().getDefiningOp()) { + elemTy = typeOffsetOp.getBaseType(); + } + } else if (auto mulOp = origOffset.getDefiningOp()) { + if (auto typeOffsetOp = + mulOp.getRhs().getDefiningOp()) { + elemTy = typeOffsetOp.getBaseType(); + } + } + + if (!elemTy) { + elemTy = rewriter.getIntegerType(8); // default to i8 + } + + Type resTy = convertPtrPointerType(ptrTy); + + Value elementIndex; + if (auto mulOp = adaptor.getOffset().getDefiningOp()) { + elementIndex = mulOp.getLhs(); + } else if (auto mulOp = + adaptor.getOffset().getDefiningOp()) { + elementIndex = mulOp.getLhs(); + } else { + LDBG("Warning: ptradd offset is not MulOp pattern, using raw offset"); + elementIndex = adaptor.getOffset(); + } + + auto gep = rewriter.create(op.getLoc(), resTy, elemTy, + adaptor.getBase(), elementIndex); + rewriter.replaceOp(op, gep); + LDBG("matchAndRewrite: ptradd done " << gep); + return success(); + } +}; + +// ToMemrefOp -> build LLVM memref struct +struct ToMemrefConverter : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(tptr::ToMemrefOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + LDBG("matchAndRewrite: to_memref (before) " << op); + + auto input = adaptor.getArg(); + + // Try to extract real pointer from UnrealizedConversionCast + if (Operation *defOp = input.getDefiningOp()) { + if (auto unrealizedCast = dyn_cast(defOp)) { + if (unrealizedCast.getInputs().size() == 1) { + Value castInput = unrealizedCast.getInputs()[0]; + if (isa(castInput.getType())) { + input = castInput; + } + } + } + } + + Type targetType = + getTypeConverter()->convertType(cast(op.getType())); + LDBG("matchAndRewrite: to_memref (typeconverted) " + << cast(op.getType()) << " -> " << targetType); + if (!targetType) { + return rewriter.notifyMatchFailure(op, "failed to convert memref type"); + } + + auto loc = op.getLoc(); + auto i64Ty = rewriter.getIntegerType(64); + auto shape = cast(op.getType()).getShape(); + auto rank = shape.size(); + + Value result = rewriter.create(loc, targetType); + result = + rewriter.create(loc, result, input, 0); // base_ptr + result = rewriter.create(loc, result, input, + 1); // aligned_ptr + + Value zeroOffset = rewriter.create( + loc, i64Ty, rewriter.getIntegerAttr(i64Ty, 0)); + result = rewriter.create(loc, result, zeroOffset, 2); + + SmallVector strides(rank, 1); + for (int i = rank - 2; i >= 0; --i) { + if (shape[i + 1] != ShapedType::kDynamic) { + strides[i] = strides[i + 1] * shape[i + 1]; + } + } + + for (auto [i, size] : llvm::enumerate(shape)) { + Value sizeVal = rewriter.create( + loc, i64Ty, rewriter.getIntegerAttr(i64Ty, size)); + result = rewriter.create( + loc, result, sizeVal, ArrayRef{3, static_cast(i)}); + + Value strideVal = rewriter.create( + loc, i64Ty, rewriter.getIntegerAttr(i64Ty, strides[i])); + result = rewriter.create( + loc, result, strideVal, + ArrayRef{4, static_cast(i)}); + } + + rewriter.replaceOp(op, result); + LDBG("matchAndRewrite: to_memref (after) -> " << result); + return success(); + } +}; + +// FromMemrefOp -> llvm.extractvalue +struct FromMemrefConverter : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(tptr::FromMemrefOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + LDBG("matchAndRewrite: from_memref (before) " << op); + + Value input = adaptor.getInput(); + // 期望此处的输入已通过 TypeConverter 转换为目标 LLVM 结构体类型 + if (isa(input.getType())) { + return rewriter.notifyMatchFailure(op, + "expected converted memref descriptor"); + } + + // Extract base_ptr (index 0) + Type resultType = LLVM::LLVMPointerType::get(rewriter.getContext()); + auto extractOp = rewriter.create( + op.getLoc(), resultType, input, rewriter.getDenseI64ArrayAttr({0})); + + rewriter.replaceOp(op, extractOp); + LDBG("matchAndRewrite: from_memref (after) -> " << extractOp); + return success(); + } +}; + +// Clean up unused UnrealizedConversionCast +struct UnrealizedCastConverter + : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(UnrealizedConversionCastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!isOneToOneCast(op)) { + return failure(); + } + + auto input = adaptor.getInputs().front(); + auto inputType = input.getType(); + auto outputType = op.getOutputs().front().getType(); + + // Same type, directly remove cast + if (inputType == outputType) { + rewriter.replaceOp(op, input); + return success(); + } + + if (isa(outputType) || + (isa(inputType) && + isa(outputType))) { + LDBG("UnrealizedCast (reject): unsafe pointer conversion " << op); + return rewriter.notifyMatchFailure(op, "unsafe pointer conversion"); + } + + if ((isa(inputType) && isa(outputType)) || + (isa(inputType) && isa(outputType))) { + LDBG("matchAndRewrite: UnrealizedCast (after) " << op << " -> " << input); + rewriter.replaceOp(op, input); + return success(); + } + + return failure(); + } +}; + +// Basic block argument type legalization +static LogicalResult legalizeBlockArguments(Block &block, Operation *op, + PatternRewriter &rewriter, + const TypeConverter &converter) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(&block); + + for (unsigned i = 0; i < block.getNumArguments(); ++i) { + BlockArgument arg = block.getArgument(i); + if (converter.isLegal(arg.getType())) { + continue; + } + + Type newTy = converter.convertType(arg.getType()); + if (!newTy) { + return rewriter.notifyMatchFailure(op, "failed to convert argument type"); + } + + unsigned argNum = arg.getArgNumber(); + Value newArg = block.insertArgument(argNum, newTy, arg.getLoc()); + arg.replaceAllUsesWith(newArg); + block.eraseArgument(argNum + 1); + } + return success(); +} + +// Conditional branch conversion +struct ConvertControlFlowOp : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(cf::CondBranchOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + LDBG("matchAndRewrite: cond_branch (before) " << op); + + if (failed(legalizeBlockArguments(*op.getTrueDest(), op, rewriter, + *getTypeConverter())) || + failed(legalizeBlockArguments(*op.getFalseDest(), op, rewriter, + *getTypeConverter()))) { + return failure(); + } + + auto newOp = rewriter.replaceOpWithNewOp( + op, adaptor.getCondition(), op.getTrueDest(), + adaptor.getTrueDestOperands(), op.getFalseDest(), + adaptor.getFalseDestOperands()); + + LDBG("matchAndRewrite: cond_branch (after) -> " << newOp); + return success(); + } +}; + +// Unconditional branch conversion +struct ConvertBranchOp : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(cf::BranchOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + LDBG("matchAndRewrite: cf.br (before) " << op); + + if (failed(legalizeBlockArguments(*op.getDest(), op, rewriter, + *getTypeConverter()))) { + return failure(); + } + + auto newOp = rewriter.replaceOpWithNewOp( + op, op.getDest(), adaptor.getDestOperands()); + LDBG("matchAndRewrite: cf.br (after) -> " << newOp); + return success(); + } +}; + +// MemRef allocation with pointer element types -> LLVM malloc + struct +// construction +struct MemRefAllocConverter : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + LDBG("matchAndRewrite: memref.alloc (before) " << op); + + auto oldMemRefType = op.getType(); + auto elementType = oldMemRefType.getElementType(); + + // Only handle memref with pointer element types + if (!isa(elementType)) { + return failure(); + } + + // Convert to LLVM struct type + Type llvmStructType = getTypeConverter()->convertType(oldMemRefType); + if (!llvmStructType) { + return rewriter.notifyMatchFailure(op, "failed to convert memref type"); + } + + auto loc = op.getLoc(); + auto ctx = rewriter.getContext(); + auto i64Ty = rewriter.getIntegerType(64); + auto ptrTy = LLVM::LLVMPointerType::get(ctx); + auto shape = oldMemRefType.getShape(); + auto rank = shape.size(); + + // Calculate total size (number of elements * pointer size) + int64_t totalElements = 1; + for (auto dim : shape) { + if (dim == ShapedType::kDynamic) { + return rewriter.notifyMatchFailure(op, + "dynamic shapes not supported yet"); + } + totalElements *= dim; + } + + // Compute total allocation size in bytes = numElements * sizeof(element) + Value numElementsVal = rewriter.create( + loc, i64Ty, rewriter.getIntegerAttr(i64Ty, totalElements)); + + // Query pointer size from DataLayout + DataLayout dl = DataLayout::closest(op); + auto ptrSize = dl.getTypeSize(ptrTy); + if (ptrSize.isScalable()) { + return rewriter.notifyMatchFailure(op, + "scalable pointer size unsupported"); + } + auto fixedPtrSize = static_cast(ptrSize.getFixedValue()); + Value ptrSizeVal = rewriter.create( + loc, i64Ty, rewriter.getIntegerAttr(i64Ty, fixedPtrSize)); + + Value totalBytes = + rewriter.create(loc, numElementsVal, ptrSizeVal); + + // Declare or lookup malloc: ptr (i64) + ModuleOp module = op->getParentOfType(); + auto mallocName = StringRef("malloc"); + LLVM::LLVMFuncOp mallocFunc = + module.lookupSymbol(mallocName); + if (!mallocFunc) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(module.getBody()); + auto mallocType = + LLVM::LLVMFunctionType::get(ptrTy, {i64Ty}, /*isVarArg=*/false); + mallocFunc = + rewriter.create(loc, mallocName, mallocType); + } + + auto mallocCallee = SymbolRefAttr::get(mallocFunc); + Value allocatedPtr = + rewriter + .create(loc, TypeRange{ptrTy}, mallocCallee, + ValueRange{totalBytes}) + .getResult(); + + // Build memref descriptor struct + Value result = rewriter.create(loc, llvmStructType); + result = rewriter.create(loc, result, allocatedPtr, + 0); // base_ptr + result = rewriter.create(loc, result, allocatedPtr, + 1); // aligned_ptr + + Value zeroOffset = rewriter.create( + loc, i64Ty, rewriter.getIntegerAttr(i64Ty, 0)); + result = rewriter.create(loc, result, zeroOffset, 2); + + // Set sizes and strides + SmallVector strides(rank, 1); + for (int i = rank - 2; i >= 0; --i) { + strides[i] = strides[i + 1] * shape[i + 1]; + } + + for (auto [i, size] : llvm::enumerate(shape)) { + Value sizeVal = rewriter.create( + loc, i64Ty, rewriter.getIntegerAttr(i64Ty, size)); + result = rewriter.create( + loc, result, sizeVal, ArrayRef{3, static_cast(i)}); + + Value strideVal = rewriter.create( + loc, i64Ty, rewriter.getIntegerAttr(i64Ty, strides[i])); + result = rewriter.create( + loc, result, strideVal, + ArrayRef{4, static_cast(i)}); + } + + rewriter.replaceOp(op, result); + LDBG("matchAndRewrite: memref.alloc (after) -> " << result); + return success(); + } +}; + +// MemRef store with pointer element types -> LLVM GEP + store +struct MemRefStoreConverter : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + LDBG("matchAndRewrite: memref.store (before) " << op); + + auto memrefType = op.getMemRef().getType(); + if (auto memrefTy = dyn_cast(memrefType)) { + auto elementType = memrefTy.getElementType(); + + // Only handle memref with pointer element types + if (!isa(elementType)) { + return failure(); + } + + auto loc = op.getLoc(); + auto ctx = rewriter.getContext(); + auto ptrTy = LLVM::LLVMPointerType::get(ctx); + auto i64Ty = rewriter.getIntegerType(64); + + // Extract aligned pointer and offset from memref descriptor + // aligned_ptr at index 1, offset at index 2 + Value memrefDescriptor = adaptor.getMemref(); + Value alignedPtr = rewriter.create( + loc, ptrTy, memrefDescriptor, rewriter.getDenseI64ArrayAttr({1})); + Value baseOffset = rewriter.create( + loc, i64Ty, memrefDescriptor, rewriter.getDenseI64ArrayAttr({2})); + + // Calculate linear index from multi-dimensional indices + Value linearIndex = nullptr; + if (adaptor.getIndices().size() == 1) { + // Single dimension case + Value index = adaptor.getIndices()[0]; + // Convert index to i64 if needed + if (index.getType() != i64Ty) { + if (isa(index.getType())) { + index = + rewriter.create(loc, i64Ty, index) + .getResult(0); + } + } + linearIndex = rewriter.create(loc, baseOffset, index); + } else { + // Multi-dimensional: linearIndex = i0*stride0 + i1*stride1 + ... + linearIndex = baseOffset; + + for (auto [i, index] : llvm::enumerate(adaptor.getIndices())) { + // Convert index to i64 if needed + Value convertedIndex = index; + if (index.getType() != i64Ty) { + if (isa(index.getType())) { + convertedIndex = + rewriter.create(loc, i64Ty, index) + .getResult(0); + } + } + + Value stride = rewriter.create( + loc, i64Ty, memrefDescriptor, + rewriter.getDenseI64ArrayAttr({4, static_cast(i)})); + Value contribution = + rewriter.create(loc, convertedIndex, stride); + linearIndex = + rewriter.create(loc, linearIndex, contribution); + } + } + + // GEP to get the address of the element + Value elementPtr = rewriter.create(loc, ptrTy, ptrTy, + alignedPtr, linearIndex); + + // Store the value + auto storeOp = + rewriter.create(loc, adaptor.getValue(), elementPtr); + rewriter.eraseOp(op); + LDBG("matchAndRewrite: memref.store (after) -> " << storeOp); + return success(); + } + + return failure(); + } +}; + +// MemRef load with pointer element types -> LLVM GEP + load +struct MemRefLoadConverter : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(memref::LoadOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + LDBG("matchAndRewrite: memref.load (before) " << op); + + auto memrefType = op.getMemRef().getType(); + if (auto memrefTy = dyn_cast(memrefType)) { + auto elementType = memrefTy.getElementType(); + + // Only handle memref with pointer element types + if (!isa(elementType)) { + return failure(); + } + + // Convert result type through type converter + Type newResultType = getTypeConverter()->convertType(op.getType()); + if (!newResultType) { + return rewriter.notifyMatchFailure(op, "failed to convert result type"); + } + + auto loc = op.getLoc(); + auto ctx = rewriter.getContext(); + auto ptrTy = LLVM::LLVMPointerType::get(ctx); + auto i64Ty = rewriter.getIntegerType(64); + + // Extract aligned pointer and offset from memref descriptor + // aligned_ptr at index 1, offset at index 2 + Value memrefDescriptor = adaptor.getMemref(); + LDBG("memrefDescriptor " << memrefDescriptor); + Value alignedPtr = rewriter.create( + loc, ptrTy, memrefDescriptor, rewriter.getDenseI64ArrayAttr({1})); + LDBG("basePtr " << rewriter.create( + loc, ptrTy, memrefDescriptor, rewriter.getDenseI64ArrayAttr({0}))); + LDBG("alignedPtr " << alignedPtr); + Value baseOffset = rewriter.create( + loc, i64Ty, memrefDescriptor, rewriter.getDenseI64ArrayAttr({2})); + LDBG("baseOffset " << baseOffset); + // Calculate linear index from multi-dimensional indices + Value linearIndex = nullptr; + if (adaptor.getIndices().size() == 1) { + // Single dimension case + Value index = adaptor.getIndices()[0]; + LDBG("if index " << index); + // Convert index to i64 if needed + if (index.getType() != i64Ty) { + if (isa(index.getType())) { + index = + rewriter.create(loc, i64Ty, index) + .getResult(0); + } + } + linearIndex = rewriter.create(loc, baseOffset, index); + } else { + // Multi-dimensional: linearIndex = i0*stride0 + i1*stride1 + ... + linearIndex = baseOffset; + LDBG("else index " << linearIndex); + + for (auto [i, index] : llvm::enumerate(adaptor.getIndices())) { + // Convert index to i64 if needed + Value convertedIndex = index; + if (index.getType() != i64Ty) { + if (isa(index.getType())) { + convertedIndex = + rewriter.create(loc, i64Ty, index) + .getResult(0); + } + } + + Value stride = rewriter.create( + loc, i64Ty, memrefDescriptor, + rewriter.getDenseI64ArrayAttr({4, static_cast(i)})); + Value contribution = + rewriter.create(loc, convertedIndex, stride); + linearIndex = + rewriter.create(loc, linearIndex, contribution); + } + } + + // GEP to get the address of the element + Value elementPtr = rewriter.create(loc, ptrTy, ptrTy, + alignedPtr, linearIndex); + + // Load the value + Value loadedValue = + rewriter.create(loc, newResultType, elementPtr); + rewriter.replaceOp(op, loadedValue); + + LDBG("matchAndRewrite: memref.load (after) -> " << loadedValue); + return success(); + } + + return failure(); + } +}; + +// TypeOffsetOp -> constant conversion +struct TypeOffsetConverter : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + llvm::TypeSize + getTypeSize(tptr::TypeOffsetOp op, + std::optional layout = std::nullopt) const { + if (layout) + return layout->getTypeSize(op.getBaseType()); + DataLayout dl = DataLayout::closest(op); + return dl.getTypeSize(op.getBaseType()); + } + + LogicalResult + matchAndRewrite(tptr::TypeOffsetOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + LDBG("matchAndRewrite: type_offset (before) " << op); + + auto size = getTypeSize(op); + if (size.isScalable()) { + return rewriter.notifyMatchFailure(op, "scalable type size unsupported"); + } + auto fixedSize = static_cast(size.getFixedValue()); + auto constOp = rewriter.create( + op.getLoc(), op.getType(), + rewriter.getIntegerAttr(op.getType(), fixedSize)); + + rewriter.replaceOp(op, constOp); + LDBG("matchAndRewrite: type_offset (after) -> " << constOp); + return success(); + } +}; + +void populateTPtrToLLVMConversionPatterns(RewritePatternSet &patterns, + TypeConverter &typeConverter) { + patterns.add( + typeConverter, patterns.getContext()); +} + +} // namespace tptr +} // namespace mlir diff --git a/lib/Conversion/TPtrToLLVM/TPtrToLLVMPass.cpp b/lib/Conversion/TPtrToLLVM/TPtrToLLVMPass.cpp new file mode 100644 index 00000000..f1dc6857 --- /dev/null +++ b/lib/Conversion/TPtrToLLVM/TPtrToLLVMPass.cpp @@ -0,0 +1,226 @@ +#include + +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Ptr/IR/PtrAttrs.h" +#include "mlir/Dialect/Ptr/IR/PtrDialect.h" +#include "mlir/IR/BuiltinDialect.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" +#include "triton-shared/Conversion/TPtrToLLVM/TPtrToLLVM.h" +#include "triton-shared/Dialect/TPtr/IR/TPtrDialect.h" + +#define DEBUG_TYPE "tptr-to-llvm" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +using namespace mlir; +using namespace tptr; + +namespace mlir { +namespace tptr { +#define GEN_PASS_DEF_TPTRTOLLVM +#include "triton-shared/Conversion/TPtrToLLVM/Passes.h.inc" +} // namespace tptr +} // namespace mlir + +namespace { + +struct TptrToLLVMTypeConverter : TypeConverter { + TptrToLLVMTypeConverter(MLIRContext *ctx) { + addConversion([](Type type) -> Type { return type; }); + + addConversion([&](MemRefType type) -> std::optional { + auto elementType = type.getElementType(); + auto ctx = type.getContext(); + auto rank = type.getShape().size(); + auto i64Ty = IntegerType::get(ctx, 64); + + SmallVector types{ + LLVM::LLVMPointerType::get(ctx), // base_ptr + LLVM::LLVMPointerType::get(ctx), // aligned_ptr + i64Ty, // offset + LLVM::LLVMArrayType::get(ctx, i64Ty, rank), // sizes + LLVM::LLVMArrayType::get(ctx, i64Ty, rank)}; // strides + + return LLVM::LLVMStructType::getLiteral(ctx, types); + }); + addTypeAttributeConversion( + [&](PtrLikeTypeInterface type, ptr::GenericSpaceAttr memorySpace) + -> TypeConverter::AttributeConversionResult { + if (type.getMemorySpace() != memorySpace) + return TypeConverter::AttributeConversionResult::na(); + return IntegerAttr::get(IntegerType::get(type.getContext(), 32), 0); + }); + addTypeAttributeConversion( + [&](PtrLikeTypeInterface type, tptr::DefaultMemorySpaceAttr memorySpace) + -> TypeConverter::AttributeConversionResult { + if (type.getMemorySpace() != memorySpace) + return TypeConverter::AttributeConversionResult::na(); + // Default memory space maps to LLVM addrspace(0). + return IntegerAttr::get(IntegerType::get(type.getContext(), 32), 0); + }); + + // Add type conversions. + addConversion([&](ptr::PtrType type) -> Type { + LDBG("MemorySpace " << type.getMemorySpace()); + std::optional maybeAttr = + convertTypeAttribute(type, type.getMemorySpace()); + auto memSpace = + maybeAttr ? dyn_cast_or_null(*maybeAttr) : IntegerAttr(); + if (!memSpace) { + return {}; + } + return LLVM::LLVMPointerType::get(type.getContext(), + memSpace.getValue().getSExtValue()); + }); + + auto createUnrealizedCast = [](OpBuilder &builder, Type resultType, + ValueRange inputs, Location loc) -> Value { + return builder.create(loc, resultType, inputs) + .getResult(0); + }; + addSourceMaterialization(createUnrealizedCast); + addTargetMaterialization(createUnrealizedCast); + } +}; + +class TPtrToLLVMPass : public tptr::impl::TPtrToLLVMBase { + using TPtrToLLVMBase::TPtrToLLVMBase; + +public: + void runOnOperation() override { + auto moduleOp = getOperation(); + + RewritePatternSet patterns(&getContext()); + ConversionTarget target(getContext()); + TptrToLLVMTypeConverter typeConverter(&getContext()); + + target.addLegalDialect(); + target.addLegalDialect(); + target.addLegalDialect(); + target.addIllegalDialect(); + target.addDynamicallyLegalOp([&](cf::CondBranchOp op) { + for (auto operand : op.getOperands()) { + if (isa(operand.getType())) { + LDBG("CondBranchOp marked illegal due to operand type: " + << operand.getType()); + return false; + } + } + for (auto dest : {op.getTrueDest(), op.getFalseDest()}) { + for (auto arg : dest->getArguments()) { + if (isa(arg.getType())) { + LDBG("CondBranchOp marked illegal due to block arg type: " + << arg.getType()); + return false; + } + } + } + return true; + }); + + target.addDynamicallyLegalOp([&](cf::BranchOp op) { + for (auto operand : op.getOperands()) { + if (isa(operand.getType())) { + LDBG("BranchOp marked illegal due to operand type: " + << operand.getType()); + return false; + } + } + + for (auto arg : op.getDest()->getArguments()) { + if (isa(arg.getType())) { + LDBG("BranchOp marked illegal due to block arg type: " + << arg.getType()); + return false; + } + } + return true; + }); + + target.addDynamicallyLegalOp( + [&](UnrealizedConversionCastOp op) { + for (auto type : op.getResultTypes()) { + if (isa(type)) { + return false; + } + } + for (auto operand : op.getOperands()) { + if (isa(operand.getType())) { + return false; + } + } + return true; + }); + + target.addDynamicallyLegalOp([&](memref::AllocOp op) { + auto memrefType = op.getType(); + auto elementType = memrefType.getElementType(); + if (isa(elementType)) { + LDBG("AllocOp marked illegal due to pointer element type: " + << elementType); + return false; + } + return true; + }); + + target.addDynamicallyLegalOp([&](memref::StoreOp op) { + auto memrefType = op.getMemRef().getType(); + if (auto memrefTy = dyn_cast(memrefType)) { + auto elementType = memrefTy.getElementType(); + if (isa(elementType)) { + LDBG("StoreOp marked illegal due to pointer element type: " + << elementType); + return false; + } + } + return true; + }); + + target.addDynamicallyLegalOp([&](memref::LoadOp op) { + auto memrefType = op.getMemRef().getType(); + if (auto memrefTy = dyn_cast(memrefType)) { + auto elementType = memrefTy.getElementType(); + if (isa(elementType)) { + LDBG("LoadOp marked illegal due to pointer element type: " + << elementType); + return false; + } + } + return true; + }); + + target.addLegalOp(); + + populateTPtrToLLVMConversionPatterns(patterns, typeConverter); + if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) { + signalPassFailure(); + return; + } + LDBG("runOnOperation: after conversion\n" << moduleOp); + + { + mlir::PassManager pm(&getContext(), getOperation().getOperationName()); + pm.addPass(mlir::createCanonicalizerPass()); + pm.addPass(mlir::createCSEPass()); + if (failed(runPipeline(pm, getOperation()))) { + signalPassFailure(); + return; + } + } + + LDBG("runOnOperation: done\n" << moduleOp); + } +}; + +} // namespace + +std::unique_ptr> tptr::createTPtrToLLVMPass() { + return std::make_unique(); +} diff --git a/test/Conversion/TPtrToLLVM/from_bitcast.mlir b/test/Conversion/TPtrToLLVM/from_bitcast.mlir new file mode 100644 index 00000000..08237bbb --- /dev/null +++ b/test/Conversion/TPtrToLLVM/from_bitcast.mlir @@ -0,0 +1,141 @@ +// RUN: triton-shared-opt --tptr-to-llvm %s | FileCheck %s + +module { + func.func @bitcast_ptr_as_src(%arg0: memref<*xi32>, %arg1: memref<*xi32>, %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32) { + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c0 = arith.constant 0 : index + %0 = tptr.type_offset i64 : i32 + %1 = tptr.type_offset i32 : i32 + %c2_i32 = arith.constant 2 : i32 + %c1_i32 = arith.constant 1 : i32 + %cast = memref.cast %arg1 : memref<*xi32> to memref<1xi32> + %2 = tptr.from_memref %cast : memref<1xi32> to <#tptr.default_memory_space> + %cast_0 = memref.cast %arg0 : memref<*xi32> to memref<1xi32> + %3 = tptr.from_memref %cast_0 : memref<1xi32> to <#tptr.default_memory_space> + %alloc = memref.alloc() {alignment = 64 : i64} : memref<16xi32> + %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<16xi32> + cf.br ^bb1(%c0 : index) + ^bb1(%4: index): // 2 preds: ^bb0, ^bb2 + %5 = arith.cmpi slt, %4, %c16 : index + cf.cond_br %5, ^bb2, ^bb3 + ^bb2: // pred: ^bb1 + memref.store %c2_i32, %alloc_1[%4] : memref<16xi32> + %6 = arith.addi %4, %c1 : index + cf.br ^bb1(%6 : index) + ^bb3: // pred: ^bb1 + %7 = arith.muli %c1_i32, %1 : i32 + %8 = tptr.ptradd %3 %7 : <#tptr.default_memory_space>, i32 to <#tptr.default_memory_space> + cf.br ^bb4(%c0 : index) + ^bb4(%9: index): // 2 preds: ^bb3, ^bb5 + %10 = arith.cmpi slt, %9, %c16 : index + cf.cond_br %10, ^bb5, ^bb6 + ^bb5: // pred: ^bb4 + %11 = arith.index_cast %9 : index to i32 + memref.store %11, %alloc[%9] : memref<16xi32> + %12 = arith.addi %9, %c1 : index + cf.br ^bb4(%12 : index) + ^bb6: // pred: ^bb4 + %alloc_2 = memref.alloc() {alignment = 64 : i64} : memref<16x!ptr.ptr<#tptr.default_memory_space>> + cf.br ^bb7(%c0 : index) + ^bb7(%13: index): // 2 preds: ^bb6, ^bb8 + %14 = arith.cmpi slt, %13, %c16 : index + cf.cond_br %14, ^bb8, ^bb9 + ^bb8: // pred: ^bb7 + memref.store %8, %alloc_2[%13] : memref<16x!ptr.ptr<#tptr.default_memory_space>> + %15 = arith.addi %13, %c1 : index + cf.br ^bb7(%15 : index) + ^bb9: // pred: ^bb7 + cf.br ^bb10(%c0 : index) + ^bb10(%16: index): // 2 preds: ^bb9, ^bb11 + %17 = arith.cmpi slt, %16, %c16 : index + cf.cond_br %17, ^bb11, ^bb12 + ^bb11: // pred: ^bb10 + %18 = memref.load %alloc_2[%16] : memref<16x!ptr.ptr<#tptr.default_memory_space>> + %19 = memref.load %alloc[%16] : memref<16xi32> + %20 = arith.muli %19, %0 : i32 + %21 = tptr.ptradd %18 %20 : <#tptr.default_memory_space>, i32 to <#tptr.default_memory_space> + memref.store %21, %alloc_2[%16] : memref<16x!ptr.ptr<#tptr.default_memory_space>> + %22 = arith.addi %16, %c1 : index + cf.br ^bb10(%22 : index) + ^bb12: // pred: ^bb10 + cf.br ^bb13(%c0 : index) + ^bb13(%23: index): // 2 preds: ^bb12, ^bb14 + %24 = arith.cmpi slt, %23, %c16 : index + cf.cond_br %24, ^bb14, ^bb15 + ^bb14: // pred: ^bb13 + %25 = memref.load %alloc_2[%23] : memref<16x!ptr.ptr<#tptr.default_memory_space>> + %26 = memref.load %alloc_1[%23] : memref<16xi32> + %27 = arith.muli %26, %0 : i32 + %28 = tptr.ptradd %25 %27 : <#tptr.default_memory_space>, i32 to <#tptr.default_memory_space> + memref.store %28, %alloc_2[%23] : memref<16x!ptr.ptr<#tptr.default_memory_space>> + %29 = arith.addi %23, %c1 : index + cf.br ^bb13(%29 : index) + ^bb15: // pred: ^bb13 + %alloc_3 = memref.alloc() {alignment = 64 : i64} : memref<16xi64> + cf.br ^bb16(%c0 : index) + ^bb16(%30: index): // 2 preds: ^bb15, ^bb17 + %31 = arith.cmpi slt, %30, %c16 : index + cf.cond_br %31, ^bb17, ^bb18 + ^bb17: // pred: ^bb16 + %32 = memref.load %alloc_2[%30] : memref<16x!ptr.ptr<#tptr.default_memory_space>> + %33 = tptr.to_memref %32 : <#tptr.default_memory_space> to memref<1xi64> + %34 = memref.load %33[%c0] : memref<1xi64> + memref.store %34, %alloc_3[%30] : memref<16xi64> + %35 = arith.addi %30, %c1 : index + cf.br ^bb16(%35 : index) + ^bb18: // pred: ^bb16 + %36 = tptr.ptradd %2 %7 : <#tptr.default_memory_space>, i32 to <#tptr.default_memory_space> + cf.br ^bb19(%c0 : index) + ^bb19(%37: index): // 2 preds: ^bb18, ^bb20 + %38 = arith.cmpi slt, %37, %c16 : index + cf.cond_br %38, ^bb20, ^bb21 + ^bb20: // pred: ^bb19 + memref.store %36, %alloc_2[%37] : memref<16x!ptr.ptr<#tptr.default_memory_space>> + %39 = arith.addi %37, %c1 : index + cf.br ^bb19(%39 : index) + ^bb21: // pred: ^bb19 + cf.br ^bb22(%c0 : index) + ^bb22(%40: index): // 2 preds: ^bb21, ^bb23 + %41 = arith.cmpi slt, %40, %c16 : index + cf.cond_br %41, ^bb23, ^bb24 + ^bb23: // pred: ^bb22 + %42 = memref.load %alloc_2[%40] : memref<16x!ptr.ptr<#tptr.default_memory_space>> + %43 = memref.load %alloc[%40] : memref<16xi32> + %44 = arith.muli %43, %0 : i32 + %45 = tptr.ptradd %42 %44 : <#tptr.default_memory_space>, i32 to <#tptr.default_memory_space> + memref.store %45, %alloc_2[%40] : memref<16x!ptr.ptr<#tptr.default_memory_space>> + %46 = arith.addi %40, %c1 : index + cf.br ^bb22(%46 : index) + ^bb24: // pred: ^bb22 + cf.br ^bb25(%c0 : index) + ^bb25(%47: index): // 2 preds: ^bb24, ^bb26 + %48 = arith.cmpi slt, %47, %c16 : index + cf.cond_br %48, ^bb26, ^bb27 + ^bb26: // pred: ^bb25 + %49 = memref.load %alloc_2[%47] : memref<16x!ptr.ptr<#tptr.default_memory_space>> + %50 = memref.load %alloc_1[%47] : memref<16xi32> + %51 = arith.muli %50, %0 : i32 + %52 = tptr.ptradd %49 %51 : <#tptr.default_memory_space>, i32 to <#tptr.default_memory_space> + memref.store %52, %alloc_2[%47] : memref<16x!ptr.ptr<#tptr.default_memory_space>> + %53 = arith.addi %47, %c1 : index + cf.br ^bb25(%53 : index) + ^bb27: // pred: ^bb25 + cf.br ^bb28(%c0 : index) + ^bb28(%54: index): // 2 preds: ^bb27, ^bb29 + %55 = arith.cmpi slt, %54, %c16 : index + cf.cond_br %55, ^bb29, ^bb30 + ^bb29: // pred: ^bb28 + %56 = memref.load %alloc_2[%54] : memref<16x!ptr.ptr<#tptr.default_memory_space>> + %57 = memref.load %alloc_3[%54] : memref<16xi64> + %58 = tptr.to_memref %56 : <#tptr.default_memory_space> to memref<1xi64> + memref.store %57, %58[%c0] : memref<1xi64> + %59 = arith.addi %54, %c1 : index + cf.br ^bb28(%59 : index) + ^bb30: // pred: ^bb28 + return + } +} + +// CHECK-NOT: tptr. +// CHECK-NOT: ptr.ptr \ No newline at end of file diff --git a/test/Conversion/TPtrToLLVM/ptr_type_in_memref.mlir b/test/Conversion/TPtrToLLVM/ptr_type_in_memref.mlir new file mode 100644 index 00000000..c9c0e61d --- /dev/null +++ b/test/Conversion/TPtrToLLVM/ptr_type_in_memref.mlir @@ -0,0 +1,683 @@ +// RUN: triton-shared-opt --tptr-to-llvm %s | FileCheck %s + +module { + func.func @_fwd_grouped_kernel_stage1(%arg0: memref<*xf32> {tt.divisibility = 16 : i32}, %arg1: memref<*xf32> {tt.divisibility = 16 : i32}, %arg2: memref<*xf32> {tt.divisibility = 16 : i32}, %arg3: f32, %arg4: memref<*xi32> {tt.divisibility = 16 : i32}, %arg5: memref<*xi32> {tt.divisibility = 16 : i32}, %arg6: memref<*xf32> {tt.divisibility = 16 : i32}, %arg7: memref<*xf32> {tt.divisibility = 16 : i32}, %arg8: memref<*xf32> {tt.divisibility = 16 : i32}, %arg9: i32, %arg10: i32 {tt.divisibility = 16 : i32}, %arg11: i32 {tt.divisibility = 16 : i32}, %arg12: i32 {tt.divisibility = 16 : i32}, %arg13: i32 {tt.divisibility = 16 : i32}, %arg14: i32 {tt.divisibility = 16 : i32}, %arg15: i32 {tt.divisibility = 16 : i32}, %arg16: i32 {tt.divisibility = 16 : i32}, %arg17: i32, %arg18: i32, %arg19: i32 {tt.divisibility = 16 : i32}, %arg20: i32 {tt.divisibility = 16 : i32}, %arg21: i32, %arg22: i32, %arg23: i32, %arg24: i32, %arg25: i32, %arg26: i32) { + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c1 = arith.constant 1 : index + %0 = tptr.type_offset f32 : i32 + %c0 = arith.constant 0 : index + %1 = tptr.type_offset i32 : i32 + %c32_i32 = arith.constant 32 : i32 + %c2_i32 = arith.constant 2 : i32 + %c4 = arith.constant 4 : index + %c8 = arith.constant 8 : index + %cst = arith.constant 0.000000e+00 : f32 + %c0_i32 = arith.constant 0 : i32 + %c16_i32 = arith.constant 16 : i32 + %c64_i32 = arith.constant 64 : i32 + %c1_i32 = arith.constant 1 : i32 + %c4_i32 = arith.constant 4 : i32 + %cast = memref.cast %arg5 : memref<*xi32> to memref<1xi32> + %2 = tptr.from_memref %cast : memref<1xi32> to <#tptr.default_memory_space> + %cast_0 = memref.cast %arg4 : memref<*xi32> to memref<1xi32> + %3 = tptr.from_memref %cast_0 : memref<1xi32> to <#tptr.default_memory_space> + %cast_1 = memref.cast %arg1 : memref<*xf32> to memref<1xf32> + %4 = tptr.from_memref %cast_1 : memref<1xf32> to <#tptr.default_memory_space> + %alloc = memref.alloc() {alignment = 64 : i64} : memref<64x32xf32> + %alloc_2 = memref.alloc() {alignment = 64 : i64} : memref<64x32xf32> + cf.br ^bb1(%c0 : index) + ^bb1(%5: index): // 2 preds: ^bb0, ^bb5 + %6 = arith.cmpi slt, %5, %c64 : index + cf.cond_br %6, ^bb2, ^bb6 + ^bb2: // pred: ^bb1 + cf.br ^bb3(%c0 : index) + ^bb3(%7: index): // 2 preds: ^bb2, ^bb4 + %8 = arith.cmpi slt, %7, %c32 : index + cf.cond_br %8, ^bb4, ^bb5 + ^bb4: // pred: ^bb3 + memref.store %cst, %alloc_2[%5, %7] : memref<64x32xf32> + %9 = arith.addi %7, %c1 : index + cf.br ^bb3(%9 : index) + ^bb5: // pred: ^bb3 + %10 = arith.addi %5, %c1 : index + cf.br ^bb1(%10 : index) + ^bb6: // pred: ^bb1 + %alloc_3 = memref.alloc() {alignment = 64 : i64} : memref<32xi32> + %alloc_4 = memref.alloc() {alignment = 64 : i64} : memref<32xi32> + cf.br ^bb7(%c0 : index) + ^bb7(%11: index): // 2 preds: ^bb6, ^bb8 + %12 = arith.cmpi slt, %11, %c32 : index + cf.cond_br %12, ^bb8, ^bb9 + ^bb8: // pred: ^bb7 + memref.store %c0_i32, %alloc_4[%11] : memref<32xi32> + %13 = arith.addi %11, %c1 : index + cf.br ^bb7(%13 : index) + ^bb9: // pred: ^bb7 + %alloc_5 = memref.alloc() {alignment = 64 : i64} : memref<32xi32> + cf.br ^bb10(%c0 : index) + ^bb10(%14: index): // 2 preds: ^bb9, ^bb11 + %15 = arith.cmpi slt, %14, %c32 : index + cf.cond_br %15, ^bb11, ^bb12 + ^bb11: // pred: ^bb10 + memref.store %c16_i32, %alloc_5[%14] : memref<32xi32> + %16 = arith.addi %14, %c1 : index + cf.br ^bb10(%16 : index) + ^bb12: // pred: ^bb10 + %alloc_6 = memref.alloc() {alignment = 64 : i64} : memref<64xi32> + %alloc_7 = memref.alloc() {alignment = 64 : i64} : memref<64xi32> + cf.br ^bb13(%c0 : index) + ^bb13(%17: index): // 2 preds: ^bb12, ^bb14 + %18 = arith.cmpi slt, %17, %c64 : index + cf.cond_br %18, ^bb14, ^bb15 + ^bb14: // pred: ^bb13 + memref.store %c64_i32, %alloc_7[%17] : memref<64xi32> + %19 = arith.addi %17, %c1 : index + cf.br ^bb13(%19 : index) + ^bb15: // pred: ^bb13 + %20 = arith.muli %arg25, %c4_i32 : i32 + %21 = arith.index_cast %20 : i32 to index + %22 = arith.addi %arg25, %c1_i32 : i32 + %23 = arith.muli %22, %c4_i32 : i32 + cf.br ^bb16(%c0 : index) + ^bb16(%24: index): // 2 preds: ^bb15, ^bb17 + %25 = arith.cmpi slt, %24, %c64 : index + cf.cond_br %25, ^bb17, ^bb18 + ^bb17: // pred: ^bb16 + %26 = arith.index_cast %24 : index to i32 + memref.store %26, %alloc_6[%24] : memref<64xi32> + %27 = arith.addi %24, %c1 : index + cf.br ^bb16(%27 : index) + ^bb18: // pred: ^bb16 + %alloc_8 = memref.alloc() {alignment = 64 : i64} : memref<64xi1> + cf.br ^bb19(%c0 : index) + ^bb19(%28: index): // 2 preds: ^bb18, ^bb20 + %29 = arith.cmpi slt, %28, %c64 : index + cf.cond_br %29, ^bb20, ^bb21 + ^bb20: // pred: ^bb19 + %30 = memref.load %alloc_6[%28] : memref<64xi32> + %31 = memref.load %alloc_7[%28] : memref<64xi32> + %32 = arith.cmpi slt, %30, %31 : i32 + memref.store %32, %alloc_8[%28] : memref<64xi1> + %33 = arith.addi %28, %c1 : index + cf.br ^bb19(%33 : index) + ^bb21: // pred: ^bb19 + %34 = arith.muli %arg24, %1 : i32 + %35 = tptr.ptradd %2 %34 : <#tptr.default_memory_space>, i32 to <#tptr.default_memory_space> + %36 = tptr.to_memref %35 : <#tptr.default_memory_space> to memref<1xi32> + %37 = memref.load %36[%c0] : memref<1xi32> + %38 = arith.muli %arg24, %arg10 : i32 + %39 = arith.index_cast %38 : i32 to index + %40 = arith.index_cast %arg11 : i32 to index + %41 = arith.muli %21, %40 : index + %42 = arith.addi %39, %41 : index + %reinterpret_cast = memref.reinterpret_cast %arg0 to offset: [%42], sizes: [4, 64], strides: [%40, 1] : memref<*xf32> to memref<4x64xf32, strided<[?, 1], offset: ?>> + %43 = arith.addi %21, %c4 : index + %44 = arith.index_cast %23 : i32 to index + %45 = arith.minsi %43, %44 : index + %46 = arith.maxsi %45, %21 : index + %47 = arith.subi %46, %21 : index + %48 = arith.minsi %43, %c8 : index + %49 = arith.maxsi %48, %21 : index + %50 = arith.subi %49, %21 : index + %51 = arith.minsi %47, %50 : index + %52 = arith.minsi %51, %c4 : index + %alloc_9 = memref.alloc() : memref<4x64xf32> + %53 = arith.cmpi slt, %52, %c4 : index + cf.cond_br %53, ^bb22, ^bb29 + ^bb22: // pred: ^bb21 + cf.br ^bb23(%c0 : index) + ^bb23(%54: index): // 2 preds: ^bb22, ^bb27 + %55 = arith.cmpi slt, %54, %c4 : index + cf.cond_br %55, ^bb24, ^bb28 + ^bb24: // pred: ^bb23 + cf.br ^bb25(%c0 : index) + ^bb25(%56: index): // 2 preds: ^bb24, ^bb26 + %57 = arith.cmpi slt, %56, %c64 : index + cf.cond_br %57, ^bb26, ^bb27 + ^bb26: // pred: ^bb25 + memref.store %cst, %alloc_9[%54, %56] : memref<4x64xf32> + %58 = arith.addi %56, %c1 : index + cf.br ^bb25(%58 : index) + ^bb27: // pred: ^bb25 + %59 = arith.addi %54, %c1 : index + cf.br ^bb23(%59 : index) + ^bb28: // pred: ^bb23 + cf.br ^bb29 + ^bb29: // 2 preds: ^bb21, ^bb28 + %base_buffer, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %reinterpret_cast : memref<4x64xf32, strided<[?, 1], offset: ?>> -> memref, index, index, index, index, index + %reinterpret_cast_10 = memref.reinterpret_cast %base_buffer to offset: [%offset], sizes: [%52, 64], strides: [%strides#0, 1] : memref to memref> + %reinterpret_cast_11 = memref.reinterpret_cast %alloc_9 to offset: [0], sizes: [%52, 64], strides: [64, 1] : memref<4x64xf32> to memref> + memref.copy %reinterpret_cast_10, %reinterpret_cast_11 : memref> to memref> + %60 = arith.muli %arg24, %arg19 : i32 + %61 = arith.index_cast %60 : i32 to index + %62 = arith.index_cast %arg20 : i32 to index + %63 = arith.muli %21, %62 : index + %64 = arith.addi %61, %63 : index + %reinterpret_cast_12 = memref.reinterpret_cast %arg7 to offset: [%64], sizes: [4, 64], strides: [%62, 1] : memref<*xf32> to memref<4x64xf32, strided<[?, 1], offset: ?>> + %reinterpret_cast_13 = memref.reinterpret_cast %alloc_9 to offset: [0], sizes: [%52, 64], strides: [64, 1] : memref<4x64xf32> to memref> + %base_buffer_14, %offset_15, %sizes_16:2, %strides_17:2 = memref.extract_strided_metadata %reinterpret_cast_12 : memref<4x64xf32, strided<[?, 1], offset: ?>> -> memref, index, index, index, index, index + %reinterpret_cast_18 = memref.reinterpret_cast %base_buffer_14 to offset: [%offset_15], sizes: [%52, 64], strides: [%strides_17#0, 1] : memref to memref> + memref.copy %reinterpret_cast_13, %reinterpret_cast_18 : memref> to memref> + %65 = arith.addi %37, %c1_i32 : i32 + %66 = arith.divsi %65, %c2_i32 : i32 + %67 = arith.muli %66, %arg26 : i32 + %68 = arith.addi %67, %66 : i32 + %69 = arith.minsi %68, %37 : i32 + %70 = arith.cmpi sgt, %69, %67 : i32 + cf.cond_br %70, ^bb30, ^bb174 + ^bb30: // pred: ^bb29 + %alloc_19 = memref.alloc() {alignment = 64 : i64} : memref<32xi32> + cf.br ^bb31(%c0 : index) + ^bb31(%71: index): // 2 preds: ^bb30, ^bb32 + %72 = arith.cmpi slt, %71, %c32 : index + cf.cond_br %72, ^bb32, ^bb33 + ^bb32: // pred: ^bb31 + %73 = arith.index_cast %71 : index to i32 + memref.store %73, %alloc_19[%71] : memref<32xi32> + %74 = arith.addi %71, %c1 : index + cf.br ^bb31(%74 : index) + ^bb33: // pred: ^bb31 + %alloc_20 = memref.alloc() {alignment = 64 : i64} : memref<32xi32> + cf.br ^bb34(%c0 : index) + ^bb34(%75: index): // 2 preds: ^bb33, ^bb35 + %76 = arith.cmpi slt, %75, %c32 : index + cf.cond_br %76, ^bb35, ^bb36 + ^bb35: // pred: ^bb34 + memref.store %69, %alloc_20[%75] : memref<32xi32> + %77 = arith.addi %75, %c1 : index + cf.br ^bb34(%77 : index) + ^bb36: // pred: ^bb34 + %78 = arith.muli %arg9, %arg24 : i32 + %79 = arith.muli %78, %1 : i32 + %80 = tptr.ptradd %3 %79 : <#tptr.default_memory_space>, i32 to <#tptr.default_memory_space> + %alloc_21 = memref.alloc() {alignment = 64 : i64} : memref<32x!ptr.ptr<#tptr.default_memory_space>> + cf.br ^bb37(%c0 : index) + ^bb37(%81: index): // 2 preds: ^bb36, ^bb38 + %82 = arith.cmpi slt, %81, %c32 : index + cf.cond_br %82, ^bb38, ^bb39 + ^bb38: // pred: ^bb37 + memref.store %80, %alloc_21[%81] : memref<32x!ptr.ptr<#tptr.default_memory_space>> + %83 = arith.addi %81, %c1 : index + cf.br ^bb37(%83 : index) + ^bb39: // pred: ^bb37 + %alloc_22 = memref.alloc() {alignment = 64 : i64} : memref<1x32xi32> + %alloc_23 = memref.alloc() {alignment = 64 : i64} : memref<1x32xi32> + cf.br ^bb40(%c0 : index) + ^bb40(%84: index): // 2 preds: ^bb39, ^bb44 + %85 = arith.cmpi slt, %84, %c1 : index + cf.cond_br %85, ^bb41, ^bb45 + ^bb41: // pred: ^bb40 + cf.br ^bb42(%c0 : index) + ^bb42(%86: index): // 2 preds: ^bb41, ^bb43 + %87 = arith.cmpi slt, %86, %c32 : index + cf.cond_br %87, ^bb43, ^bb44 + ^bb43: // pred: ^bb42 + memref.store %arg12, %alloc_23[%84, %86] : memref<1x32xi32> + %88 = arith.addi %86, %c1 : index + cf.br ^bb42(%88 : index) + ^bb44: // pred: ^bb42 + %89 = arith.addi %84, %c1 : index + cf.br ^bb40(%89 : index) + ^bb45: // pred: ^bb40 + %90 = arith.muli %arg25, %arg13 : i32 + %alloc_24 = memref.alloc() {alignment = 64 : i64} : memref<1x32xi32> + cf.br ^bb46(%c0 : index) + ^bb46(%91: index): // 2 preds: ^bb45, ^bb50 + %92 = arith.cmpi slt, %91, %c1 : index + cf.cond_br %92, ^bb47, ^bb51 + ^bb47: // pred: ^bb46 + cf.br ^bb48(%c0 : index) + ^bb48(%93: index): // 2 preds: ^bb47, ^bb49 + %94 = arith.cmpi slt, %93, %c32 : index + cf.cond_br %94, ^bb49, ^bb50 + ^bb49: // pred: ^bb48 + memref.store %90, %alloc_24[%91, %93] : memref<1x32xi32> + %95 = arith.addi %93, %c1 : index + cf.br ^bb48(%95 : index) + ^bb50: // pred: ^bb48 + %96 = arith.addi %91, %c1 : index + cf.br ^bb46(%96 : index) + ^bb51: // pred: ^bb46 + %reinterpret_cast_25 = memref.reinterpret_cast %alloc_6 to offset: [0], sizes: [64, 1], strides: [1, 1] : memref<64xi32> to memref<64x1xi32> + %alloc_26 = memref.alloc() {alignment = 64 : i64} : memref<64x32xi32> + %alloc_27 = memref.alloc() {alignment = 64 : i64} : memref<64x32xi32> + cf.br ^bb52(%c0 : index) + ^bb52(%97: index): // 2 preds: ^bb51, ^bb56 + %98 = arith.cmpi slt, %97, %c64 : index + cf.cond_br %98, ^bb53, ^bb57 + ^bb53: // pred: ^bb52 + cf.br ^bb54(%c0 : index) + ^bb54(%99: index): // 2 preds: ^bb53, ^bb55 + %100 = arith.cmpi slt, %99, %c32 : index + cf.cond_br %100, ^bb55, ^bb56 + ^bb55: // pred: ^bb54 + %101 = memref.load %reinterpret_cast_25[%97, %c0] : memref<64x1xi32> + memref.store %101, %alloc_27[%97, %99] : memref<64x32xi32> + %102 = arith.addi %99, %c1 : index + cf.br ^bb54(%102 : index) + ^bb56: // pred: ^bb54 + %103 = arith.addi %97, %c1 : index + cf.br ^bb52(%103 : index) + ^bb57: // pred: ^bb52 + cf.br ^bb58(%c0 : index) + ^bb58(%104: index): // 2 preds: ^bb57, ^bb62 + %105 = arith.cmpi slt, %104, %c1 : index + cf.cond_br %105, ^bb59, ^bb63 + ^bb59: // pred: ^bb58 + cf.br ^bb60(%c0 : index) + ^bb60(%106: index): // 2 preds: ^bb59, ^bb61 + %107 = arith.cmpi slt, %106, %c32 : index + cf.cond_br %107, ^bb61, ^bb62 + ^bb61: // pred: ^bb60 + memref.store %69, %alloc_22[%104, %106] : memref<1x32xi32> + %108 = arith.addi %106, %c1 : index + cf.br ^bb60(%108 : index) + ^bb62: // pred: ^bb60 + %109 = arith.addi %104, %c1 : index + cf.br ^bb58(%109 : index) + ^bb63: // pred: ^bb58 + %reinterpret_cast_28 = memref.reinterpret_cast %alloc_8 to offset: [0], sizes: [64, 1], strides: [1, 1] : memref<64xi1> to memref<64x1xi1> + %alloc_29 = memref.alloc() {alignment = 64 : i64} : memref<64x32xi1> + %alloc_30 = memref.alloc() {alignment = 64 : i64} : memref<64x32xi1> + cf.br ^bb64(%c0 : index) + ^bb64(%110: index): // 2 preds: ^bb63, ^bb68 + %111 = arith.cmpi slt, %110, %c64 : index + cf.cond_br %111, ^bb65, ^bb69 + ^bb65: // pred: ^bb64 + cf.br ^bb66(%c0 : index) + ^bb66(%112: index): // 2 preds: ^bb65, ^bb67 + %113 = arith.cmpi slt, %112, %c32 : index + cf.cond_br %113, ^bb67, ^bb68 + ^bb67: // pred: ^bb66 + %114 = memref.load %reinterpret_cast_28[%110, %c0] : memref<64x1xi1> + memref.store %114, %alloc_30[%110, %112] : memref<64x32xi1> + %115 = arith.addi %112, %c1 : index + cf.br ^bb66(%115 : index) + ^bb68: // pred: ^bb66 + %116 = arith.addi %110, %c1 : index + cf.br ^bb64(%116 : index) + ^bb69: // pred: ^bb64 + %alloc_31 = memref.alloc() {alignment = 64 : i64} : memref<64x32x!ptr.ptr<#tptr.default_memory_space>> + cf.br ^bb70(%c0 : index) + ^bb70(%117: index): // 2 preds: ^bb69, ^bb74 + %118 = arith.cmpi slt, %117, %c64 : index + cf.cond_br %118, ^bb71, ^bb75 + ^bb71: // pred: ^bb70 + cf.br ^bb72(%c0 : index) + ^bb72(%119: index): // 2 preds: ^bb71, ^bb73 + %120 = arith.cmpi slt, %119, %c32 : index + cf.cond_br %120, ^bb73, ^bb74 + ^bb73: // pred: ^bb72 + memref.store %4, %alloc_31[%117, %119] : memref<64x32x!ptr.ptr<#tptr.default_memory_space>> + %121 = arith.addi %119, %c1 : index + cf.br ^bb72(%121 : index) + ^bb74: // pred: ^bb72 + %122 = arith.addi %117, %c1 : index + cf.br ^bb70(%122 : index) + ^bb75: // pred: ^bb70 + %alloc_32 = memref.alloc() {alignment = 64 : i64} : memref<64x32xf32> + memref.copy %alloc_2, %alloc_32 : memref<64x32xf32> to memref<64x32xf32> + cf.br ^bb76(%67, %alloc_32 : i32, memref<64x32xf32>) + ^bb76(%123: i32, %124: memref<64x32xf32>): // 2 preds: ^bb75, ^bb172 + %125 = arith.cmpi slt, %123, %69 : i32 + cf.cond_br %125, ^bb77, ^bb173 + ^bb77: // pred: ^bb76 + %alloc_33 = memref.alloc() {alignment = 64 : i64} : memref<32xi32> + cf.br ^bb78(%c0 : index) + ^bb78(%126: index): // 2 preds: ^bb77, ^bb79 + %127 = arith.cmpi slt, %126, %c32 : index + cf.cond_br %127, ^bb79, ^bb80 + ^bb79: // pred: ^bb78 + memref.store %123, %alloc_33[%126] : memref<32xi32> + %128 = arith.addi %126, %c1 : index + cf.br ^bb78(%128 : index) + ^bb80: // pred: ^bb78 + cf.br ^bb81(%c0 : index) + ^bb81(%129: index): // 2 preds: ^bb80, ^bb82 + %130 = arith.cmpi slt, %129, %c32 : index + cf.cond_br %130, ^bb82, ^bb83 + ^bb82: // pred: ^bb81 + %131 = memref.load %alloc_33[%129] : memref<32xi32> + %132 = memref.load %alloc_19[%129] : memref<32xi32> + %133 = arith.addi %131, %132 : i32 + memref.store %133, %alloc_33[%129] : memref<32xi32> + %134 = arith.addi %129, %c1 : index + cf.br ^bb81(%134 : index) + ^bb83: // pred: ^bb81 + %alloc_34 = memref.alloc() {alignment = 64 : i64} : memref<32xi1> + cf.br ^bb84(%c0 : index) + ^bb84(%135: index): // 2 preds: ^bb83, ^bb85 + %136 = arith.cmpi slt, %135, %c32 : index + cf.cond_br %136, ^bb85, ^bb86 + ^bb85: // pred: ^bb84 + %137 = memref.load %alloc_33[%135] : memref<32xi32> + %138 = memref.load %alloc_20[%135] : memref<32xi32> + %139 = arith.cmpi slt, %137, %138 : i32 + memref.store %139, %alloc_34[%135] : memref<32xi1> + %140 = arith.addi %135, %c1 : index + cf.br ^bb84(%140 : index) + ^bb86: // pred: ^bb84 + %alloc_35 = memref.alloc() {alignment = 64 : i64} : memref<32xi32> + cf.br ^bb87(%c0 : index) + ^bb87(%141: index): // 2 preds: ^bb86, ^bb88 + %142 = arith.cmpi slt, %141, %c32 : index + cf.cond_br %142, ^bb88, ^bb89 + ^bb88: // pred: ^bb87 + %143 = memref.load %alloc_33[%141] : memref<32xi32> + %144 = memref.load %alloc_5[%141] : memref<32xi32> + %145 = arith.divsi %143, %144 : i32 + memref.store %145, %alloc_35[%141] : memref<32xi32> + %146 = arith.addi %141, %c1 : index + cf.br ^bb87(%146 : index) + ^bb89: // pred: ^bb87 + %alloc_36 = memref.alloc() {alignment = 64 : i64} : memref<32x!ptr.ptr<#tptr.default_memory_space>> + cf.br ^bb90(%c0 : index) + ^bb90(%147: index): // 2 preds: ^bb89, ^bb91 + %148 = arith.cmpi slt, %147, %c32 : index + cf.cond_br %148, ^bb91, ^bb92 + ^bb91: // pred: ^bb90 + %149 = memref.load %alloc_21[%147] : memref<32x!ptr.ptr<#tptr.default_memory_space>> + %150 = memref.load %alloc_35[%147] : memref<32xi32> + %151 = arith.muli %150, %1 : i32 + %152 = tptr.ptradd %149 %151 : <#tptr.default_memory_space>, i32 to <#tptr.default_memory_space> + memref.store %152, %alloc_36[%147] : memref<32x!ptr.ptr<#tptr.default_memory_space>> + %153 = arith.addi %147, %c1 : index + cf.br ^bb90(%153 : index) + ^bb92: // pred: ^bb90 + cf.br ^bb93(%c0 : index) + ^bb93(%154: index): // 2 preds: ^bb92, ^bb98 + %155 = arith.cmpi slt, %154, %c32 : index + cf.cond_br %155, ^bb94, ^bb99 + ^bb94: // pred: ^bb93 + %156 = memref.load %alloc_36[%154] : memref<32x!ptr.ptr<#tptr.default_memory_space>> + %157 = memref.load %alloc_34[%154] : memref<32xi1> + %158 = memref.load %alloc_4[%154] : memref<32xi32> + %159 = tptr.to_memref %156 : <#tptr.default_memory_space> to memref<1xi32> + cf.cond_br %157, ^bb95, ^bb96 + ^bb95: // pred: ^bb94 + %160 = memref.load %159[%c0] : memref<1xi32> + cf.br ^bb97(%160 : i32) + ^bb96: // pred: ^bb94 + cf.br ^bb97(%158 : i32) + ^bb97(%161: i32): // 2 preds: ^bb95, ^bb96 + cf.br ^bb98 + ^bb98: // pred: ^bb97 + memref.store %161, %alloc_3[%154] : memref<32xi32> + %162 = arith.addi %154, %c1 : index + cf.br ^bb93(%162 : index) + ^bb99: // pred: ^bb93 + cf.br ^bb100(%c0 : index) + ^bb100(%163: index): // 2 preds: ^bb99, ^bb101 + %164 = arith.cmpi slt, %163, %c32 : index + cf.cond_br %164, ^bb101, ^bb102 + ^bb101: // pred: ^bb100 + %165 = memref.load %alloc_3[%163] : memref<32xi32> + %166 = memref.load %alloc_5[%163] : memref<32xi32> + %167 = arith.muli %165, %166 : i32 + memref.store %167, %alloc_3[%163] : memref<32xi32> + %168 = arith.addi %163, %c1 : index + cf.br ^bb100(%168 : index) + ^bb102: // pred: ^bb100 + %alloc_37 = memref.alloc() {alignment = 64 : i64} : memref<32xi32> + cf.br ^bb103(%c0 : index) + ^bb103(%169: index): // 2 preds: ^bb102, ^bb104 + %170 = arith.cmpi slt, %169, %c32 : index + cf.cond_br %170, ^bb104, ^bb105 + ^bb104: // pred: ^bb103 + %171 = memref.load %alloc_33[%169] : memref<32xi32> + %172 = memref.load %alloc_5[%169] : memref<32xi32> + %173 = arith.remsi %171, %172 : i32 + memref.store %173, %alloc_37[%169] : memref<32xi32> + %174 = arith.addi %169, %c1 : index + cf.br ^bb103(%174 : index) + ^bb105: // pred: ^bb103 + cf.br ^bb106(%c0 : index) + ^bb106(%175: index): // 2 preds: ^bb105, ^bb107 + %176 = arith.cmpi slt, %175, %c32 : index + cf.cond_br %176, ^bb107, ^bb108 + ^bb107: // pred: ^bb106 + %177 = memref.load %alloc_3[%175] : memref<32xi32> + %178 = memref.load %alloc_37[%175] : memref<32xi32> + %179 = arith.addi %177, %178 : i32 + memref.store %179, %alloc_3[%175] : memref<32xi32> + %180 = arith.addi %175, %c1 : index + cf.br ^bb106(%180 : index) + ^bb108: // pred: ^bb106 + %reinterpret_cast_38 = memref.reinterpret_cast %alloc_3 to offset: [0], sizes: [1, 32], strides: [32, 1] : memref<32xi32> to memref<1x32xi32> + cf.br ^bb109(%c0 : index) + ^bb109(%181: index): // 2 preds: ^bb108, ^bb113 + %182 = arith.cmpi slt, %181, %c1 : index + cf.cond_br %182, ^bb110, ^bb114 + ^bb110: // pred: ^bb109 + cf.br ^bb111(%c0 : index) + ^bb111(%183: index): // 2 preds: ^bb110, ^bb112 + %184 = arith.cmpi slt, %183, %c32 : index + cf.cond_br %184, ^bb112, ^bb113 + ^bb112: // pred: ^bb111 + %185 = memref.load %reinterpret_cast_38[%181, %183] : memref<1x32xi32> + %186 = memref.load %alloc_23[%181, %183] : memref<1x32xi32> + %187 = arith.muli %185, %186 : i32 + memref.store %187, %reinterpret_cast_38[%181, %183] : memref<1x32xi32> + %188 = arith.addi %183, %c1 : index + cf.br ^bb111(%188 : index) + ^bb113: // pred: ^bb111 + %189 = arith.addi %181, %c1 : index + cf.br ^bb109(%189 : index) + ^bb114: // pred: ^bb109 + cf.br ^bb115(%c0 : index) + ^bb115(%190: index): // 2 preds: ^bb114, ^bb119 + %191 = arith.cmpi slt, %190, %c1 : index + cf.cond_br %191, ^bb116, ^bb120 + ^bb116: // pred: ^bb115 + cf.br ^bb117(%c0 : index) + ^bb117(%192: index): // 2 preds: ^bb116, ^bb118 + %193 = arith.cmpi slt, %192, %c32 : index + cf.cond_br %193, ^bb118, ^bb119 + ^bb118: // pred: ^bb117 + %194 = memref.load %reinterpret_cast_38[%190, %192] : memref<1x32xi32> + %195 = memref.load %alloc_24[%190, %192] : memref<1x32xi32> + %196 = arith.addi %194, %195 : i32 + memref.store %196, %reinterpret_cast_38[%190, %192] : memref<1x32xi32> + %197 = arith.addi %192, %c1 : index + cf.br ^bb117(%197 : index) + ^bb119: // pred: ^bb117 + %198 = arith.addi %190, %c1 : index + cf.br ^bb115(%198 : index) + ^bb120: // pred: ^bb115 + cf.br ^bb121(%c0 : index) + ^bb121(%199: index): // 2 preds: ^bb120, ^bb125 + %200 = arith.cmpi slt, %199, %c64 : index + cf.cond_br %200, ^bb122, ^bb126 + ^bb122: // pred: ^bb121 + cf.br ^bb123(%c0 : index) + ^bb123(%201: index): // 2 preds: ^bb122, ^bb124 + %202 = arith.cmpi slt, %201, %c32 : index + cf.cond_br %202, ^bb124, ^bb125 + ^bb124: // pred: ^bb123 + %203 = memref.load %reinterpret_cast_38[%c0, %201] : memref<1x32xi32> + memref.store %203, %alloc_26[%199, %201] : memref<64x32xi32> + %204 = arith.addi %201, %c1 : index + cf.br ^bb123(%204 : index) + ^bb125: // pred: ^bb123 + %205 = arith.addi %199, %c1 : index + cf.br ^bb121(%205 : index) + ^bb126: // pred: ^bb121 + cf.br ^bb127(%c0 : index) + ^bb127(%206: index): // 2 preds: ^bb126, ^bb131 + %207 = arith.cmpi slt, %206, %c64 : index + cf.cond_br %207, ^bb128, ^bb132 + ^bb128: // pred: ^bb127 + cf.br ^bb129(%c0 : index) + ^bb129(%208: index): // 2 preds: ^bb128, ^bb130 + %209 = arith.cmpi slt, %208, %c32 : index + cf.cond_br %209, ^bb130, ^bb131 + ^bb130: // pred: ^bb129 + %210 = memref.load %alloc_26[%206, %208] : memref<64x32xi32> + %211 = memref.load %alloc_27[%206, %208] : memref<64x32xi32> + %212 = arith.addi %210, %211 : i32 + memref.store %212, %alloc_26[%206, %208] : memref<64x32xi32> + %213 = arith.addi %208, %c1 : index + cf.br ^bb129(%213 : index) + ^bb131: // pred: ^bb129 + %214 = arith.addi %206, %c1 : index + cf.br ^bb127(%214 : index) + ^bb132: // pred: ^bb127 + %reinterpret_cast_39 = memref.reinterpret_cast %alloc_33 to offset: [0], sizes: [1, 32], strides: [32, 1] : memref<32xi32> to memref<1x32xi32> + %alloc_40 = memref.alloc() {alignment = 64 : i64} : memref<1x32xi1> + cf.br ^bb133(%c0 : index) + ^bb133(%215: index): // 2 preds: ^bb132, ^bb137 + %216 = arith.cmpi slt, %215, %c1 : index + cf.cond_br %216, ^bb134, ^bb138 + ^bb134: // pred: ^bb133 + cf.br ^bb135(%c0 : index) + ^bb135(%217: index): // 2 preds: ^bb134, ^bb136 + %218 = arith.cmpi slt, %217, %c32 : index + cf.cond_br %218, ^bb136, ^bb137 + ^bb136: // pred: ^bb135 + %219 = memref.load %reinterpret_cast_39[%215, %217] : memref<1x32xi32> + %220 = memref.load %alloc_22[%215, %217] : memref<1x32xi32> + %221 = arith.cmpi slt, %219, %220 : i32 + memref.store %221, %alloc_40[%215, %217] : memref<1x32xi1> + %222 = arith.addi %217, %c1 : index + cf.br ^bb135(%222 : index) + ^bb137: // pred: ^bb135 + %223 = arith.addi %215, %c1 : index + cf.br ^bb133(%223 : index) + ^bb138: // pred: ^bb133 + cf.br ^bb139(%c0 : index) + ^bb139(%224: index): // 2 preds: ^bb138, ^bb143 + %225 = arith.cmpi slt, %224, %c64 : index + cf.cond_br %225, ^bb140, ^bb144 + ^bb140: // pred: ^bb139 + cf.br ^bb141(%c0 : index) + ^bb141(%226: index): // 2 preds: ^bb140, ^bb142 + %227 = arith.cmpi slt, %226, %c32 : index + cf.cond_br %227, ^bb142, ^bb143 + ^bb142: // pred: ^bb141 + %228 = memref.load %alloc_40[%c0, %226] : memref<1x32xi1> + memref.store %228, %alloc_29[%224, %226] : memref<64x32xi1> + %229 = arith.addi %226, %c1 : index + cf.br ^bb141(%229 : index) + ^bb143: // pred: ^bb141 + %230 = arith.addi %224, %c1 : index + cf.br ^bb139(%230 : index) + ^bb144: // pred: ^bb139 + cf.br ^bb145(%c0 : index) + ^bb145(%231: index): // 2 preds: ^bb144, ^bb149 + %232 = arith.cmpi slt, %231, %c64 : index + cf.cond_br %232, ^bb146, ^bb150 + ^bb146: // pred: ^bb145 + cf.br ^bb147(%c0 : index) + ^bb147(%233: index): // 2 preds: ^bb146, ^bb148 + %234 = arith.cmpi slt, %233, %c32 : index + cf.cond_br %234, ^bb148, ^bb149 + ^bb148: // pred: ^bb147 + %235 = memref.load %alloc_29[%231, %233] : memref<64x32xi1> + %236 = memref.load %alloc_30[%231, %233] : memref<64x32xi1> + %237 = arith.andi %235, %236 : i1 + memref.store %237, %alloc_29[%231, %233] : memref<64x32xi1> + %238 = arith.addi %233, %c1 : index + cf.br ^bb147(%238 : index) + ^bb149: // pred: ^bb147 + %239 = arith.addi %231, %c1 : index + cf.br ^bb145(%239 : index) + ^bb150: // pred: ^bb145 + %alloc_41 = memref.alloc() {alignment = 64 : i64} : memref<64x32x!ptr.ptr<#tptr.default_memory_space>> + cf.br ^bb151(%c0 : index) + ^bb151(%240: index): // 2 preds: ^bb150, ^bb155 + %241 = arith.cmpi slt, %240, %c64 : index + cf.cond_br %241, ^bb152, ^bb156 + ^bb152: // pred: ^bb151 + cf.br ^bb153(%c0 : index) + ^bb153(%242: index): // 2 preds: ^bb152, ^bb154 + %243 = arith.cmpi slt, %242, %c32 : index + cf.cond_br %243, ^bb154, ^bb155 + ^bb154: // pred: ^bb153 + %244 = memref.load %alloc_31[%240, %242] : memref<64x32x!ptr.ptr<#tptr.default_memory_space>> + %245 = memref.load %alloc_26[%240, %242] : memref<64x32xi32> + %246 = arith.muli %245, %0 : i32 + %247 = tptr.ptradd %244 %246 : <#tptr.default_memory_space>, i32 to <#tptr.default_memory_space> + memref.store %247, %alloc_41[%240, %242] : memref<64x32x!ptr.ptr<#tptr.default_memory_space>> + %248 = arith.addi %242, %c1 : index + cf.br ^bb153(%248 : index) + ^bb155: // pred: ^bb153 + %249 = arith.addi %240, %c1 : index + cf.br ^bb151(%249 : index) + ^bb156: // pred: ^bb151 + cf.br ^bb157(%c0 : index) + ^bb157(%250: index): // 2 preds: ^bb156, ^bb165 + %251 = arith.cmpi slt, %250, %c64 : index + cf.cond_br %251, ^bb158, ^bb166 + ^bb158: // pred: ^bb157 + cf.br ^bb159(%c0 : index) + ^bb159(%252: index): // 2 preds: ^bb158, ^bb164 + %253 = arith.cmpi slt, %252, %c32 : index + cf.cond_br %253, ^bb160, ^bb165 + ^bb160: // pred: ^bb159 + %254 = memref.load %alloc_41[%250, %252] : memref<64x32x!ptr.ptr<#tptr.default_memory_space>> + %255 = memref.load %alloc_29[%250, %252] : memref<64x32xi1> + %256 = memref.load %alloc_2[%250, %252] : memref<64x32xf32> + %257 = tptr.to_memref %254 : <#tptr.default_memory_space> to memref<1xf32> + cf.cond_br %255, ^bb161, ^bb162 + ^bb161: // pred: ^bb160 + %258 = memref.load %257[%c0] : memref<1xf32> + cf.br ^bb163(%258 : f32) + ^bb162: // pred: ^bb160 + cf.br ^bb163(%256 : f32) + ^bb163(%259: f32): // 2 preds: ^bb161, ^bb162 + cf.br ^bb164 + ^bb164: // pred: ^bb163 + memref.store %259, %alloc[%250, %252] : memref<64x32xf32> + %260 = arith.addi %252, %c1 : index + cf.br ^bb159(%260 : index) + ^bb165: // pred: ^bb159 + %261 = arith.addi %250, %c1 : index + cf.br ^bb157(%261 : index) + ^bb166: // pred: ^bb157 + cf.br ^bb167(%c0 : index) + ^bb167(%262: index): // 2 preds: ^bb166, ^bb171 + %263 = arith.cmpi slt, %262, %c64 : index + cf.cond_br %263, ^bb168, ^bb172 + ^bb168: // pred: ^bb167 + cf.br ^bb169(%c0 : index) + ^bb169(%264: index): // 2 preds: ^bb168, ^bb170 + %265 = arith.cmpi slt, %264, %c32 : index + cf.cond_br %265, ^bb170, ^bb171 + ^bb170: // pred: ^bb169 + %266 = memref.load %124[%262, %264] : memref<64x32xf32> + %267 = memref.load %alloc[%262, %264] : memref<64x32xf32> + %268 = arith.addf %266, %267 : f32 + memref.store %268, %124[%262, %264] : memref<64x32xf32> + %269 = arith.addi %264, %c1 : index + cf.br ^bb169(%269 : index) + ^bb171: // pred: ^bb169 + %270 = arith.addi %262, %c1 : index + cf.br ^bb167(%270 : index) + ^bb172: // pred: ^bb167 + %271 = arith.addi %123, %c32_i32 : i32 + cf.br ^bb76(%271, %124 : i32, memref<64x32xf32>) + ^bb173: // pred: ^bb76 + cf.br ^bb175(%124 : memref<64x32xf32>) + ^bb174: // pred: ^bb29 + cf.br ^bb175(%alloc_2 : memref<64x32xf32>) + ^bb175(%272: memref<64x32xf32>): // 2 preds: ^bb173, ^bb174 + cf.br ^bb176 + ^bb176: // pred: ^bb175 + %273 = arith.cmpi eq, %arg24, %c0_i32 : i32 + %274 = arith.cmpi eq, %arg25, %c1_i32 : i32 + %275 = arith.cmpi eq, %arg26, %c1_i32 : i32 + %276 = arith.andi %274, %275 : i1 + %277 = arith.andi %273, %276 : i1 + cf.cond_br %277, ^bb177, ^bb178 + ^bb177: // pred: ^bb176 + %reinterpret_cast_42 = memref.reinterpret_cast %arg8 to offset: [0], sizes: [64, 32], strides: [32, 1] : memref<*xf32> to memref<64x32xf32, strided<[32, 1]>> + memref.copy %272, %reinterpret_cast_42 : memref<64x32xf32> to memref<64x32xf32, strided<[32, 1]>> + cf.br ^bb178 + ^bb178: // 2 preds: ^bb176, ^bb177 + return + } +} + +// CHECK-NOT: tptr. +// CHECK-NOT: ptr.ptr \ No newline at end of file diff --git a/test/Conversion/TPtrToLLVM/simple_cf_ptradd.mlir b/test/Conversion/TPtrToLLVM/simple_cf_ptradd.mlir new file mode 100644 index 00000000..9824a706 --- /dev/null +++ b/test/Conversion/TPtrToLLVM/simple_cf_ptradd.mlir @@ -0,0 +1,92 @@ +// RUN: triton-shared-opt --tptr-to-llvm %s | FileCheck %s + +module { + func.func @simple_cf_into_structured_load_2(%arg0: memref<*xi64> {tt.divisibility = 16 : i32}, %arg1: memref<*xi64> {tt.divisibility = 16 : i32}, %arg2: memref<*xf32> {tt.divisibility = 16 : i32}, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32) { + %0 = tptr.type_offset f32 : i32 + %c0 = arith.constant 0 : index + %1 = tptr.type_offset i64 : i32 + %c3_i32 = arith.constant 3 : i32 + %c2_i32 = arith.constant 2 : i32 + %c0_i32 = arith.constant 0 : i32 + %cast = memref.cast %arg2 : memref<*xf32> to memref<1xf32> + %2 = tptr.from_memref %cast : memref<1xf32> to <#tptr.default_memory_space> + %cast_0 = memref.cast %arg1 : memref<*xi64> to memref<1xi64> + %3 = tptr.from_memref %cast_0 : memref<1xi64> to <#tptr.default_memory_space> + %cast_1 = memref.cast %arg0 : memref<*xi64> to memref<1xi64> + %4 = tptr.from_memref %cast_1 : memref<1xi64> to <#tptr.default_memory_space> + %5 = arith.remsi %arg6, %c2_i32 : i32 + %6 = arith.cmpi eq, %5, %c0_i32 : i32 + cf.cond_br %6, ^bb1, ^bb2 + ^bb1: // pred: ^bb0 + %7 = arith.muli %c2_i32, %1 : i32 + %8 = tptr.ptradd %4 %7 : <#tptr.default_memory_space>, i32 to <#tptr.default_memory_space> + cf.br ^bb3(%8 : !ptr.ptr<#tptr.default_memory_space>) + ^bb2: // pred: ^bb0 + %9 = arith.muli %c3_i32, %1 : i32 + %10 = tptr.ptradd %3 %9 : <#tptr.default_memory_space>, i32 to <#tptr.default_memory_space> + cf.br ^bb3(%10 : !ptr.ptr<#tptr.default_memory_space>) + ^bb3(%11: !ptr.ptr<#tptr.default_memory_space>): // 2 preds: ^bb1, ^bb2 + cf.br ^bb4 + ^bb4: // pred: ^bb3 + %12 = arith.muli %arg6, %1 : i32 + %13 = tptr.ptradd %11 %12 : <#tptr.default_memory_space>, i32 to <#tptr.default_memory_space> + %14 = tptr.to_memref %13 : <#tptr.default_memory_space> to memref<1xi64> + %15 = memref.load %14[%c0] : memref<1xi64> + %16 = arith.muli %arg6, %0 : i32 + %17 = tptr.ptradd %2 %16 : <#tptr.default_memory_space>, i32 to <#tptr.default_memory_space> + %18 = arith.sitofp %15 : i64 to f32 + %19 = tptr.to_memref %17 : <#tptr.default_memory_space> to memref<1xf32> + memref.store %18, %19[%c0] : memref<1xf32> + return + } +} + + +// CHECK-LABEL: func.func @simple_cf_into_structured_load_2( +// CHECK-SAME: %[[VAL_0:.*]]: memref<*xi64> {tt.divisibility = 16 : i32}, %[[VAL_1:.*]]: memref<*xi64> {tt.divisibility = 16 : i32}, %[[VAL_2:.*]]: memref<*xf32> {tt.divisibility = 16 : i32}, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32, %[[VAL_5:.*]]: i32, %[[VAL_6:.*]]: i32, %[[VAL_7:.*]]: i32, %[[VAL_8:.*]]: i32) { +// CHECK: %[[VAL_9:.*]] = llvm.mlir.constant(1 : i64) : i64 +// CHECK: %[[VAL_10:.*]] = llvm.mlir.constant(0 : i64) : i64 +// CHECK: %[[VAL_11:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> +// CHECK: %[[VAL_12:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_13:.*]] = arith.constant 2 : i32 +// CHECK: %[[VAL_14:.*]] = arith.constant 0 : i32 +// CHECK: %[[VAL_15:.*]] = memref.cast %[[VAL_2]] : memref<*xf32> to memref<1xf32> +// CHECK: %[[VAL_16:.*]] = builtin.unrealized_conversion_cast %[[VAL_15]] : memref<1xf32> to !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> +// CHECK: %[[VAL_17:.*]] = llvm.extractvalue %[[VAL_16]][0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> +// CHECK: %[[VAL_18:.*]] = memref.cast %[[VAL_1]] : memref<*xi64> to memref<1xi64> +// CHECK: %[[VAL_19:.*]] = builtin.unrealized_conversion_cast %[[VAL_18]] : memref<1xi64> to !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> +// CHECK: %[[VAL_20:.*]] = llvm.extractvalue %[[VAL_19]][0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> +// CHECK: %[[VAL_21:.*]] = memref.cast %[[VAL_0]] : memref<*xi64> to memref<1xi64> +// CHECK: %[[VAL_22:.*]] = builtin.unrealized_conversion_cast %[[VAL_21]] : memref<1xi64> to !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> +// CHECK: %[[VAL_23:.*]] = llvm.extractvalue %[[VAL_22]][0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> +// CHECK: %[[VAL_24:.*]] = arith.remsi %[[VAL_6]], %[[VAL_13]] : i32 +// CHECK: %[[VAL_25:.*]] = arith.cmpi eq, %[[VAL_24]], %[[VAL_14]] : i32 +// CHECK: cf.cond_br %[[VAL_25]], ^bb1, ^bb2 +// CHECK: ^bb1: +// CHECK: %[[VAL_26:.*]] = llvm.getelementptr %[[VAL_23]][2] : (!llvm.ptr) -> !llvm.ptr, i64 +// CHECK: cf.br ^bb3(%[[VAL_26]] : !llvm.ptr) +// CHECK: ^bb2: +// CHECK: %[[VAL_27:.*]] = llvm.getelementptr %[[VAL_20]][3] : (!llvm.ptr) -> !llvm.ptr, i64 +// CHECK: cf.br ^bb3(%[[VAL_27]] : !llvm.ptr) +// CHECK: ^bb3(%[[VAL_28:.*]]: !llvm.ptr): +// CHECK: %[[VAL_29:.*]] = llvm.getelementptr %[[VAL_28]]{{\[}}%[[VAL_6]]] : (!llvm.ptr, i32) -> !llvm.ptr, i64 +// CHECK: %[[VAL_30:.*]] = llvm.insertvalue %[[VAL_29]], %[[VAL_11]][0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> +// CHECK: %[[VAL_31:.*]] = llvm.insertvalue %[[VAL_29]], %[[VAL_30]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> +// CHECK: %[[VAL_32:.*]] = llvm.insertvalue %[[VAL_10]], %[[VAL_31]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> +// CHECK: %[[VAL_33:.*]] = llvm.insertvalue %[[VAL_9]], %[[VAL_32]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> +// CHECK: %[[VAL_34:.*]] = llvm.insertvalue %[[VAL_9]], %[[VAL_33]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> +// CHECK: %[[VAL_35:.*]] = builtin.unrealized_conversion_cast %[[VAL_34]] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> to memref<1xi64> +// CHECK: %[[VAL_36:.*]] = memref.load %[[VAL_35]]{{\[}}%[[VAL_12]]] : memref<1xi64> +// CHECK: %[[VAL_37:.*]] = llvm.getelementptr %[[VAL_17]]{{\[}}%[[VAL_6]]] : (!llvm.ptr, i32) -> !llvm.ptr, f32 +// CHECK: %[[VAL_38:.*]] = arith.sitofp %[[VAL_36]] : i64 to f32 +// CHECK: %[[VAL_39:.*]] = llvm.insertvalue %[[VAL_37]], %[[VAL_11]][0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> +// CHECK: %[[VAL_40:.*]] = llvm.insertvalue %[[VAL_37]], %[[VAL_39]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> +// CHECK: %[[VAL_41:.*]] = llvm.insertvalue %[[VAL_10]], %[[VAL_40]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> +// CHECK: %[[VAL_42:.*]] = llvm.insertvalue %[[VAL_9]], %[[VAL_41]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> +// CHECK: %[[VAL_43:.*]] = llvm.insertvalue %[[VAL_9]], %[[VAL_42]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> +// CHECK: %[[VAL_44:.*]] = builtin.unrealized_conversion_cast %[[VAL_43]] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> to memref<1xf32> +// CHECK: memref.store %[[VAL_38]], %[[VAL_44]]{{\[}}%[[VAL_12]]] : memref<1xf32> +// CHECK: return +// CHECK: } + + diff --git a/tools/triton-shared-opt/CMakeLists.txt b/tools/triton-shared-opt/CMakeLists.txt index 260c3e5d..e8374b44 100644 --- a/tools/triton-shared-opt/CMakeLists.txt +++ b/tools/triton-shared-opt/CMakeLists.txt @@ -14,6 +14,7 @@ target_link_libraries(triton-shared-opt PRIVATE ${triton_libs} # MLIR core MLIROptLib + MLIRLLVMDialect MLIRPass MLIRRegisterAllPasses MLIRTransforms diff --git a/tools/triton-shared-opt/RegisterTritonSharedDialects.h b/tools/triton-shared-opt/RegisterTritonSharedDialects.h index 6d92953f..c5743a7d 100644 --- a/tools/triton-shared-opt/RegisterTritonSharedDialects.h +++ b/tools/triton-shared-opt/RegisterTritonSharedDialects.h @@ -1,16 +1,23 @@ #pragma once +// Core dialects and passes needed by triton-shared-opt +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Ptr/IR/PtrDialect.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/Triton/Transforms/Passes.h" #include "triton-shared/Conversion/StructuredToMemref/Passes.h" +#include "triton-shared/Conversion/TPtrToLLVM/Passes.h" #include "triton-shared/Conversion/TritonArithToLinalg/Passes.h" #include "triton-shared/Conversion/TritonPtrToMemref/Passes.h" #include "triton-shared/Conversion/TritonToLinalg/Passes.h" @@ -38,6 +45,7 @@ inline void registerTritonSharedDialects(mlir::DialectRegistry ®istry) { mlir::triton::registerTritonArithToLinalgPasses(); mlir::triton::registerStructuredToMemrefPasses(); mlir::triton::registerAddLLVMDebugInfoPass(); + mlir::tptr::registerTPtrToLLVM(); // TODO: register Triton & TritonGPU passes registry.insert< diff --git a/triton_shared.cc b/triton_shared.cc index e69ed1db..91428dab 100644 --- a/triton_shared.cc +++ b/triton_shared.cc @@ -1,8 +1,167 @@ -#include +// PyBind11 +#include +#include +#include namespace py = pybind11; -// The CPU backend with triton_shared doesn't do compilation from within python -// but rather externally through triton-shared-opt, so we leave this function -// blank. -void init_triton_triton_shared(py::module &&m) {} +// LLVM +#include "llvm/IR/Constants.h" +#include "llvm/Support/TargetSelect.h" + +// MLIR: Conversion Passes +#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" +#include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h" +#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" +#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" +#include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h" +#include "mlir/Conversion/MathToLLVM/MathToLLVM.h" +#include "mlir/Conversion/Passes.h" +#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" +#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.h" + +// MLIR: Dialects +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Bufferization/Transforms/Passes.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/MemRef/Transforms/Passes.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tensor/Transforms/Passes.h" +#include "mlir/Dialect/UB/IR/UBOps.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Transforms/Passes.h" + +// MLIR: Core IR and Passes +#include "mlir/IR/DialectRegistry.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/InitAllDialects.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" + +// MLIR: Target and Translation +// #include "mlir/Target/LLVMIR/Dialect/AMX/AMXToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" + +// LLVM: Debug +#include "llvm/Support/Debug.h" // Key header file + +// MLIR: Top-level Transforms +#include "mlir/Transforms/Passes.h" + +// Triton and other third-party dialects +#include "triton-shared/Conversion/TPtrToLLVM/TPtrToLLVM.h" +#include "triton-shared/Dialect/TPtr/IR/TPtrDialect.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonInstrument/IR/Dialect.h" + +#define ADD_PASS_WRAPPER_0(name, builder) \ + m.def(name, [](mlir::PassManager &pm) { pm.addPass(builder()); }) + +#define ADD_PASS_WRAPPER_1(name, builder, ty0) \ + m.def(name, \ + [](mlir::PassManager &pm, ty0 val0) { pm.addPass(builder(val0)); }) + +#define ADD_PASS_WRAPPER_1_ARG(name, builder, ty0, arg0, val0) \ + m.def( \ + name, \ + [](mlir::PassManager &pm, ty0 arg0) { pm.addPass(builder(val0)); }, \ + py::arg("pm"), py::arg(#arg0) = val0); + +// Function to set MLIR/LLVM debug type +void enable_mlir_debug(const std::string &debug_type) { + ::llvm::DebugFlag = true; + llvm::setCurrentDebugType(debug_type.c_str()); +} + +void init_to_llvm(py::module &&m) { + using namespace mlir; + + ADD_PASS_WRAPPER_0("add_eliminate_empty_tensors", + bufferization::createEmptyTensorEliminationPass); + ADD_PASS_WRAPPER_0("add_convert_linalg_to_affine_loops", + createConvertLinalgToAffineLoopsPass); + ADD_PASS_WRAPPER_0("add_empty_tensor_to_alloc_tensor", + bufferization::createEmptyTensorToAllocTensorPass); + + ADD_PASS_WRAPPER_1_ARG( + "add_one_shot_bufferize_with_options", + [](bool allowReturnAllocsFromLoops) { + mlir::bufferization::OneShotBufferizePassOptions options; + options.allowReturnAllocsFromLoops = allowReturnAllocsFromLoops; + return mlir::bufferization::createOneShotBufferizePass(options); + }, + bool, allow_return_allocs_from_loops, true); + + ADD_PASS_WRAPPER_0("add_one_shot_bufferize", + bufferization::createOneShotBufferizePass); + ADD_PASS_WRAPPER_0("add_lower_affine", createLowerAffinePass); + ADD_PASS_WRAPPER_0("add_convert_linalg_to_loops", + createConvertLinalgToLoopsPass); + ADD_PASS_WRAPPER_0("add_expand_strided_metadata", + memref::createExpandStridedMetadataPass); + ADD_PASS_WRAPPER_0("add_convert_scf_to_cf", createSCFToControlFlowPass); + ADD_PASS_WRAPPER_0("add_convert_arith_to_llvm", + createArithToLLVMConversionPass); + ADD_PASS_WRAPPER_0("add_convert_math_to_llvm", createConvertMathToLLVMPass); + ADD_PASS_WRAPPER_0("add_convert_complex_to_llvm", + createConvertComplexToLLVMPass); + ADD_PASS_WRAPPER_0("add_convert_vector_to_llvm", + createConvertVectorToLLVMPass); + ADD_PASS_WRAPPER_0("add_convert_index_to_llvm", createConvertIndexToLLVMPass); + ADD_PASS_WRAPPER_0("add_memref_expand", memref::createExpandOpsPass); + ADD_PASS_WRAPPER_0("add_finalize_memref_to_llvm", + createFinalizeMemRefToLLVMConversionPass); + ADD_PASS_WRAPPER_0("add_convert_func_to_llvm", createConvertFuncToLLVMPass); + ADD_PASS_WRAPPER_0("add_convert_tptr_to_llvm", tptr::createTPtrToLLVMPass); + ADD_PASS_WRAPPER_0("add_convert_cf_to_llvm", + createConvertControlFlowToLLVMPass); + ADD_PASS_WRAPPER_0("add_reconcile_unrealized_casts", + createReconcileUnrealizedCastsPass); +} + +void init_triton_shared_ir(py::module &&m) { + m.def("load_dialects", [](mlir::MLIRContext &context) { + mlir::DialectRegistry registry; + + // Register core dialects + registry.insert<::mlir::triton::TritonDialect, + ::mlir::linalg::LinalgDialect, + ::mlir::bufferization::BufferizationDialect, + ::mlir::tptr::TPtrDialect, ::mlir::math::MathDialect, + ::mlir::memref::MemRefDialect, ::mlir::arith::ArithDialect, + ::mlir::scf::SCFDialect, ::mlir::vector::VectorDialect, + ::mlir::cf::ControlFlowDialect, ::mlir::LLVM::LLVMDialect, + ::mlir::ub::UBDialect, ::mlir::func::FuncDialect>(); + + // Register interfaces and translations + registerAllDialects(registry); + context.appendDialectRegistry(registry); + context.loadAllAvailableDialects(); + }); +} + +void init_triton_shared_debug(py::module &&m) { + m.def("enable_mlir_debug", enable_mlir_debug, + "Enables a specific MLIR/LLVM debug type (e.g., 'pattern-rewrite'). " + "Pass an empty string to disable.", + py::arg("debug_type")); +} + +void init_triton_triton_shared(py::module &&m) { + init_to_llvm(m.def_submodule("to_llir")); + init_triton_shared_ir(m.def_submodule("ir")); + init_triton_shared_debug(m.def_submodule("debug")); +}