Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 14 additions & 76 deletions lib/Conversion/StructuredToMemref/StructuredToMemrefPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "triton/Dialect/Triton/IR/Dialect.h"

#include "triton-shared/Conversion/StructuredToMemref/StructuredToMemref.h"
#include "triton-shared/Dialect/TPtr/IR/TPtrDialect.h"
#include "triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.h"
#include "triton-shared/Dialect/TritonTilingExt/IR/TritonTilingExtDialect.h"

Expand All @@ -24,12 +25,7 @@
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
#include "mlir/Pass/PassManager.h"
#include "triton/Dialect/Triton/IR/Types.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Casting.h"

#include <optional>

#define DEBUG_TYPE "structured-to-memref"

Expand All @@ -45,67 +41,6 @@ namespace triton {

namespace {

class LoopTypeConverter : public TypeConverter {
public:
LoopTypeConverter(MLIRContext *context) {
// The order of type conversion is important: later ones are tried earlier.
addConversion([](Type type) { return type; });
// addConversion([context](triton::PointerType ptrType) {
// SmallVector<int64_t> strides{1};
// auto layout =
// StridedLayoutAttr::get(context, ShapedType::kDynamic, strides);

// auto elemType = ptrType.getPointeeType();
// auto memrefType = MemRefType::get({1}, elemType, layout);
// return memrefType;
// });

// A tensor of pointers can be passed in as scf.for's init-args, in such
// cases, we convert the type to a memref with dynamic offsets and
// strides.
addConversion(
[context](RankedTensorType tensorType) -> std::optional<MemRefType> {
if (auto ptrType = llvm::dyn_cast<triton::PointerType>(
tensorType.getElementType())) {
auto layout = StridedLayoutAttr::get(
context, ShapedType::kDynamic,
SmallVector<int64_t>(tensorType.getRank(),
ShapedType::kDynamic));
Type elemType = ptrType.getPointeeType();
return MemRefType::get(tensorType.getShape(), elemType, layout);
}

return std::nullopt;
});

// Convert the current memref type to a memref type with dynamic offsets and
// strides through another reinterpret_cast with the same offsets.
// Canonicalization will simplify this sequence by removing the inital
// reinterpret_cast.
addTargetMaterialization([&](OpBuilder &builder, MemRefType memrefType,
ValueRange inputs,
Location loc) -> Value {
auto reinterpretCast =
inputs[0].getDefiningOp<memref::ReinterpretCastOp>();
if (!reinterpretCast) {
return builder
.create<UnrealizedConversionCastOp>(loc, memrefType, inputs)
.getResult(0);
}
return builder.create<memref::ReinterpretCastOp>(
loc, memrefType, inputs[0], reinterpretCast.getMixedOffsets()[0],
reinterpretCast.getMixedSizes(), reinterpretCast.getMixedStrides());
});

addSourceMaterialization([&](OpBuilder &builder, Type resultType,
ValueRange inputs,
Location loc) -> Value {
return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
.getResult(0);
});
}
};

class PtrToUnrankedMemrefConverter : public TypeConverter {
public:
PtrToUnrankedMemrefConverter() {
Expand All @@ -120,6 +55,13 @@ class PtrToUnrankedMemrefConverter : public TypeConverter {
return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
.getResult(0);
});

addSourceMaterialization([&](OpBuilder &builder, Type resultType,
ValueRange inputs,
Location loc) -> Value {
return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
.getResult(0);
});
}
};

Expand All @@ -129,11 +71,12 @@ class StructuredToMemrefPass

public:
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<func::FuncDialect, arith::ArithDialect, math::MathDialect,
linalg::LinalgDialect, affine::AffineDialect,
scf::SCFDialect, tensor::TensorDialect,
bufferization::BufferizationDialect, triton::TritonDialect,
ttx::TritonTilingExtDialect, memref::MemRefDialect>();
registry
.insert<tptr::TPtrDialect, func::FuncDialect, arith::ArithDialect,
math::MathDialect, linalg::LinalgDialect, affine::AffineDialect,
scf::SCFDialect, tensor::TensorDialect,
bufferization::BufferizationDialect, triton::TritonDialect,
ttx::TritonTilingExtDialect, memref::MemRefDialect>();
}

