diff --git a/include/triton-shared/Conversion/CMakeLists.txt b/include/triton-shared/Conversion/CMakeLists.txt index dfa72232..14b72b1a 100644 --- a/include/triton-shared/Conversion/CMakeLists.txt +++ b/include/triton-shared/Conversion/CMakeLists.txt @@ -4,3 +4,4 @@ add_subdirectory(TritonToStructured) add_subdirectory(TritonArithToLinalg) add_subdirectory(TritonToUnstructured) add_subdirectory(StructuredToMemref) +add_subdirectory(TritonPtrToMemref) diff --git a/include/triton-shared/Conversion/TritonPtrToMemref/CMakeLists.txt b/include/triton-shared/Conversion/TritonPtrToMemref/CMakeLists.txt new file mode 100644 index 00000000..07f9ad33 --- /dev/null +++ b/include/triton-shared/Conversion/TritonPtrToMemref/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls --name TritonPtrToMemref) +add_public_tablegen_target(TritonPtrToMemrefConversionPassIncGen) diff --git a/include/triton-shared/Conversion/TritonPtrToMemref/Passes.h b/include/triton-shared/Conversion/TritonPtrToMemref/Passes.h new file mode 100644 index 00000000..e1f6f33b --- /dev/null +++ b/include/triton-shared/Conversion/TritonPtrToMemref/Passes.h @@ -0,0 +1,15 @@ +#ifndef TRITON_PTR_TO_MEMREF_CONVERSION_PASSES_H +#define TRITON_PTR_TO_MEMREF_CONVERSION_PASSES_H + +#include "triton-shared/Conversion/TritonPtrToMemref/TritonPtrToMemref.h" + +namespace mlir { +namespace triton { + +#define GEN_PASS_REGISTRATION +#include "triton-shared/Conversion/TritonPtrToMemref/Passes.h.inc" + +} // namespace triton +} // namespace mlir + +#endif diff --git a/include/triton-shared/Conversion/TritonPtrToMemref/Passes.td b/include/triton-shared/Conversion/TritonPtrToMemref/Passes.td new file mode 100644 index 00000000..c027b098 --- /dev/null +++ b/include/triton-shared/Conversion/TritonPtrToMemref/Passes.td @@ -0,0 +1,11 @@ +#ifndef TRITON_PTR_TO_MEMREF_CONVERSION_PASSES +#define TRITON_PTR_TO_MEMREF_CONVERSION_PASSES + +include "mlir/Pass/PassBase.td" + +def TritonPtrToMemref : Pass<"triton-ptr-to-memref", "mlir::ModuleOp"> { + let summary = "Convert triton pointer to unranked memref"; + let constructor = "triton::createTritonPtrToMemrefPass()"; +} + +#endif diff --git a/include/triton-shared/Conversion/TritonPtrToMemref/TritonPtrToMemref.h b/include/triton-shared/Conversion/TritonPtrToMemref/TritonPtrToMemref.h new file mode 100644 index 00000000..4476f7d6 --- /dev/null +++ b/include/triton-shared/Conversion/TritonPtrToMemref/TritonPtrToMemref.h @@ -0,0 +1,17 @@ +#ifndef TRITON_CONVERSION_TRITON_PTR_TO_MEMREF_TRITON_PTR_TO_MEMREF_H +#define TRITON_CONVERSION_TRITON_PTR_TO_MEMREF_TRITON_PTR_TO_MEMREF_H + +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +#include "triton/Dialect/Triton/IR/Dialect.h" + +namespace mlir { +namespace triton { + +std::unique_ptr> createTritonPtrToMemrefPass(); + +} // namespace triton +} // namespace mlir + +#endif // TRITON_CONVERSION_TRITON_PTR_TO_MEMREF_TRITON_PTR_TO_MEMREF_H diff --git a/lib/Conversion/CMakeLists.txt b/lib/Conversion/CMakeLists.txt index 2e4d1f2f..972beace 100644 --- a/lib/Conversion/CMakeLists.txt +++ b/lib/Conversion/CMakeLists.txt @@ -4,3 +4,4 @@ add_subdirectory(TritonToStructured) add_subdirectory(TritonToUnstructured) add_subdirectory(TritonArithToLinalg) add_subdirectory(StructuredToMemref) +add_subdirectory(TritonPtrToMemref) diff --git a/lib/Conversion/TritonPtrToMemref/CMakeLists.txt b/lib/Conversion/TritonPtrToMemref/CMakeLists.txt new file mode 100644 index 00000000..eeef2832 --- /dev/null +++ b/lib/Conversion/TritonPtrToMemref/CMakeLists.txt @@ -0,0 +1,18 @@ +add_triton_library(TritonPtrToMemref + TritonPtrToMemrefPass.cpp + + DEPENDS + TritonPtrToMemrefConversionPassIncGen + + LINK_LIBS PUBLIC + MLIRArithDialect + MLIRDialectUtils + MLIRIR + MLIRMathDialect + MLIRPass + MLIRTensorDialect + MLIRTransforms + MLIRSupport + MLIRReconcileUnrealizedCasts + TritonIR +) diff --git a/lib/Conversion/TritonPtrToMemref/TritonPtrToMemrefPass.cpp b/lib/Conversion/TritonPtrToMemref/TritonPtrToMemrefPass.cpp new file mode 100644 index 00000000..bb4582cb --- /dev/null +++ b/lib/Conversion/TritonPtrToMemref/TritonPtrToMemrefPass.cpp @@ -0,0 +1,124 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/Transforms/FuncConversions.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Types.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/Passes.h" +#include "triton-shared/Analysis/OpFoldResultUtils.h" +#include "triton-shared/AnalysisStructured/PtrAnalysis.h" +#include "triton-shared/Conversion/TritonPtrToMemref/TritonPtrToMemref.h" +#include "triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.h" + +#include "triton/Dialect/Triton/IR/Dialect.h" + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "triton/Dialect/Triton/IR/Types.h" + +#define DEBUG_TYPE "triton-ptr-to-memref" + +using namespace mlir; +using namespace triton; + +#define GEN_PASS_CLASSES +#include "triton-shared/Conversion/TritonPtrToMemref/Passes.h.inc" + +namespace { + +class TritonFunctionSignatureConverter : public TypeConverter { +public: + TritonFunctionSignatureConverter() { + // The order of type conversion is important: later ones are tried earlier. + addConversion([](Type type) { return type; }); + addConversion([](triton::PointerType ptrType) { + return UnrankedMemRefType::get(ptrType.getPointeeType(), + /*memorySpace=*/0); + }); + addConversion([](RankedTensorType tensorType) -> std::optional { + if (auto ptrType = + dyn_cast(tensorType.getElementType())) { + return MemRefType::get(tensorType.getShape(), ptrType.getPointeeType()); + } + return std::nullopt; + }); + + auto createUnrealizedCast = [&](OpBuilder &builder, Type resultType, + ValueRange inputs, + Location loc) -> std::optional { + return builder.create(loc, resultType, inputs) + .getResult(0); + }; + addSourceMaterialization(createUnrealizedCast); + addArgumentMaterialization(createUnrealizedCast); + } +}; + +class TritonPtrToMemrefPass + : public TritonPtrToMemrefBase { + +public: + void getDependentDialects(DialectRegistry ®istry) const override { + registry + .insert(); + } + + void runOnOperation() override { + auto moduleOp = getOperation(); + + RewritePatternSet patterns(&getContext()); + ConversionTarget target(getContext()); + TritonFunctionSignatureConverter typeConverter; + + // Update function signature and call ops to use memrefs + target.addDynamicallyLegalOp([&](auto op) { + return typeConverter.isSignatureLegal( + cast(cast(op).getFunctionType())); + }); + + target.addDynamicallyLegalOp([&](func::CallOp op) { + return typeConverter.isLegal(op.getResultTypes()) && + typeConverter.isLegal(op.getOperandTypes()); + }); + + populateFunctionOpInterfaceTypeConversionPattern( + patterns, typeConverter); + populateFunctionOpInterfaceTypeConversionPattern( + patterns, typeConverter); + populateCallOpTypeConversionPattern(patterns, typeConverter); + + if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) { + signalPassFailure(); + } + + PassManager pm(&getContext(), moduleOp.getOperationName()); + pm.addPass(createCanonicalizerPass()); + pm.addPass(createCSEPass()); + if (failed(runPipeline(pm, getOperation()))) { + signalPassFailure(); + } + } +}; +} // namespace + +std::unique_ptr> triton::createTritonPtrToMemrefPass() { + return std::make_unique(); +} diff --git a/test/Conversion/TritonPtrToMemref/call.mlir b/test/Conversion/TritonPtrToMemref/call.mlir new file mode 100644 index 00000000..e19413b6 --- /dev/null +++ b/test/Conversion/TritonPtrToMemref/call.mlir @@ -0,0 +1,23 @@ +// RUN: triton-shared-opt --triton-arith-to-linalg --triton-ptr-to-memref %s | FileCheck %s + +module { + tt.func @_sum_combine__fp32(%arg0: !tt.ptr) -> f32{ + %0 = arith.constant 42.0 : f32 + tt.return %0 : f32 + } + tt.func @test(%arg0: !tt.ptr) -> f32{ + %0 = tt.call @_sum_combine__fp32(%arg0) : (!tt.ptr) -> f32 + tt.return %0 : f32 + } +} + +// CHECK: module { +// CHECK: func.func @_sum_combine__fp32(%arg0: memref<*xf32>, %arg1: i32, %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32) -> f32 { +// CHECK: %cst = arith.constant 4.200000e+01 : f32 +// CHECK: return %cst : f32 +// CHECK: } +// CHECK: func.func @test(%arg0: memref<*xf32>, %arg1: i32, %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32) -> f32 { +// CHECK: %0 = call @_sum_combine__fp32(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6) : (memref<*xf32>, i32, i32, i32, i32, i32, i32) -> f32 +// CHECK: return %0 : f32 +// CHECK: } +// CHECK: } diff --git a/test/Conversion/TritonPtrToMemref/func.mlir b/test/Conversion/TritonPtrToMemref/func.mlir new file mode 100644 index 00000000..55753894 --- /dev/null +++ b/test/Conversion/TritonPtrToMemref/func.mlir @@ -0,0 +1,27 @@ +// RUN: triton-shared-opt --triton-ptr-to-memref %s | FileCheck %s + +module { + func.func public @add_kernel_01234(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr, %arg3: i32) { + %c1024_i32 = arith.constant 1024 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + %3 = tt.splat %1 : i32 -> tensor<1024xi32> + %4 = arith.addi %3, %2 : tensor<1024xi32> + %5 = tt.splat %arg3 : i32 -> tensor<1024xi32> + %6 = arith.cmpi slt, %4, %5 : tensor<1024xi32> + %7 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr> + %8 = tt.addptr %7, %4 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %9 = tt.load %8, %6 : tensor<1024x!tt.ptr> + %10 = tt.splat %arg1 : !tt.ptr -> tensor<1024x!tt.ptr> + %11 = tt.addptr %10, %4 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %12 = tt.load %11, %6 : tensor<1024x!tt.ptr> + %13 = arith.addf %9, %12 : tensor<1024xf32> + %14 = tt.splat %arg2 : !tt.ptr -> tensor<1024x!tt.ptr> + %15 = tt.addptr %14, %4 : tensor<1024x!tt.ptr>, tensor<1024xi32> + tt.store %15, %13, %6 : tensor<1024x!tt.ptr> + return + } +} + +// CHECK: func.func public @add_kernel_01234(%arg0: memref<*xf32>, %arg1: memref<*xf32>, %arg2: memref<*xf32>, %arg3: i32) diff --git a/test/Conversion/TritonPtrToMemref/post_structured_to_memref.mlir b/test/Conversion/TritonPtrToMemref/post_structured_to_memref.mlir new file mode 100644 index 00000000..53ca403d --- /dev/null +++ b/test/Conversion/TritonPtrToMemref/post_structured_to_memref.mlir @@ -0,0 +1,41 @@ +// RUN: triton-shared-opt --triton-ptr-to-memref %s | FileCheck %s + +module { + func.func public @add_kernel_01234(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr, %arg3: i32) { + %0 = builtin.unrealized_conversion_cast %arg2 : !tt.ptr to memref<*xf32> + %1 = builtin.unrealized_conversion_cast %arg1 : !tt.ptr to memref<*xf32> + %2 = builtin.unrealized_conversion_cast %arg0 : !tt.ptr to memref<*xf32> + %c1024 = arith.constant 1024 : index + %c1024_i32 = arith.constant 1024 : i32 + %3 = tt.get_program_id x : i32 + %4 = arith.muli %3, %c1024_i32 : i32 + %5 = arith.index_cast %4 : i32 to index + %reinterpret_cast = memref.reinterpret_cast %2 to offset: [%5], sizes: [1024], strides: [1] : memref<*xf32> to memref<1024xf32, strided<[1], offset: ?>> + %6 = arith.addi %5, %c1024 : index + %7 = arith.index_cast %arg3 : i32 to index + %8 = arith.minsi %6, %7 : index + %9 = arith.subi %8, %5 : index + %alloc = memref.alloc() : memref<1024xf32> + %subview = memref.subview %reinterpret_cast[0] [%9] [1] : memref<1024xf32, strided<[1], offset: ?>> to memref> + %subview_0 = memref.subview %alloc[0] [%9] [1] : memref<1024xf32> to memref> + memref.copy %subview, %subview_0 : memref> to memref> + %10 = bufferization.to_tensor %alloc restrict writable : memref<1024xf32> + %reinterpret_cast_1 = memref.reinterpret_cast %1 to offset: [%5], sizes: [1024], strides: [1] : memref<*xf32> to memref<1024xf32, strided<[1], offset: ?>> + %alloc_2 = memref.alloc() : memref<1024xf32> + %subview_3 = memref.subview %reinterpret_cast_1[0] [%9] [1] : memref<1024xf32, strided<[1], offset: ?>> to memref> + %subview_4 = memref.subview %alloc_2[0] [%9] [1] : memref<1024xf32> to memref> + memref.copy %subview_3, %subview_4 : memref> to memref> + %11 = bufferization.to_tensor %alloc_2 restrict writable : memref<1024xf32> + %12 = arith.addf %10, %11 : tensor<1024xf32> + %reinterpret_cast_5 = memref.reinterpret_cast %0 to offset: [%5], sizes: [1024], strides: [1] : memref<*xf32> to memref<1024xf32, strided<[1], offset: ?>> + %extracted_slice = tensor.extract_slice %12[0] [%9] [1] : tensor<1024xf32> to tensor + %subview_6 = memref.subview %reinterpret_cast_5[0] [%9] [1] : memref<1024xf32, strided<[1], offset: ?>> to memref> + bufferization.materialize_in_destination %extracted_slice in writable %subview_6 : (tensor, memref>) -> () + return + } +} + +// CHECK: func.func public @add_kernel_01234(%arg0: memref<*xf32>, %arg1: memref<*xf32>, %arg2: memref<*xf32>, %arg3: i32) +// CHECK-NOT: builtin.unrealized_conversion_cast %arg2 : memref<*xf32> to !tt.ptr +// CHECK-NOT: builtin.unrealized_conversion_cast %arg1 : memref<*xf32> to !tt.ptr +// CHECK-NOT: builtin.unrealized_conversion_cast %arg0 : memref<*xf32> to !tt.ptr diff --git a/test/Conversion/TritonPtrToMemref/post_triton_load_store_to_memref.mlir b/test/Conversion/TritonPtrToMemref/post_triton_load_store_to_memref.mlir new file mode 100644 index 00000000..8fd7383e --- /dev/null +++ b/test/Conversion/TritonPtrToMemref/post_triton_load_store_to_memref.mlir @@ -0,0 +1,68 @@ +// RUN: triton-shared-opt --triton-ptr-to-memref %s | FileCheck %s + +#map = affine_map<(d0) -> (d0)> +module { + func.func public @add_kernel_01234(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr, %arg3: i32) { + %0 = builtin.unrealized_conversion_cast %arg2 : !tt.ptr to memref<*xf32> + %1 = builtin.unrealized_conversion_cast %arg1 : !tt.ptr to memref<*xf32> + %2 = builtin.unrealized_conversion_cast %arg0 : !tt.ptr to memref<*xf32> + %c1024_i32 = arith.constant 1024 : i32 + %3 = tt.get_program_id x : i32 + %4 = arith.muli %3, %c1024_i32 : i32 + %5 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + %6 = tt.splat %4 : i32 -> tensor<1024xi32> + %7 = arith.addi %6, %5 : tensor<1024xi32> + %8 = tt.splat %arg3 : i32 -> tensor<1024xi32> + %9 = arith.cmpi slt, %7, %8 : tensor<1024xi32> + %cast = memref.cast %2 : memref<*xf32> to memref + %10 = bufferization.to_tensor %cast restrict : memref + %11 = tensor.empty() : tensor<1024xf32> + %12 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%7, %9 : tensor<1024xi32>, tensor<1024xi1>) outs(%11 : tensor<1024xf32>) { + ^bb0(%in: i32, %in_2: i1, %out: f32): + %17 = scf.if %in_2 -> (f32) { + %18 = arith.index_cast %in : i32 to index + %extracted = tensor.extract %10[%18] : tensor + scf.yield %extracted : f32 + } else { + %cst = arith.constant 0.000000e+00 : f32 + scf.yield %cst : f32 + } + linalg.yield %17 : f32 + } -> tensor<1024xf32> + %cast_0 = memref.cast %1 : memref<*xf32> to memref + %13 = bufferization.to_tensor %cast_0 restrict : memref + %14 = tensor.empty() : tensor<1024xf32> + %15 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%7, %9 : tensor<1024xi32>, tensor<1024xi1>) outs(%14 : tensor<1024xf32>) { + ^bb0(%in: i32, %in_2: i1, %out: f32): + %17 = scf.if %in_2 -> (f32) { + %18 = arith.index_cast %in : i32 to index + %extracted = tensor.extract %13[%18] : tensor + scf.yield %extracted : f32 + } else { + %cst = arith.constant 0.000000e+00 : f32 + scf.yield %cst : f32 + } + linalg.yield %17 : f32 + } -> tensor<1024xf32> + %16 = arith.addf %12, %15 : tensor<1024xf32> + %cast_1 = memref.cast %0 : memref<*xf32> to memref + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c1024 = arith.constant 1024 : index + affine.for %arg4 = 0 to 1024 { + %extracted = tensor.extract %9[%arg4] : tensor<1024xi1> + scf.if %extracted { + %extracted_2 = tensor.extract %16[%arg4] : tensor<1024xf32> + %extracted_3 = tensor.extract %7[%arg4] : tensor<1024xi32> + %17 = arith.index_cast %extracted_3 : i32 to index + memref.store %extracted_2, %cast_1[%17] : memref + } + } + return + } +} + +// CHECK: func.func public @add_kernel_01234(%arg0: memref<*xf32>, %arg1: memref<*xf32>, %arg2: memref<*xf32>, %arg3: i32) +// CHECK-NOT: builtin.unrealized_conversion_cast %arg2 : memref<*xf32> to !tt.ptr +// CHECK-NOT: builtin.unrealized_conversion_cast %arg1 : memref<*xf32> to !tt.ptr +// CHECK-NOT: builtin.unrealized_conversion_cast %arg0 : memref<*xf32> to !tt.ptr diff --git a/test/Conversion/TritonPtrToMemref/structured_ptr.mlir b/test/Conversion/TritonPtrToMemref/structured_ptr.mlir new file mode 100644 index 00000000..c040bfa2 --- /dev/null +++ b/test/Conversion/TritonPtrToMemref/structured_ptr.mlir @@ -0,0 +1,28 @@ +// RUN: triton-shared-opt --triton-ptr-to-memref %s | FileCheck %s + +module { + func.func public @add_kernel_01234(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr, %arg3: i32) { + %c1024 = arith.constant 1024 : index + %c1024_i32 = arith.constant 1024 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %2 = arith.index_cast %1 : i32 to index + %3 = tts.make_tptr %arg0 to sizes: [1024], strides: [1], offsets: [%2], shape: [0], order: [] : to tensor<1024x!tt.ptr> + %4 = arith.addi %2, %c1024 : index + %5 = arith.index_cast %arg3 : i32 to index + %6 = arith.minsi %4, %5 : index + %7 = arith.subi %6, %2 : index + %8 = "tts.load"(%3, %7) <{operandSegmentSizes = array, static_mask_dims = array}> : (tensor<1024x!tt.ptr>, index) -> tensor<1024xf32> + %9 = tts.make_tptr %arg1 to sizes: [1024], strides: [1], offsets: [%2], shape: [0], order: [] : to tensor<1024x!tt.ptr> + %10 = "tts.load"(%9, %7) <{operandSegmentSizes = array, static_mask_dims = array}> : (tensor<1024x!tt.ptr>, index) -> tensor<1024xf32> + %11 = arith.addf %8, %10 : tensor<1024xf32> + %12 = tts.make_tptr %arg2 to sizes: [1024], strides: [1], offsets: [%2], shape: [0], order: [] : to tensor<1024x!tt.ptr> + "tts.store"(%12, %11, %7) <{static_mask_dims = array}> : (tensor<1024x!tt.ptr>, tensor<1024xf32>, index) -> () + return + } +} + +// CHECK: func.func public @add_kernel_01234(%arg0: memref<*xf32>, %arg1: memref<*xf32>, %arg2: memref<*xf32>, %arg3: i32) +// CHECK-DAG: builtin.unrealized_conversion_cast %arg2 : memref<*xf32> to !tt.ptr +// CHECK-DAG: builtin.unrealized_conversion_cast %arg1 : memref<*xf32> to !tt.ptr +// CHECK-DAG: builtin.unrealized_conversion_cast %arg0 : memref<*xf32> to !tt.ptr diff --git a/test/Conversion/TritonPtrToMemref/triton_func.mlir b/test/Conversion/TritonPtrToMemref/triton_func.mlir new file mode 100644 index 00000000..bfeda3ff --- /dev/null +++ b/test/Conversion/TritonPtrToMemref/triton_func.mlir @@ -0,0 +1,27 @@ +// RUN: triton-shared-opt --triton-ptr-to-memref %s | FileCheck %s + +module { + tt.func public @add_kernel_01234(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr, %arg3: i32) { + %c1024_i32 = arith.constant 1024 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + %3 = tt.splat %1 : i32 -> tensor<1024xi32> + %4 = arith.addi %3, %2 : tensor<1024xi32> + %5 = tt.splat %arg3 : i32 -> tensor<1024xi32> + %6 = arith.cmpi slt, %4, %5 : tensor<1024xi32> + %7 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr> + %8 = tt.addptr %7, %4 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %9 = tt.load %8, %6 : tensor<1024x!tt.ptr> + %10 = tt.splat %arg1 : !tt.ptr -> tensor<1024x!tt.ptr> + %11 = tt.addptr %10, %4 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %12 = tt.load %11, %6 : tensor<1024x!tt.ptr> + %13 = arith.addf %9, %12 : tensor<1024xf32> + %14 = tt.splat %arg2 : !tt.ptr -> tensor<1024x!tt.ptr> + %15 = tt.addptr %14, %4 : tensor<1024x!tt.ptr>, tensor<1024xi32> + tt.store %15, %13, %6 : tensor<1024x!tt.ptr> + tt.return + } +} + +// CHECK: tt.func public @add_kernel_01234(%arg0: memref<*xf32>, %arg1: memref<*xf32>, %arg2: memref<*xf32>, %arg3: i32) diff --git a/tools/RegisterTritonSharedDialects.h b/tools/RegisterTritonSharedDialects.h index 1d58f6a9..f71ea8ad 100644 --- a/tools/RegisterTritonSharedDialects.h +++ b/tools/RegisterTritonSharedDialects.h @@ -16,6 +16,7 @@ #include "triton-shared/Conversion/StructuredToMemref/Passes.h" #include "triton-shared/Conversion/TritonArithToLinalg/Passes.h" +#include "triton-shared/Conversion/TritonPtrToMemref/Passes.h" #include "triton-shared/Conversion/TritonToLinalg/Passes.h" #include "triton-shared/Conversion/TritonToLinalgExperimental/Passes.h" #include "triton-shared/Conversion/TritonToStructured/Passes.h" @@ -46,6 +47,7 @@ inline void registerTritonSharedDialects(mlir::DialectRegistry ®istry) { mlir::triton::registerTritonToLinalgPass(); mlir::triton::registerTritonToLinalgExperimentalPass(); mlir::triton::registerTritonToStructuredPass(); + mlir::triton::registerTritonPtrToMemref(); mlir::triton::registerTritonToUnstructuredPasses(); mlir::triton::registerTritonArithToLinalgPasses(); mlir::triton::registerConvertTritonToTritonGPUPass();