Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/triton-shared/Conversion/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ add_subdirectory(TritonToStructured)
add_subdirectory(TritonArithToLinalg)
add_subdirectory(TritonToUnstructured)
add_subdirectory(StructuredToMemref)
add_subdirectory(TritonPtrToMemref)
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls --name TritonPtrToMemref)
add_public_tablegen_target(TritonPtrToMemrefConversionPassIncGen)
15 changes: 15 additions & 0 deletions include/triton-shared/Conversion/TritonPtrToMemref/Passes.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#ifndef TRITON_PTR_TO_MEMREF_CONVERSION_PASSES_H
#define TRITON_PTR_TO_MEMREF_CONVERSION_PASSES_H

#include "triton-shared/Conversion/TritonPtrToMemref/TritonPtrToMemref.h"

namespace mlir {
namespace triton {

#define GEN_PASS_REGISTRATION
#include "triton-shared/Conversion/TritonPtrToMemref/Passes.h.inc"

} // namespace triton
} // namespace mlir

#endif
11 changes: 11 additions & 0 deletions include/triton-shared/Conversion/TritonPtrToMemref/Passes.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#ifndef TRITON_PTR_TO_MEMREF_CONVERSION_PASSES
#define TRITON_PTR_TO_MEMREF_CONVERSION_PASSES

include "mlir/Pass/PassBase.td"

def TritonPtrToMemref : Pass<"triton-ptr-to-memref", "mlir::ModuleOp"> {
let summary = "Convert triton pointer to unranked memref";
let constructor = "triton::createTritonPtrToMemrefPass()";
}

#endif
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#ifndef TRITON_CONVERSION_TRITON_PTR_TO_MEMREF_TRITON_PTR_TO_MEMREF_H
#define TRITON_CONVERSION_TRITON_PTR_TO_MEMREF_TRITON_PTR_TO_MEMREF_H

#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"

#include "triton/Dialect/Triton/IR/Dialect.h"

namespace mlir {
namespace triton {

std::unique_ptr<OperationPass<ModuleOp>> createTritonPtrToMemrefPass();

} // namespace triton
} // namespace mlir

#endif // TRITON_CONVERSION_TRITON_PTR_TO_MEMREF_TRITON_PTR_TO_MEMREF_H
1 change: 1 addition & 0 deletions lib/Conversion/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ add_subdirectory(TritonToStructured)
add_subdirectory(TritonToUnstructured)
add_subdirectory(TritonArithToLinalg)
add_subdirectory(StructuredToMemref)
add_subdirectory(TritonPtrToMemref)
18 changes: 18 additions & 0 deletions lib/Conversion/TritonPtrToMemref/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
add_triton_library(TritonPtrToMemref
TritonPtrToMemrefPass.cpp

DEPENDS
TritonPtrToMemrefConversionPassIncGen

LINK_LIBS PUBLIC
MLIRArithDialect
MLIRDialectUtils
MLIRIR
MLIRMathDialect
MLIRPass
MLIRTensorDialect
MLIRTransforms
MLIRSupport
MLIRReconcileUnrealizedCasts
TritonIR
)
124 changes: 124 additions & 0 deletions lib/Conversion/TritonPtrToMemref/TritonPtrToMemrefPass.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
//===----------------------------------------------------------------------===//
//
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Types.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/Passes.h"
#include "triton-shared/Analysis/OpFoldResultUtils.h"
#include "triton-shared/AnalysisStructured/PtrAnalysis.h"
#include "triton-shared/Conversion/TritonPtrToMemref/TritonPtrToMemref.h"
#include "triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.h"

#include "triton/Dialect/Triton/IR/Dialect.h"

#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "triton/Dialect/Triton/IR/Types.h"

#define DEBUG_TYPE "triton-ptr-to-memref"

using namespace mlir;
using namespace triton;

#define GEN_PASS_CLASSES
#include "triton-shared/Conversion/TritonPtrToMemref/Passes.h.inc"

