Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -1918,6 +1918,11 @@ void populateDataLayoutPropagationPatterns(
RewritePatternSet &patterns,
const ControlPropagationFn &controlPackUnPackPropagation);

/// Patterns to sink extract slice across other operations.
void populateExtractSliceSinkingPatterns(
RewritePatternSet &patterns,
const ControlPropagationFn &controlPackUnPackPropagation);

/// Pattern to remove dead operands and results of `linalg.generic` operations.
/// This is a pattern wrapper for `deduplicateOperandsAndRemoveDeadResults`.
void populateEraseUnusedOperandsAndResultsPatterns(RewritePatternSet &patterns);
Expand Down
275 changes: 275 additions & 0 deletions mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/IR/Dominance.h"
#include "llvm/ADT/SetOperations.h"
Expand Down Expand Up @@ -1236,6 +1238,272 @@ struct PushDownUnPackThroughPadOp : public OpRewritePattern<tensor::PadOp> {
ControlPropagationFn controlFn;
};

// This struct contains infomation about extract_slice dims.
struct SliceDimInfo {
OpFoldResult offset;
OpFoldResult sliceSize;
OpFoldResult outputSize;
};

/// Return the first input extract slice operand, if present, for the current
/// generic op.
static FailureOr<OpOperand *> getSliceOperand(GenericOp genericOp) {
OpOperand *sliceOperand = nullptr;
for (auto operand : genericOp.getDpsInputOperands()) {
auto extractOp = operand->get().getDefiningOp<tensor::ExtractSliceOp>();
if (!extractOp)
continue;
sliceOperand = operand;
break;
}
if (!sliceOperand) {
return failure();
}
return sliceOperand;
}

// Return a map of dims that have partial slices on them so that other operands
// can use this information. Also return a bool mentioning if a reduction dim
// has a non full slice as that can be used to fold the original extract slice.
static FailureOr<llvm::DenseMap<int64_t, SliceDimInfo>>
getPartialSliceDimInfo(GenericOp genericOp, OpOperand *sliceOperand) {
tensor::ExtractSliceOp producerSliceOp =
sliceOperand->get().getDefiningOp<tensor::ExtractSliceOp>();
assert(producerSliceOp && "expect a valid ExtractSliceOp");
llvm::DenseMap<int64_t, SliceDimInfo> partialSliceDimMap;
SmallVector<OpFoldResult> offsets = producerSliceOp.getMixedOffsets();
SmallVector<OpFoldResult> sizes = producerSliceOp.getMixedSizes();

SmallVector<OpFoldResult> shape = getAsIndexOpFoldResult(
genericOp.getContext(), producerSliceOp.getSourceType().getShape());

for (auto [idx, expr] : llvm::enumerate(
genericOp.getMatchingIndexingMap(sliceOperand).getResults())) {
// If we have a full slice in a dimension then we dont need to add it to
// the partial slice map.
if (isConstantIntValue(offsets[idx], 0) &&
isEqualConstantIntOrValue(sizes[idx], shape[idx])) {
continue;
}
// We only support partial slices of AffineDimExprs so bail-out if thats not
// the case.
if (!isa<AffineDimExpr>(expr)) {
return failure();
}
SliceDimInfo sliceDimInfo{offsets[idx], sizes[idx], shape[idx]};
int64_t dimPos = cast<AffineDimExpr>(expr).getPosition();
partialSliceDimMap[dimPos] = sliceDimInfo;
}
// Next check if the dims with partial slice info are used in non
// AffineDimExpr in other operands and if they are then bail-out.
for (OpOperand &operand : genericOp->getOpOperands()) {
if (operand == *sliceOperand) {
continue;
}
AffineMap IndexingMap = genericOp.getMatchingIndexingMap(&operand);
if (llvm::any_of(IndexingMap.getResults(), [&](AffineExpr expr) {
if (isa<AffineDimExpr>(expr)) {
return false;
}
WalkResult status = expr.walk([&](AffineExpr expr) {
if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
if (partialSliceDimMap.contains(dimExpr.getPosition())) {
return WalkResult::interrupt();
}
}
return WalkResult::advance();
});
if (status.wasInterrupted()) {
return true;
}
return false;
})) {
return failure();
}
}
return partialSliceDimMap;
}

