Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,12 @@ getConstantIntValues(ArrayRef<OpFoldResult> ofrs);

/// Return true if `ofr` is constant integer equal to `value`.
bool isConstantIntValue(OpFoldResult ofr, int64_t value);
/// Return true if all of `ofrs` are constant integers equal to `value`.
bool areAllConstantIntValue(ArrayRef<OpFoldResult> ofrs, int64_t value);
/// Return true if all of `ofrs` are constant integers equal to the
/// corresponding value in `values`.
bool areConstantIntValues(ArrayRef<OpFoldResult> ofrs,
ArrayRef<int64_t> values);

/// Return true if ofr1 and ofr2 are the same integer constant attribute
/// values or the same SSA value. Ignore integer bitwitdh and type mismatch
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Operation.h"

Expand Down Expand Up @@ -636,6 +637,28 @@ struct InsertOpInterface
}
};

template <typename InsertOpTy>
static bool insertSliceOpRequiresRead(InsertOpTy insertSliceOp,
OpOperand &opOperand) {
// The source is always read.
if (opOperand == insertSliceOp.getSourceMutable())
return true;

// For the destination, it depends...
assert(opOperand == insertSliceOp.getDestMutable() && "expected dest");

// Dest is not read if it is entirely overwritten. E.g.:
// tensor.insert_slice %a into %t[0][10][1] : ... into tensor<10xf32>
bool allOffsetsZero =
llvm::all_of(insertSliceOp.getMixedOffsets(), isZeroIndex);
RankedTensorType destType = insertSliceOp.getDestType();
bool sizesMatchDestSizes =
areConstantIntValues(insertSliceOp.getMixedSizes(), destType.getShape());
bool allStridesOne =
areAllConstantIntValue(insertSliceOp.getMixedStrides(), 1);
return !(allOffsetsZero && sizesMatchDestSizes && allStridesOne);
}

/// Bufferization of tensor.insert_slice. Replace with a memory copy. Under
/// certain circumstances, this op can also be a no-op.
///
Expand All @@ -646,32 +669,8 @@ struct InsertSliceOpInterface
tensor::InsertSliceOp> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
RankedTensorType destType = insertSliceOp.getDestType();

// The source is always read.
if (opOperand == insertSliceOp.getSourceMutable())
return true;

// For the destination, it depends...
assert(opOperand == insertSliceOp.getDestMutable() && "expected dest");

// Dest is not read if it is entirely overwritten. E.g.:
// tensor.insert_slice %a into %t[0][10][1] : ... into tensor<10xf32>
bool allOffsetsZero =
llvm::all_of(insertSliceOp.getMixedOffsets(), [](OpFoldResult ofr) {
return isConstantIntValue(ofr, 0);
});
bool sizesMatchDestSizes = llvm::all_of(
llvm::enumerate(insertSliceOp.getMixedSizes()), [&](const auto &it) {
return getConstantIntValue(it.value()) ==
destType.getDimSize(it.index());
});
bool allStridesOne =
llvm::all_of(insertSliceOp.getMixedStrides(), [](OpFoldResult ofr) {
return isConstantIntValue(ofr, 1);
});
return !(allOffsetsZero && sizesMatchDestSizes && allStridesOne);
return insertSliceOpRequiresRead(cast<tensor::InsertSliceOp>(op),
opOperand);
}

LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
Expand Down Expand Up @@ -931,7 +930,8 @@ struct ParallelInsertSliceOpInterface

bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
return true;
return insertSliceOpRequiresRead(cast<tensor::ParallelInsertSliceOp>(op),
opOperand);
}

bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,6 @@ namespace mlir {
namespace tensor {
namespace {

static bool areAllConstantIntValue(ArrayRef<OpFoldResult> ofrs, int64_t value) {
return llvm::all_of(
ofrs, [&](OpFoldResult ofr) { return isConstantIntValue(ofr, value); });
}

/// Returns the number of shape sizes that is either dynamic or greater than 1.
static int64_t getNumGtOneDims(ArrayRef<int64_t> shape) {
return llvm::count_if(
Expand Down
17 changes: 17 additions & 0 deletions mlir/lib/Dialect/Utils/StaticValueUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "mlir/IR/Matchers.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/APSInt.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/MathExtras.h"

namespace mlir {
Expand Down Expand Up @@ -137,6 +138,22 @@ bool isConstantIntValue(OpFoldResult ofr, int64_t value) {
return val && *val == value;
}

/// Return true if all of `ofrs` are constant integers equal to `value`.
bool areAllConstantIntValue(ArrayRef<OpFoldResult> ofrs, int64_t value) {
return llvm::all_of(
ofrs, [&](OpFoldResult ofr) { return isConstantIntValue(ofr, value); });
}

/// Return true if all of `ofrs` are constant integers equal to the
/// corresponding value in `values`.
bool areConstantIntValues(ArrayRef<OpFoldResult> ofrs,
ArrayRef<int64_t> values) {
if (ofrs.size() != values.size())
return false;
std::optional<SmallVector<int64_t>> constOfrs = getConstantIntValues(ofrs);
return constOfrs && llvm::equal(constOfrs.value(), values);
}

/// Return true if ofr1 and ofr2 are the same integer constant attribute values
/// or the same SSA value.
/// Ignore integer bitwidth and type mismatch that come from the fact there is
Expand Down
15 changes: 15 additions & 0 deletions mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,21 @@ func.func @rank_reducing_parallel_insert_slice(%in: tensor<100xf32>, %out: tenso

// -----

// CHECK-LABEL: func.func @parallel_insert_full_slice_in_place
// CHECK-NOT: memref.alloc()
func.func @parallel_insert_full_slice_in_place(%2: tensor<2xf32>) -> tensor<2xf32> {
%cst = arith.constant 0.000000e+00 : f32
%3 = scf.forall (%arg0) in (1) shared_outs(%arg2 = %2) -> (tensor<2xf32>) {
%fill = linalg.fill ins(%cst : f32) outs(%arg2 : tensor<2xf32>) -> tensor<2xf32>
scf.forall.in_parallel {
tensor.parallel_insert_slice %fill into %arg2[0] [2] [1] : tensor<2xf32> into tensor<2xf32>
}
} {mapping = [#gpu.thread<linear_dim_0>]}
return %3 : tensor<2xf32>
}

// -----

// This test case could bufferize in-place with a better analysis. However, it
// is simpler to let the canonicalizer fold away the tensor.insert_slice.

Expand Down
Loading