namespace {

class TritonFunctionSignatureConverter : public TypeConverter {
public:
TritonFunctionSignatureConverter() {
// The order of type conversion is important: later ones are tried earlier.
addConversion([](Type type) { return type; });
addConversion([](triton::PointerType ptrType) {
return UnrankedMemRefType::get(ptrType.getPointeeType(),
/*memorySpace=*/0);
});
addConversion([](RankedTensorType tensorType) -> std::optional<Type> {
if (auto ptrType =
dyn_cast<triton::PointerType>(tensorType.getElementType())) {
return MemRefType::get(tensorType.getShape(), ptrType.getPointeeType());
}
return std::nullopt;
});

auto createUnrealizedCast = [&](OpBuilder &builder, Type resultType,
ValueRange inputs,
Location loc) -> std::optional<Value> {
return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
.getResult(0);
};
addSourceMaterialization(createUnrealizedCast);
addArgumentMaterialization(createUnrealizedCast);
}
};

class TritonPtrToMemrefPass
: public TritonPtrToMemrefBase<TritonPtrToMemrefPass> {

public:
void getDependentDialects(DialectRegistry &registry) const override {
registry
.insert<arith::ArithDialect, math::MathDialect, affine::AffineDialect,
scf::SCFDialect, tensor::TensorDialect, triton::TritonDialect,
tts::TritonStructuredDialect>();
}

void runOnOperation() override {
auto moduleOp = getOperation();

RewritePatternSet patterns(&getContext());
ConversionTarget target(getContext());
TritonFunctionSignatureConverter typeConverter;

// Update function signature and call ops to use memrefs
target.addDynamicallyLegalOp<func::FuncOp, triton::FuncOp>([&](auto op) {
return typeConverter.isSignatureLegal(
cast<FunctionType>(cast<FunctionOpInterface>(op).getFunctionType()));
});

target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
return typeConverter.isLegal(op.getResultTypes()) &&
typeConverter.isLegal(op.getOperandTypes());
});

populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
patterns, typeConverter);
populateFunctionOpInterfaceTypeConversionPattern<triton::FuncOp>(
patterns, typeConverter);
populateCallOpTypeConversionPattern(patterns, typeConverter);

if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) {
signalPassFailure();
}

PassManager pm(&getContext(), moduleOp.getOperationName());
pm.addPass(createCanonicalizerPass());
pm.addPass(createCSEPass());
if (failed(runPipeline(pm, getOperation()))) {
signalPassFailure();
}
}
};
} // namespace

std::unique_ptr<OperationPass<ModuleOp>> triton::createTritonPtrToMemrefPass() {
return std::make_unique<TritonPtrToMemrefPass>();
}
23 changes: 23 additions & 0 deletions test/Conversion/TritonPtrToMemref/call.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// RUN: triton-shared-opt --triton-arith-to-linalg --triton-ptr-to-memref %s | FileCheck %s

module {
tt.func @_sum_combine__fp32(%arg0: !tt.ptr<f32>) -> f32{
%0 = arith.constant 42.0 : f32
tt.return %0 : f32
}
tt.func @test(%arg0: !tt.ptr<f32>) -> f32{
%0 = tt.call @_sum_combine__fp32(%arg0) : (!tt.ptr<f32>) -> f32
tt.return %0 : f32
}
}

// CHECK: module {
// CHECK: func.func @_sum_combine__fp32(%arg0: memref<*xf32>, %arg1: i32, %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32) -> f32 {
// CHECK: %cst = arith.constant 4.200000e+01 : f32
// CHECK: return %cst : f32
// CHECK: }
// CHECK: func.func @test(%arg0: memref<*xf32>, %arg1: i32, %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32) -> f32 {
// CHECK: %0 = call @_sum_combine__fp32(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6) : (memref<*xf32>, i32, i32, i32, i32, i32, i32) -> f32
// CHECK: return %0 : f32
// CHECK: }
// CHECK: }
27 changes: 27 additions & 0 deletions test/Conversion/TritonPtrToMemref/func.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// RUN: triton-shared-opt --triton-ptr-to-memref %s | FileCheck %s