static FailureOr<std::tuple<GenericOp, Value>>
pushDownExtractSliceOpThroughGenericOp(RewriterBase &rewriter,
GenericOp genericOp,
ControlPropagationFn controlFn) {
if (genericOp.getNumResults() != 1)
return rewriter.notifyMatchFailure(
genericOp, "propagation through multi-result generic is unsupported.");
if (hasGatherSemantics(genericOp))
return rewriter.notifyMatchFailure(
genericOp,
"propagation through generic with gather semantics is unsupported.");
// Collect the sliced operand, if present.
auto maybeSliceOperand = getSliceOperand(genericOp);
if (failed(maybeSliceOperand))
return failure();
OpOperand *sliceOperand = *maybeSliceOperand;
unsigned OperandIndex = sliceOperand->getOperandNumber();

if (!controlFn(sliceOperand))
return failure();

tensor::ExtractSliceOp producerSliceOp =
sliceOperand->get().getDefiningOp<tensor::ExtractSliceOp>();
assert(producerSliceOp && "expect a valid ExtractSliceOp");

if (producerSliceOp.getSource().getType().getRank() !=
producerSliceOp.getResult().getType().getRank()) {
return rewriter.notifyMatchFailure(
genericOp,
"propagation of rank-reducing extract slice is unsupported.");
}

SmallVector<OpFoldResult> strides = producerSliceOp.getMixedStrides();
if (!areAllConstantIntValue(strides, 1))
return rewriter.notifyMatchFailure(
genericOp, "propagation of strided extract slice is unsupported.");

// check if we can support the propagation of this extractSlice
// through the generic op and if so return the dimensions that

auto maybePartialSliceDimMap =
getPartialSliceDimInfo(genericOp, sliceOperand);

if (failed(maybePartialSliceDimMap)) {
return failure();
}

auto partialSliceDimMap = *maybePartialSliceDimMap;

SmallVector<utils::IteratorType> iterators =
genericOp.getIteratorTypesArray();
bool hasPartialReductionDimSlice =
llvm::any_of(partialSliceDimMap, [&](const auto &slice) {
int64_t sliceDim = slice.first;
return iterators[sliceDim] == utils::IteratorType::reduction;
});

// Store the padding information as (dimPos, lowPad, highPad, PaddedShape).
Location loc = genericOp->getLoc();
AffineExpr dim0, dim1;
bindDims(rewriter.getContext(), dim0, dim1);
auto subMap = AffineMap::get(2, 0, {dim0 - dim1});
auto sub = [&](OpFoldResult v1, OpFoldResult v2) {
return affine::makeComposedFoldedAffineApply(rewriter, loc, subMap,
{v1, v2});
};

MLIRContext *ctx = genericOp.getContext();
SmallVector<Value> paddedInputs;
for (auto [idx, operand] : llvm::enumerate(genericOp.getDpsInputOperands())) {
if (idx == OperandIndex && !hasPartialReductionDimSlice) {
paddedInputs.push_back(producerSliceOp.getSource());
continue;
}
AffineMap IndexingMap = genericOp.getMatchingIndexingMap(operand);
SmallVector<OpFoldResult> operandLowPads(IndexingMap.getNumResults(),
getAsIndexOpFoldResult(ctx, 0));
SmallVector<OpFoldResult> operandHighPads(IndexingMap.getNumResults(),
getAsIndexOpFoldResult(ctx, 0));
for (auto [idx, expr] : llvm::enumerate(IndexingMap.getResults())) {
if (!isa<AffineDimExpr>(expr)) {
continue;
}
AffineDimExpr dimExpr = cast<AffineDimExpr>(expr);
if (!partialSliceDimMap.contains(dimExpr.getPosition())) {
continue;
}
SliceDimInfo sliceDimInfo = partialSliceDimMap[dimExpr.getPosition()];
operandLowPads[idx] = sliceDimInfo.offset;
operandHighPads[idx] =
sub(sub(sliceDimInfo.outputSize, sliceDimInfo.offset),
sliceDimInfo.sliceSize);
}
auto paddingValue = ub::PoisonOp::create(
rewriter, loc, getElementTypeOrSelf(operand->get().getType()));
auto paddedOperand = tensor::PadOp::create(
rewriter, loc, Type(), operand->get(), operandLowPads, operandHighPads,
paddingValue, /*nofold=*/false);
paddedInputs.push_back(paddedOperand);
}
AffineMap outputIndexingMap =
genericOp.getMatchingIndexingMap(genericOp.getDpsInitOperand(0));

auto outputShapeType =
llvm::cast<ShapedType>(genericOp.getDpsInitOperand(0)->get().getType());
SmallVector<OpFoldResult> OutputShape = llvm::map_to_vector(
outputShapeType.getShape(),
[&](int64_t sz) -> OpFoldResult { return rewriter.getIndexAttr(sz); });
SmallVector<OpFoldResult> newSizes = OutputShape;
SmallVector<OpFoldResult> outputLowPads(outputIndexingMap.getNumResults(),
getAsIndexOpFoldResult(ctx, 0));
SmallVector<OpFoldResult> outputHighPads(outputIndexingMap.getNumResults(),
getAsIndexOpFoldResult(ctx, 0));
SmallVector<OpFoldResult> newStrides(outputIndexingMap.getNumResults(),
getAsIndexOpFoldResult(ctx, 1));
for (auto [idx, expr] : llvm::enumerate(outputIndexingMap.getResults())) {
if (!isa<AffineDimExpr>(expr)) {
continue;
}
AffineDimExpr dimExpr = cast<AffineDimExpr>(expr);
if (!partialSliceDimMap.contains(dimExpr.getPosition())) {
continue;
}
SliceDimInfo sliceDimInfo = partialSliceDimMap[dimExpr.getPosition()];
outputLowPads[idx] = sliceDimInfo.offset;
outputHighPads[idx] = sub(sub(sliceDimInfo.outputSize, sliceDimInfo.offset),
sliceDimInfo.sliceSize);
OutputShape[idx] = sliceDimInfo.outputSize;
newSizes[idx] = sliceDimInfo.sliceSize;
}
Value newPadOutput;
auto outputElType =
getElementTypeOrSelf(genericOp.getDpsInits()[0].getType());
if (isGenericOutsNotUsed(genericOp)) {
newPadOutput =
tensor::EmptyOp::create(rewriter, loc, OutputShape, outputElType);
} else {
auto paddingValue = ub::PoisonOp::create(rewriter, loc, outputElType);
newPadOutput = tensor::PadOp::create(
rewriter, loc, Type(), genericOp.getDpsInits()[0], outputLowPads,
outputHighPads, paddingValue, /*nofold=*/false);
}

auto newGenericOp = linalg::GenericOp::create(
rewriter, loc, newPadOutput.getType(), paddedInputs, {newPadOutput},
genericOp.getIndexingMapsArray(), genericOp.getIteratorTypesArray(),
/*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp));
rewriter.cloneRegionBefore(genericOp.getRegion(), newGenericOp.getRegion(),
newGenericOp.getRegion().begin());

auto extractOp = tensor::ExtractSliceOp::create(
rewriter, loc,
newGenericOp.getTiedOpResult(newGenericOp.getDpsInitOperand(0)),
outputLowPads, newSizes, newStrides);
Value extractRes = extractOp.getResult();

return std::make_tuple(newGenericOp, extractRes);
}

class PushDownExtractSliceOpThroughGenericOp final
: public OpRewritePattern<GenericOp> {
public:
PushDownExtractSliceOpThroughGenericOp(MLIRContext *context,
ControlPropagationFn fun)
: OpRewritePattern<GenericOp>(context), controlFn(std::move(fun)) {}

LogicalResult matchAndRewrite(GenericOp genericOp,
PatternRewriter &rewriter) const override {
auto genericAndRepl =
pushDownExtractSliceOpThroughGenericOp(rewriter, genericOp, controlFn);
if (failed(genericAndRepl))
return failure();
rewriter.replaceOp(genericOp, std::get<1>(*genericAndRepl));
return success();
}

private:
ControlPropagationFn controlFn;
};

} // namespace

void mlir::linalg::populateDataLayoutPropagationPatterns(
Expand All @@ -1247,3 +1515,10 @@ void mlir::linalg::populateDataLayoutPropagationPatterns(
PushDownUnPackThroughPadOp, PushDownUnPackOpThroughReshapeOp>(
patterns.getContext(), controlPackUnPackPropagation);
}

void mlir::linalg::populateExtractSliceSinkingPatterns(
RewritePatternSet &patterns,
const ControlPropagationFn &controlPackUnPackPropagation) {
patterns.insert<PushDownExtractSliceOpThroughGenericOp>(
patterns.getContext(), controlPackUnPackPropagation);
}
Loading