void runOnOperation() override {
Expand All @@ -158,11 +101,6 @@ class StructuredToMemrefPass
triton::populateStructuredToMemrefConversionPatterns(patterns,
typeConverter);

LoopTypeConverter loopTypeConverter(patterns.getContext());

mlir::scf::populateSCFStructuralTypeConversionsAndLegality(
loopTypeConverter, patterns, target);

if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) {
signalPassFailure();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,22 +60,17 @@ class TritonToLinalgExperimentalPass
pm.addPass(createTritonToUnstructuredPass());
pm.addPass(createTritonArithToLinalgPass(/*tensorPtrToLinalg=*/true));

// TODO: structured-to-memref converts the loop iter-args to memref, while
// triton-to-ptr converts the loop iter-args to ptr. These two passes might
// end up conflicting with each other in cases where we have a mixed of
// structured and unstructured accesses. Fortunately, the structured ops do
// not need to use the loop iter-args at all (see code in PtrAnalysis.cpp),
// so if we run remove-dead-values after structured-to-memref, the memref
// iter-args that are used in structured loads and stores should be removed.
// Running this now may be too invasive and cause many IR changes, so
// leave as a TODO for now.
pm.addPass(createStructuredToMemrefPass());
pm.addPass(createUnstructuredToMemrefPass());
pm.addPass(createTritonPtrToMemrefPass());
pm.addPass(createTritonToPtrPass());
pm.addPass(createReconcileUnrealizedCastsPass());
pm.addPass(createReconcilePtrCastsPass());

// Now that remove-dead-values fully works with linalg ops, clean up the IR
// again, particularly unused loop iter-args that were created
// during triton-to-structured.
pm.addPass(createRemoveDeadValuesPass());
pm.addPass(createCSEPass());
pm.addPass(createCanonicalizerPass());

Expand Down
10 changes: 4 additions & 6 deletions test/Conversion/StructuredToMemref/addptr_scalar_for.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: triton-shared-opt --split-input-file --triton-to-linalg-experimental %s | FileCheck %s
// RUN: triton-shared-opt --triton-to-linalg-experimental %s | FileCheck %s
module {
tt.func @kernel (%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: i32, %arg3: i32, %arg4: i32) {
%0 = tt.get_program_id x : i32
Expand Down Expand Up @@ -47,7 +47,7 @@ module {
// CHECK-DAG: [[VAR_1_:%.+]] = linalg.fill ins([[CST_0_dot_000000_]] : f32) outs([[VAR_0_]] : tensor<1024xf32>) -> tensor<1024xf32>
// CHECK-DAG: [[VAR_2_:%.+]] = arith.muli [[PARAM_8_]], [[PARAM_2_]] : i32
// CHECK: [[VAR_3_:%.+]] = arith.index_cast [[VAR_2_]] : i32 to index
// CHECK-DAG: [[VAR_4_:%.+]]:3 = scf.for [[VAR_arg11_:%.+]] = [[CST_0_]] to [[CST_12_]] step [[CST_3_]] iter_args([[VAR_arg12_:%.+]] = [[VAR_2_]], [[VAR_arg13_:%.+]] = [[VAR_3_]], [[VAR_arg14_:%.+]] = [[VAR_1_]]) -> (i32, index, tensor<1024xf32>) {
// CHECK-DAG: [[VAR_4_:%.+]]:2 = scf.for [[VAR_arg11_:%.+]] = [[CST_0_]] to [[CST_12_]] step [[CST_3_]] iter_args([[VAR_arg13_:%.+]] = [[VAR_3_]], [[VAR_arg14_:%.+]] = [[VAR_1_]]) -> (index, tensor<1024xf32>) {
// CHECK-DAG: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_arg13_]]{{.}}, sizes: [1024], strides: [1] : memref<*xf32> to memref<1024xf32, strided<[1], offset: ?>>
// CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<1024xf32>
// CHECK: memref.copy [[VAR_reinterpret_cast_0_]], [[RES_]] : memref<1024xf32, strided<[1], offset: ?>> to memref<1024xf32>
Expand All @@ -62,14 +62,12 @@ module {
// CHECK: [[VAR_13_1_:%.+]] = arith.addf [[IN_2_]], [[IN_3_]] : f32
// CHECK: linalg.yield [[VAR_13_1_]] : f32
// CHECK: } -> tensor<1024xf32>
// CHECK-DAG: [[VAR_10_:%.+]] = arith.index_cast [[VAR_arg11_]] : index to i32
// CHECK-DAG: [[VAR_11_:%.+]] = arith.addi [[VAR_arg13_]], [[VAR_arg11_]] : index
// CHECK: [[VAR_12_:%.+]] = arith.addi [[VAR_arg12_]], [[VAR_10_]] : i32
// CHECK: scf.yield [[VAR_12_]], [[VAR_11_]], [[VAR_9_]] : i32, index, tensor<1024xf32>
// CHECK: scf.yield [[VAR_11_]], [[VAR_9_]] : index, tensor<1024xf32>
// CHECK: }
// CHECK: [[VAR_5_:%.+]] = arith.muli [[PARAM_8_]], [[PARAM_3_]] : i32
// CHECK: [[VAR_6_:%.+]] = arith.index_cast [[VAR_5_]] : i32 to index
// CHECK: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_6_]]{{.}}, sizes: [1024], strides: [1] : memref<*xf32> to memref<1024xf32, strided<[1], offset: ?>>
// CHECK: bufferization.materialize_in_destination [[VAR_4_]]#2 in writable [[VAR_reinterpret_cast_]] : (tensor<1024xf32>, memref<1024xf32, strided<[1], offset: ?>>) -> ()
// CHECK: bufferization.materialize_in_destination [[VAR_4_]]#1 in writable [[VAR_reinterpret_cast_]] : (tensor<1024xf32>, memref<1024xf32, strided<[1], offset: ?>>) -> ()
// CHECK: return
// CHECK: }
8 changes: 3 additions & 5 deletions test/Conversion/StructuredToMemref/addptr_scalar_for_2d.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: triton-shared-opt --split-input-file --triton-to-linalg-experimental %s | FileCheck %s
// RUN: triton-shared-opt --triton-to-linalg-experimental %s | FileCheck %s
module {
tt.func @kernel (%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: i32, %arg3: i32, %arg4: i32) {
%0 = tt.get_program_id x : i32
Expand Down Expand Up @@ -67,7 +67,7 @@ module {
// CHECK-DAG: [[VAR_1_:%.+]] = linalg.fill ins([[CST_0_dot_000000_]] : f32) outs([[VAR_0_]] : tensor<128x128xf32>) -> tensor<128x128xf32>
// CHECK-DAG: [[VAR_2_:%.+]] = arith.muli [[PARAM_8_]], [[PARAM_2_]] : i32
// CHECK: [[VAR_3_:%.+]] = arith.index_cast [[VAR_2_]] : i32 to index
// CHECK-DAG: [[VAR_4_:%.+]]:3 = scf.for [[VAR_arg11_:%.+]] = [[CST_0_]] to [[CST_12_]] step [[CST_3_]] iter_args([[VAR_arg12_:%.+]] = [[VAR_1_]], [[VAR_arg13_:%.+]] = [[VAR_2_]], [[VAR_arg14_:%.+]] = [[VAR_3_]]) -> (tensor<128x128xf32>, i32, index) {
// CHECK-DAG: [[VAR_4_:%.+]]:2 = scf.for [[VAR_arg11_:%.+]] = [[CST_0_]] to [[CST_12_]] step [[CST_3_]] iter_args([[VAR_arg12_:%.+]] = [[VAR_1_]], [[VAR_arg14_:%.+]] = [[VAR_3_]]) -> (tensor<128x128xf32>, index) {
// CHECK-DAG: [[VAR_8_:%.+]] = arith.addi [[VAR_arg14_]], [[CST_128_]] : index
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_8_]]{{.}}, sizes: [128, 128], strides: [1, 1] : memref<*xf32> to memref<128x128xf32, strided<[1, 1], offset: ?>>
Expand All @@ -84,10 +84,8 @@ module {
// CHECK: [[VAR_15_1_:%.+]] = arith.addf [[IN_2_]], [[IN_3_]] : f32
// CHECK: linalg.yield [[VAR_15_1_]] : f32
// CHECK: } -> tensor<128x128xf32>
// CHECK-DAG: [[VAR_12_:%.+]] = arith.index_cast [[VAR_arg11_]] : index to i32
// CHECK-DAG: [[VAR_13_:%.+]] = arith.addi [[VAR_arg14_]], [[VAR_arg11_]] : index
// CHECK: [[VAR_14_:%.+]] = arith.addi [[VAR_arg13_]], [[VAR_12_]] : i32
// CHECK: scf.yield [[VAR_11_]], [[VAR_14_]], [[VAR_13_]] : tensor<128x128xf32>, i32, index
// CHECK: scf.yield [[VAR_11_]], [[VAR_13_]] : tensor<128x128xf32>, index
// CHECK: }
// CHECK: [[VAR_5_:%.+]] = arith.muli [[PARAM_8_]], [[PARAM_3_]] : i32
// CHECK: [[VAR_6_:%.+]] = arith.index_cast [[VAR_5_]] : i32 to index
Expand Down
4 changes: 1 addition & 3 deletions test/Conversion/StructuredToMemref/nested_loops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,4 @@ module {
// CHECK-NOT: tt.addptr
// CHECK-NOT: tt.load
// CHECK-NOT: tt.store

// CHECK-COUNT-20: memref.reinterpret_cast %arg{{[0-9]+}}
// CHECK-NOT: memref.reinterpret_cast %arg{{[0-9]+}}
// CHECK-NOT: unrealized_conversion_cast
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,4 @@ module {
// CHECK-NOT: tt.addptr
// CHECK-NOT: tt.load
// CHECK-NOT: tt.store

// CHECK-COUNT-43: memref.reinterpret_cast %arg{{[0-9]+}}
// CHECK-NOT: memref.reinterpret_cast %arg{{[0-9]+}}
// CHECK-NOT: unrealized_conversion_cast
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,14 @@ module {
// CHECK-DAG: [[CST_5_:%.+]] = arith.constant 5 : i32
// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : i32
// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : i32
// CHECK-DAG: [[VAR_0_:%.+]] = arith.index_cast [[PARAM_7_]] : i32 to index
// CHECK-DAG: [[VAR_1_:%.+]] = arith.sitofp [[PARAM_7_]] : i32 to f32
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[VAR_2_:%.+]]:3 = scf.for [[VAR_arg16_:%.+]] = [[CST_0_]] to [[CST_5_]] step [[CST_1_]] iter_args([[VAR_arg17_:%.+]] = [[PARAM_7_]], [[VAR_arg18_:%.+]] = [[VAR_0_]], [[VAR_arg19_:%.+]] = [[VAR_0_]]) -> (i32, index, index) : i32 {
// CHECK-DAG: [[VAR_2_:%.+]] = scf.for [[VAR_arg16_:%.+]] = [[CST_0_]] to [[CST_5_]] step [[CST_1_]] iter_args([[VAR_arg17_:%.+]] = [[PARAM_7_]]) -> (i32) : i32 {
// CHECK-DAG: [[VAR_3_:%.+]] = arith.index_cast [[VAR_arg17_]] : i32 to index
// CHECK: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_3_]]{{.}}, sizes: [1], strides: [1] : memref<*xf32> to memref<1xf32, strided<[1], offset: ?>>
// CHECK: memref.store [[VAR_1_]], [[VAR_reinterpret_cast_]][%[[C0]]] : memref<1xf32, strided<[1], offset: ?>>
// CHECK: [[VAR_4_:%.+]] = arith.index_cast [[VAR_arg16_]] : i32 to index
// CHECK-DAG: [[VAR_5_:%.+]] = arith.addi [[VAR_arg18_]], [[VAR_4_]] : index
// CHECK-DAG: [[VAR_6_:%.+]] = arith.addi [[VAR_arg17_]], [[VAR_arg16_]] : i32
// CHECK-DAG: [[VAR_7_:%.+]] = arith.addi [[VAR_arg19_]], [[VAR_4_]] : index
// CHECK: scf.yield [[VAR_6_]], [[VAR_5_]], [[VAR_7_]] : i32, index, index
// CHECK: scf.yield [[VAR_6_]] : i32
// CHECK: }
// CHECK: return
// CHECK: }
14 changes: 5 additions & 9 deletions test/Conversion/StructuredToMemref/scalar_store_nested_loop.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -24,26 +24,22 @@ module {
// CHECK-LABEL: func.func @reduce_kernel_2d_0d
// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<*xf32> {tt.divisibility = 16 : i32}, [[PARAM_1_:%.+]]: i32, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32, [[PARAM_8_:%.+]]: i32, [[PARAM_9_:%.+]]: i32, [[PARAM_10_:%.+]]: i32, [[PARAM_11_:%.+]]: i32, [[PARAM_12_:%.+]]: i32) {
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : index
// CHECK-DAG: [[CST_1_1_:%.+]] = arith.constant 1 : i32
// CHECK-DAG: [[CST_8_:%.+]] = arith.constant 8 : i32
// CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : i32
// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : i32
// CHECK-DAG: [[VAR_0_:%.+]] = arith.index_cast [[PARAM_4_]] : i32 to index
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[VAR_1_:%.+]]:2 = scf.for [[VAR_arg13_:%.+]] = [[CST_0_]] to [[CST_8_]] step [[CST_1_1_]] iter_args([[VAR_arg14_:%.+]] = [[PARAM_4_]], [[VAR_arg15_:%.+]] = [[VAR_0_]]) -> (i32, index) : i32 {
// CHECK-DAG: [[VAR_2_:%.+]]:2 = scf.for [[VAR_arg16_:%.+]] = [[CST_0_]] to [[CST_2_]] step [[CST_1_1_]] iter_args([[VAR_arg17_:%.+]] = [[VAR_arg14_]], [[VAR_arg18_:%.+]] = [[VAR_arg15_]]) -> (i32, index) : i32 {
// CHECK-DAG: [[VAR_1_:%.+]] = scf.for [[VAR_arg13_:%.+]] = [[CST_0_]] to [[CST_8_]] step [[CST_1_1_]] iter_args([[VAR_arg14_:%.+]] = [[PARAM_4_]]) -> (i32) : i32 {
// CHECK-DAG: [[VAR_2_:%.+]] = scf.for [[VAR_arg16_:%.+]] = [[CST_0_]] to [[CST_2_]] step [[CST_1_1_]] iter_args([[VAR_arg17_:%.+]] = [[VAR_arg14_]]) -> (i32) : i32 {
// CHECK-DAG: [[VAR_3_:%.+]] = arith.muli [[VAR_arg13_]], [[VAR_arg16_]] : i32
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[VAR_4_:%.+]] = arith.sitofp [[VAR_3_]] : i32 to f32
// CHECK-DAG: [[VAR_5_:%.+]] = arith.index_cast [[VAR_arg17_]] : i32 to index
// CHECK: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_5_]]{{.}}, sizes: [1], strides: [1] : memref<*xf32> to memref<1xf32, strided<[1], offset: ?>>
// CHECK: memref.store [[VAR_4_]], [[VAR_reinterpret_cast_]][%[[C0]]] : memref<1xf32, strided<[1], offset: ?>>
// CHECK-DAG: [[VAR_6_:%.+]] = arith.addi [[VAR_arg18_]], [[CST_1_]] : index
// CHECK-DAG: [[VAR_7_:%.+]] = arith.addi [[VAR_arg17_]], [[CST_1_1_]] : i32
// CHECK: scf.yield [[VAR_7_]], [[VAR_6_]] : i32, index
// CHECK: [[VAR_7_:%.+]] = arith.addi [[VAR_arg17_]], [[CST_1_1_]] : i32
// CHECK: scf.yield [[VAR_7_]] : i32
// CHECK: }
// CHECK: scf.yield [[VAR_2_]]#0, [[VAR_2_]]#1 : i32, index
// CHECK: scf.yield [[VAR_2_]] : i32
// CHECK: }
// CHECK: return
// CHECK: }
Loading