module {
func.func public @add_kernel_01234(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>, %arg2: !tt.ptr<f32>, %arg3: i32) {
%c1024_i32 = arith.constant 1024 : i32
%0 = tt.get_program_id x : i32
%1 = arith.muli %0, %c1024_i32 : i32
%2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32>
%3 = tt.splat %1 : i32 -> tensor<1024xi32>
%4 = arith.addi %3, %2 : tensor<1024xi32>
%5 = tt.splat %arg3 : i32 -> tensor<1024xi32>
%6 = arith.cmpi slt, %4, %5 : tensor<1024xi32>
%7 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
%8 = tt.addptr %7, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
%9 = tt.load %8, %6 : tensor<1024x!tt.ptr<f32>>
%10 = tt.splat %arg1 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
%11 = tt.addptr %10, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
%12 = tt.load %11, %6 : tensor<1024x!tt.ptr<f32>>
%13 = arith.addf %9, %12 : tensor<1024xf32>
%14 = tt.splat %arg2 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
%15 = tt.addptr %14, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
tt.store %15, %13, %6 : tensor<1024x!tt.ptr<f32>>
return
}
}

// CHECK: func.func public @add_kernel_01234(%arg0: memref<*xf32>, %arg1: memref<*xf32>, %arg2: memref<*xf32>, %arg3: i32)
41 changes: 41 additions & 0 deletions test/Conversion/TritonPtrToMemref/post_structured_to_memref.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// RUN: triton-shared-opt --triton-ptr-to-memref %s | FileCheck %s

module {
func.func public @add_kernel_01234(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>, %arg2: !tt.ptr<f32>, %arg3: i32) {
%0 = builtin.unrealized_conversion_cast %arg2 : !tt.ptr<f32> to memref<*xf32>
%1 = builtin.unrealized_conversion_cast %arg1 : !tt.ptr<f32> to memref<*xf32>
%2 = builtin.unrealized_conversion_cast %arg0 : !tt.ptr<f32> to memref<*xf32>
%c1024 = arith.constant 1024 : index
%c1024_i32 = arith.constant 1024 : i32
%3 = tt.get_program_id x : i32
%4 = arith.muli %3, %c1024_i32 : i32
%5 = arith.index_cast %4 : i32 to index
%reinterpret_cast = memref.reinterpret_cast %2 to offset: [%5], sizes: [1024], strides: [1] : memref<*xf32> to memref<1024xf32, strided<[1], offset: ?>>
%6 = arith.addi %5, %c1024 : index
%7 = arith.index_cast %arg3 : i32 to index
%8 = arith.minsi %6, %7 : index
%9 = arith.subi %8, %5 : index
%alloc = memref.alloc() : memref<1024xf32>
%subview = memref.subview %reinterpret_cast[0] [%9] [1] : memref<1024xf32, strided<[1], offset: ?>> to memref<?xf32, strided<[1], offset: ?>>
%subview_0 = memref.subview %alloc[0] [%9] [1] : memref<1024xf32> to memref<?xf32, strided<[1]>>
memref.copy %subview, %subview_0 : memref<?xf32, strided<[1], offset: ?>> to memref<?xf32, strided<[1]>>
%10 = bufferization.to_tensor %alloc restrict writable : memref<1024xf32>
%reinterpret_cast_1 = memref.reinterpret_cast %1 to offset: [%5], sizes: [1024], strides: [1] : memref<*xf32> to memref<1024xf32, strided<[1], offset: ?>>
%alloc_2 = memref.alloc() : memref<1024xf32>
%subview_3 = memref.subview %reinterpret_cast_1[0] [%9] [1] : memref<1024xf32, strided<[1], offset: ?>> to memref<?xf32, strided<[1], offset: ?>>
%subview_4 = memref.subview %alloc_2[0] [%9] [1] : memref<1024xf32> to memref<?xf32, strided<[1]>>
memref.copy %subview_3, %subview_4 : memref<?xf32, strided<[1], offset: ?>> to memref<?xf32, strided<[1]>>
%11 = bufferization.to_tensor %alloc_2 restrict writable : memref<1024xf32>
%12 = arith.addf %10, %11 : tensor<1024xf32>
%reinterpret_cast_5 = memref.reinterpret_cast %0 to offset: [%5], sizes: [1024], strides: [1] : memref<*xf32> to memref<1024xf32, strided<[1], offset: ?>>
%extracted_slice = tensor.extract_slice %12[0] [%9] [1] : tensor<1024xf32> to tensor<?xf32>
%subview_6 = memref.subview %reinterpret_cast_5[0] [%9] [1] : memref<1024xf32, strided<[1], offset: ?>> to memref<?xf32, strided<[1], offset: ?>>
bufferization.materialize_in_destination %extracted_slice in writable %subview_6 : (tensor<?xf32>, memref<?xf32, strided<[1], offset: ?>>) -> ()
return
}
}

