Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -1918,6 +1918,11 @@ void populateDataLayoutPropagationPatterns(
RewritePatternSet &patterns,
const ControlPropagationFn &controlPackUnPackPropagation);

/// Patterns to sink extract slice across other operations.
void populateExtractSliceSinkingPatterns(
RewritePatternSet &patterns,
const ControlPropagationFn &controlPackUnPackPropagation);

/// Pattern to remove dead operands and results of `linalg.generic` operations.
/// This is a pattern wrapper for `deduplicateOperandsAndRemoveDeadResults`.
void populateEraseUnusedOperandsAndResultsPatterns(RewritePatternSet &patterns);
Expand Down
275 changes: 275 additions & 0 deletions mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/IR/Dominance.h"
#include "llvm/ADT/SetOperations.h"
Expand Down Expand Up @@ -1236,6 +1238,272 @@ struct PushDownUnPackThroughPadOp : public OpRewritePattern<tensor::PadOp> {
ControlPropagationFn controlFn;
};

// This struct contains infomation about extract_slice dims.
struct SliceDimInfo {
OpFoldResult offset;
OpFoldResult sliceSize;
OpFoldResult outputSize;
};

/// Return the first input extract slice operand, if present, for the current
/// generic op.
static FailureOr<OpOperand *> getSliceOperand(GenericOp genericOp) {
OpOperand *sliceOperand = nullptr;
for (auto operand : genericOp.getDpsInputOperands()) {
auto extractOp = operand->get().getDefiningOp<tensor::ExtractSliceOp>();
if (!extractOp)
continue;
sliceOperand = operand;
break;
}
if (!sliceOperand) {
return failure();
}
return sliceOperand;
}

// Return a map of dims that have partial slices on them so that other operands
// can use this information. Also return a bool mentioning if a reduction dim
// has a non full slice as that can be used to fold the original extract slice.
static FailureOr<llvm::DenseMap<int64_t, SliceDimInfo>>
getPartialSliceDimInfo(GenericOp genericOp, OpOperand *sliceOperand) {
tensor::ExtractSliceOp producerSliceOp =
sliceOperand->get().getDefiningOp<tensor::ExtractSliceOp>();
assert(producerSliceOp && "expect a valid ExtractSliceOp");
llvm::DenseMap<int64_t, SliceDimInfo> partialSliceDimMap;
SmallVector<OpFoldResult> offsets = producerSliceOp.getMixedOffsets();
SmallVector<OpFoldResult> sizes = producerSliceOp.getMixedSizes();

SmallVector<OpFoldResult> shape = getAsIndexOpFoldResult(
genericOp.getContext(), producerSliceOp.getSourceType().getShape());

for (auto [idx, expr] : llvm::enumerate(
genericOp.getMatchingIndexingMap(sliceOperand).getResults())) {
// If we have a full slice in a dimension then we dont need to add it to
// the partial slice map.
if (isConstantIntValue(offsets[idx], 0) &&
isEqualConstantIntOrValue(sizes[idx], shape[idx])) {
continue;
}
// We only support partial slices of AffineDimExprs so bail-out if thats not
// the case.
if (!isa<AffineDimExpr>(expr)) {
return failure();
}
SliceDimInfo sliceDimInfo{offsets[idx], sizes[idx], shape[idx]};
int64_t dimPos = cast<AffineDimExpr>(expr).getPosition();
partialSliceDimMap[dimPos] = sliceDimInfo;
}
// Next check if the dims with partial slice info are used in non
// AffineDimExpr in other operands and if they are then bail-out.
for (OpOperand &operand : genericOp->getOpOperands()) {
if (operand == *sliceOperand) {
continue;
}
AffineMap IndexingMap = genericOp.getMatchingIndexingMap(&operand);
if (llvm::any_of(IndexingMap.getResults(), [&](AffineExpr expr) {
if (isa<AffineDimExpr>(expr)) {
return false;
}
WalkResult status = expr.walk([&](AffineExpr expr) {
if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
if (partialSliceDimMap.contains(dimExpr.getPosition())) {
return WalkResult::interrupt();
}
}
return WalkResult::advance();
});
if (status.wasInterrupted()) {
return true;
}
return false;
})) {
return failure();
}
}
return partialSliceDimMap;
}