// CHECK: func.func public @add_kernel_01234(%arg0: memref<*xf32>, %arg1: memref<*xf32>, %arg2: memref<*xf32>, %arg3: i32)
// CHECK-NOT: builtin.unrealized_conversion_cast %arg2 : memref<*xf32> to !tt.ptr<f32>
// CHECK-NOT: builtin.unrealized_conversion_cast %arg1 : memref<*xf32> to !tt.ptr<f32>
// CHECK-NOT: builtin.unrealized_conversion_cast %arg0 : memref<*xf32> to !tt.ptr<f32>
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
// RUN: triton-shared-opt --triton-ptr-to-memref %s | FileCheck %s

#map = affine_map<(d0) -> (d0)>
module {
func.func public @add_kernel_01234(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>, %arg2: !tt.ptr<f32>, %arg3: i32) {
%0 = builtin.unrealized_conversion_cast %arg2 : !tt.ptr<f32> to memref<*xf32>
%1 = builtin.unrealized_conversion_cast %arg1 : !tt.ptr<f32> to memref<*xf32>
%2 = builtin.unrealized_conversion_cast %arg0 : !tt.ptr<f32> to memref<*xf32>
%c1024_i32 = arith.constant 1024 : i32
%3 = tt.get_program_id x : i32
%4 = arith.muli %3, %c1024_i32 : i32
%5 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32>
%6 = tt.splat %4 : i32 -> tensor<1024xi32>
%7 = arith.addi %6, %5 : tensor<1024xi32>
%8 = tt.splat %arg3 : i32 -> tensor<1024xi32>
%9 = arith.cmpi slt, %7, %8 : tensor<1024xi32>
%cast = memref.cast %2 : memref<*xf32> to memref<?xf32>
%10 = bufferization.to_tensor %cast restrict : memref<?xf32>
%11 = tensor.empty() : tensor<1024xf32>
%12 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%7, %9 : tensor<1024xi32>, tensor<1024xi1>) outs(%11 : tensor<1024xf32>) {
^bb0(%in: i32, %in_2: i1, %out: f32):
%17 = scf.if %in_2 -> (f32) {
%18 = arith.index_cast %in : i32 to index
%extracted = tensor.extract %10[%18] : tensor<?xf32>
scf.yield %extracted : f32
} else {
%cst = arith.constant 0.000000e+00 : f32
scf.yield %cst : f32
}
linalg.yield %17 : f32
} -> tensor<1024xf32>
%cast_0 = memref.cast %1 : memref<*xf32> to memref<?xf32>
%13 = bufferization.to_tensor %cast_0 restrict : memref<?xf32>
%14 = tensor.empty() : tensor<1024xf32>
%15 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%7, %9 : tensor<1024xi32>, tensor<1024xi1>) outs(%14 : tensor<1024xf32>) {
^bb0(%in: i32, %in_2: i1, %out: f32):
%17 = scf.if %in_2 -> (f32) {
%18 = arith.index_cast %in : i32 to index
%extracted = tensor.extract %13[%18] : tensor<?xf32>
scf.yield %extracted : f32
} else {
%cst = arith.constant 0.000000e+00 : f32
scf.yield %cst : f32
}
linalg.yield %17 : f32
} -> tensor<1024xf32>
%16 = arith.addf %12, %15 : tensor<1024xf32>
%cast_1 = memref.cast %0 : memref<*xf32> to memref<?xf32>
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c1024 = arith.constant 1024 : index
affine.for %arg4 = 0 to 1024 {
%extracted = tensor.extract %9[%arg4] : tensor<1024xi1>
scf.if %extracted {
%extracted_2 = tensor.extract %16[%arg4] : tensor<1024xf32>
%extracted_3 = tensor.extract %7[%arg4] : tensor<1024xi32>
%17 = arith.index_cast %extracted_3 : i32 to index
memref.store %extracted_2, %cast_1[%17] : memref<?xf32>
}
}
return
}
}