static FailureOr<std::tuple<GenericOp, Value>>
pushDownExtractSliceOpThroughGenericOp(RewriterBase &rewriter,
GenericOp genericOp,
ControlPropagationFn controlFn) {
if (genericOp.getNumResults() != 1)
return rewriter.notifyMatchFailure(
genericOp, "propagation through multi-result generic is unsupported.");
if (hasGatherSemantics(genericOp))
return rewriter.notifyMatchFailure(
genericOp,
"propagation through generic with gather semantics is unsupported.");
// Collect the sliced operand, if present.
auto maybeSliceOperand = getSliceOperand(genericOp);
if (failed(maybeSliceOperand))
return failure();
OpOperand *sliceOperand = *maybeSliceOperand;
unsigned OperandIndex = sliceOperand->getOperandNumber();

if (!controlFn(sliceOperand))
return failure();

tensor::ExtractSliceOp producerSliceOp =
sliceOperand->get().getDefiningOp<tensor::ExtractSliceOp>();
assert(producerSliceOp && "expect a valid ExtractSliceOp");

if (producerSliceOp.getSource().getType().getRank() !=
producerSliceOp.getResult().getType().getRank()) {
return rewriter.notifyMatchFailure(
genericOp,
"propagation of rank-reducing extract slice is unsupported.");
}

SmallVector<OpFoldResult> strides = producerSliceOp.getMixedStrides();
if (!areAllConstantIntValue(strides, 1))
return rewriter.notifyMatchFailure(
genericOp, "propagation of strided extract slice is unsupported.");

// check if we can support the propagation of this extractSlice
// through the generic op and if so return the dimensions that

auto maybePartialSliceDimMap =
getPartialSliceDimInfo(genericOp, sliceOperand);

if (failed(maybePartialSliceDimMap)) {
return failure();
}

auto partialSliceDimMap = *maybePartialSliceDimMap;

SmallVector<utils::IteratorType> iterators =
genericOp.getIteratorTypesArray();
bool hasPartialReductionDimSlice =
llvm::any_of(partialSliceDimMap, [&](const auto &slice) {
int64_t sliceDim = slice.first;
return iterators[sliceDim] == utils::IteratorType::reduction;
});

// Store the padding information as (dimPos, lowPad, highPad, PaddedShape).
Location loc = genericOp->getLoc();
AffineExpr dim0, dim1;
bindDims(rewriter.getContext(), dim0, dim1);
auto subMap = AffineMap::get(2, 0, {dim0 - dim1});
auto sub = [&](OpFoldResult v1, OpFoldResult v2) {
return affine::makeComposedFoldedAffineApply(rewriter, loc, subMap,
{v1, v2});
};

MLIRContext *ctx = genericOp.getContext();
SmallVector<Value> paddedInputs;
for (auto [idx, operand] : llvm::enumerate(genericOp.getDpsInputOperands())) {
if (idx == OperandIndex && !hasPartialReductionDimSlice) {
paddedInputs.push_back(producerSliceOp.getSource());
continue;
}
AffineMap IndexingMap = genericOp.getMatchingIndexingMap(operand);
SmallVector<OpFoldResult> operandLowPads(IndexingMap.getNumResults(),
getAsIndexOpFoldResult(ctx, 0));
SmallVector<OpFoldResult> operandHighPads(IndexingMap.getNumResults(),
getAsIndexOpFoldResult(ctx, 0));
for (auto [idx, expr] : llvm::enumerate(IndexingMap.getResults())) {
if (!isa<AffineDimExpr>(expr)) {
continue;
}
AffineDimExpr dimExpr = cast<AffineDimExpr>(expr);
if (!partialSliceDimMap.contains(dimExpr.getPosition())) {
continue;
}
SliceDimInfo sliceDimInfo = partialSliceDimMap[dimExpr.getPosition()];
operandLowPads[idx] = sliceDimInfo.offset;
operandHighPads[idx] =
sub(sub(sliceDimInfo.outputSize, sliceDimInfo.offset),
sliceDimInfo.sliceSize);
}
auto paddingValue = ub::PoisonOp::create(
rewriter, loc, getElementTypeOrSelf(operand->get().getType()));
auto paddedOperand = tensor::PadOp::create(
rewriter, loc, Type(), operand->get(), operandLowPads, operandHighPads,
paddingValue, /*nofold=*/false);
paddedInputs.push_back(paddedOperand);
}
AffineMap outputIndexingMap =
genericOp.getMatchingIndexingMap(genericOp.getDpsInitOperand(0));

auto outputShapeType =
llvm::cast<ShapedType>(genericOp.getDpsInitOperand(0)->get().getType());
SmallVector<OpFoldResult> OutputShape = llvm::map_to_vector(
outputShapeType.getShape(),
[&](int64_t sz) -> OpFoldResult { return rewriter.getIndexAttr(sz); });
SmallVector<OpFoldResult> newSizes = OutputShape;
SmallVector<OpFoldResult> outputLowPads(outputIndexingMap.getNumResults(),
getAsIndexOpFoldResult(ctx, 0));
SmallVector<OpFoldResult> outputHighPads(outputIndexingMap.getNumResults(),
getAsIndexOpFoldResult(ctx, 0));
SmallVector<OpFoldResult> newStrides(outputIndexingMap.getNumResults(),
getAsIndexOpFoldResult(ctx, 1));
for (auto [idx, expr] : llvm::enumerate(outputIndexingMap.getResults())) {
if (!isa<AffineDimExpr>(expr)) {
continue;
}
AffineDimExpr dimExpr = cast<AffineDimExpr>(expr);
if (!partialSliceDimMap.contains(dimExpr.getPosition())) {
continue;
}
SliceDimInfo sliceDimInfo = partialSliceDimMap[dimExpr.getPosition()];
outputLowPads[idx] = sliceDimInfo.offset;
outputHighPads[idx] = sub(sub(sliceDimInfo.outputSize, sliceDimInfo.offset),
sliceDimInfo.sliceSize);
OutputShape[idx] = sliceDimInfo.outputSize;
newSizes[idx] = sliceDimInfo.sliceSize;
}
Value newPadOutput;
auto outputElType =
getElementTypeOrSelf(genericOp.getDpsInits()[0].getType());
if (isGenericOutsNotUsed(genericOp)) {
newPadOutput =
tensor::EmptyOp::create(rewriter, loc, OutputShape, outputElType);
} else {
auto paddingValue = ub::PoisonOp::create(rewriter, loc, outputElType);
newPadOutput = tensor::PadOp::create(
rewriter, loc, Type(), genericOp.getDpsInits()[0], outputLowPads,
outputHighPads, paddingValue, /*nofold=*/false);
}

auto newGenericOp = linalg::GenericOp::create(
rewriter, loc, newPadOutput.getType(), paddedInputs, {newPadOutput},
genericOp.getIndexingMapsArray(), genericOp.getIteratorTypesArray(),
/*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp));
rewriter.cloneRegionBefore(genericOp.getRegion(), newGenericOp.getRegion(),
newGenericOp.getRegion().begin());

auto extractOp = tensor::ExtractSliceOp::create(
rewriter, loc,
newGenericOp.getTiedOpResult(newGenericOp.getDpsInitOperand(0)),
outputLowPads, newSizes, newStrides);
Value extractRes = extractOp.getResult();

return std::make_tuple(newGenericOp, extractRes);
}

class PushDownExtractSliceOpThroughGenericOp final
: public OpRewritePattern<GenericOp> {
public:
PushDownExtractSliceOpThroughGenericOp(MLIRContext *context,
ControlPropagationFn fun)
: OpRewritePattern<GenericOp>(context), controlFn(std::move(fun)) {}

LogicalResult matchAndRewrite(GenericOp genericOp,
PatternRewriter &rewriter) const override {
auto genericAndRepl =
pushDownExtractSliceOpThroughGenericOp(rewriter, genericOp, controlFn);
if (failed(genericAndRepl))
return failure();
rewriter.replaceOp(genericOp, std::get<1>(*genericAndRepl));
return success();
}

private:
ControlPropagationFn controlFn;
};

} // namespace

void mlir::linalg::populateDataLayoutPropagationPatterns(
Expand All @@ -1247,3 +1515,10 @@ void mlir::linalg::populateDataLayoutPropagationPatterns(
PushDownUnPackThroughPadOp, PushDownUnPackOpThroughReshapeOp>(
patterns.getContext(), controlPackUnPackPropagation);
}

void mlir::linalg::populateExtractSliceSinkingPatterns(
RewritePatternSet &patterns,
const ControlPropagationFn &controlPackUnPackPropagation) {
patterns.insert<PushDownExtractSliceOpThroughGenericOp>(
patterns.getContext(), controlPackUnPackPropagation);
}
Loading