// CHECK: func.func public @add_kernel_01234(%arg0: memref<*xf32>, %arg1: memref<*xf32>, %arg2: memref<*xf32>, %arg3: i32)
// CHECK-NOT: builtin.unrealized_conversion_cast %arg2 : memref<*xf32> to !tt.ptr<f32>
// CHECK-NOT: builtin.unrealized_conversion_cast %arg1 : memref<*xf32> to !tt.ptr<f32>
// CHECK-NOT: builtin.unrealized_conversion_cast %arg0 : memref<*xf32> to !tt.ptr<f32>
28 changes: 28 additions & 0 deletions test/Conversion/TritonPtrToMemref/structured_ptr.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// RUN: triton-shared-opt --triton-ptr-to-memref %s | FileCheck %s

module {
func.func public @add_kernel_01234(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>, %arg2: !tt.ptr<f32>, %arg3: i32) {
%c1024 = arith.constant 1024 : index
%c1024_i32 = arith.constant 1024 : i32
%0 = tt.get_program_id x : i32
%1 = arith.muli %0, %c1024_i32 : i32
%2 = arith.index_cast %1 : i32 to index
%3 = tts.make_tptr %arg0 to sizes: [1024], strides: [1], offsets: [%2], shape: [0], order: [] : <f32> to tensor<1024x!tt.ptr<f32>>
%4 = arith.addi %2, %c1024 : index
%5 = arith.index_cast %arg3 : i32 to index
%6 = arith.minsi %4, %5 : index
%7 = arith.subi %6, %2 : index
%8 = "tts.load"(%3, %7) <{operandSegmentSizes = array<i32: 1, 1, 0>, static_mask_dims = array<i64: -9223372036854775808>}> : (tensor<1024x!tt.ptr<f32>>, index) -> tensor<1024xf32>
%9 = tts.make_tptr %arg1 to sizes: [1024], strides: [1], offsets: [%2], shape: [0], order: [] : <f32> to tensor<1024x!tt.ptr<f32>>
%10 = "tts.load"(%9, %7) <{operandSegmentSizes = array<i32: 1, 1, 0>, static_mask_dims = array<i64: -9223372036854775808>}> : (tensor<1024x!tt.ptr<f32>>, index) -> tensor<1024xf32>
%11 = arith.addf %8, %10 : tensor<1024xf32>
%12 = tts.make_tptr %arg2 to sizes: [1024], strides: [1], offsets: [%2], shape: [0], order: [] : <f32> to tensor<1024x!tt.ptr<f32>>
"tts.store"(%12, %11, %7) <{static_mask_dims = array<i64: -9223372036854775808>}> : (tensor<1024x!tt.ptr<f32>>, tensor<1024xf32>, index) -> ()
return
}
}

// CHECK: func.func public @add_kernel_01234(%arg0: memref<*xf32>, %arg1: memref<*xf32>, %arg2: memref<*xf32>, %arg3: i32)
// CHECK-DAG: builtin.unrealized_conversion_cast %arg2 : memref<*xf32> to !tt.ptr<f32>
// CHECK-DAG: builtin.unrealized_conversion_cast %arg1 : memref<*xf32> to !tt.ptr<f32>
// CHECK-DAG: builtin.unrealized_conversion_cast %arg0 : memref<*xf32> to !tt.ptr<f32>
Loading
Loading