Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,14 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
ArrayRef<int64_t> outerDimsPerm,
ArrayRef<OpFoldResult> innerTiles);

// Same as above function but here dynamic dimensions are assumed
// to require padding.
static bool requirePaddingValueStrict(ArrayRef<int64_t> inputShape,
ArrayRef<int64_t> innerDimsPos,
ArrayRef<int64_t> outputShape,
ArrayRef<int64_t> outerDimsPerm,
ArrayRef<OpFoldResult> innerTiles);

static Value createDestinationTensor(OpBuilder &b, Location loc,
Value source, ArrayRef<OpFoldResult> innerTileSizes,
ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm);
Expand Down
5 changes: 4 additions & 1 deletion mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -1914,9 +1914,12 @@ void populateElementwiseOpsFusionPatterns(
using ControlPropagationFn = std::function<bool(OpOperand *opOperand)>;

/// Patterns to bubble up or down data layout ops across other operations.
/// The function also has an option to allow the patterns to propagate with
/// poison padding if requested by the caller.
void populateDataLayoutPropagationPatterns(
RewritePatternSet &patterns,
const ControlPropagationFn &controlPackUnPackPropagation);
const ControlPropagationFn &controlPackUnPackPropagation,
bool PoisonPaddingOk = false);

/// Patterns to sink extract slice across other operations.
void populateExtractSliceSinkingPatterns(
Expand Down
29 changes: 29 additions & 0 deletions mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5310,6 +5310,35 @@ bool PackOp::requirePaddingValue(ArrayRef<int64_t> inputShape,
return false;
}

bool PackOp::requirePaddingValueStrict(ArrayRef<int64_t> inputShape,
ArrayRef<int64_t> innerDimsPos,
ArrayRef<int64_t> outputShape,
ArrayRef<int64_t> outerDimsPerm,
ArrayRef<OpFoldResult> innerTiles) {
SmallVector<int64_t> outputTileSizes(
outputShape.take_front(inputShape.size()));
if (!outerDimsPerm.empty()) {
assert(outerDimsPerm.size() == outputTileSizes.size() &&
"expected output and outer_dims_perm to have same size");
applyPermutationToVector(outputTileSizes,
invertPermutationVector(outerDimsPerm));
}
for (auto [pos, tileSize] : llvm::zip_equal(innerDimsPos, innerTiles)) {
if (ShapedType::isDynamic(inputShape[pos]))
return true;
std::optional<int64_t> constantTile = getConstantIntValue(tileSize);

if (!constantTile) {
if (ShapedType::isStatic(outputTileSizes[pos]) &&
(inputShape[pos] % outputTileSizes[pos] != 0))
return true;
} else if (inputShape[pos] % (*constantTile) != 0) {
return true;
}
}
return false;
}

LogicalResult PackOp::verify() {
if (failed(commonVerifierPackAndUnPackOp(*this)))
return failure();
Expand Down
110 changes: 76 additions & 34 deletions mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -221,9 +221,10 @@ static SmallVector<int64_t> computeOuterDims(ArrayRef<int64_t> perm,
/// inner_dims_pos = [0]
/// inner_tiles = [8]
/// into %init : tensor<?xf32> -> tensor<?x8xf32>
static std::tuple<Value, AffineMap>
static FailureOr<std::tuple<Value, AffineMap>>
getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo,
GenericOp genericOp, OpOperand *opOperand) {
GenericOp genericOp, OpOperand *opOperand,
bool poisonPaddingOk) {
int64_t numOrigLoops = genericOp.getNumLoops();
int64_t numInnerLoops = packInfo.getNumTiledLoops();
int64_t numLoops = numOrigLoops + numInnerLoops;
Expand Down Expand Up @@ -287,12 +288,24 @@ getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo,
// The operand does not have dimensions that relates to pack op.
if (innerDimsPos.empty() && outerDimsPerm.empty())
return std::make_tuple(opOperand->get(), indexingMap);

auto inputType = cast<RankedTensorType>(opOperand->get().getType());
auto maybeIntInnerTileSizes = getConstantIntValues(innerTileSizes);
if (!maybeIntInnerTileSizes.has_value()) {
return failure();
}
if (!poisonPaddingOk &&
linalg::PackOp::requirePaddingValueStrict(
inputType.getShape(), innerDimsPos,
linalg::PackOp::inferPackedType(inputType, *maybeIntInnerTileSizes,
innerDimsPos, outerDimsPerm)
.getShape(),
outerDimsPerm, innerTileSizes))
return failure();
auto empty = linalg::PackOp::createDestinationTensor(
b, loc, opOperand->get(), innerTileSizes, innerDimsPos, outerDimsPerm);
auto poison = ub::PoisonOp::create(
b, loc, getElementTypeOrSelf(opOperand->get().getType()));
auto packedOperand =
Value packedOperand =
linalg::PackOp::create(b, loc, opOperand->get(), empty, innerDimsPos,
innerTileSizes, poison, outerDimsPerm);
return std::make_tuple(packedOperand, indexingMap);
Expand All @@ -304,10 +317,10 @@ getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo,
/// around it. Implicitly this will only work when a packInfo can be obtained.
/// This make sure that we are only using this function on parallel permuted
/// dimensions.
static GenericOp packGenericOp(RewriterBase &rewriter, GenericOp genericOp,
Value dest, AffineMap packedOutIndexingMap,
const PackInfo &packInfo,
bool isFoldableUnpackPack) {
static FailureOr<GenericOp>
packGenericOp(RewriterBase &rewriter, GenericOp genericOp, Value dest,
AffineMap packedOutIndexingMap, const PackInfo &packInfo,
bool isFoldableUnpackPack, bool poisonPaddingOk) {
Location loc = genericOp.getLoc();
SmallVector<Value> inputOperands;
SmallVector<Value> inputOperandsFromUnpackedSource;
Expand All @@ -318,8 +331,13 @@ static GenericOp packGenericOp(RewriterBase &rewriter, GenericOp genericOp,
llvm::equal(packOp.getMixedTiles(), unPackOp.getMixedTiles());
};
for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
auto [packedOperand, packedIndexingMap] = getOrCreatePackedViewOfOperand(
rewriter, loc, packInfo, genericOp, inputOperand);
auto mayBepackedOperandAndIndexing = getOrCreatePackedViewOfOperand(
rewriter, loc, packInfo, genericOp, inputOperand, poisonPaddingOk);
if (failed(mayBepackedOperandAndIndexing)) {
return failure();
}
auto packedOperand = std::get<0>(*mayBepackedOperandAndIndexing);
auto packedIndexingMap = std::get<1>(*mayBepackedOperandAndIndexing);
auto unpackOp = inputOperand->get().getDefiningOp<linalg::UnPackOp>();
auto packOp = packedOperand.getDefiningOp<linalg::PackOp>();
if (packOp && unpackOp && hasEquivalentTiles(packOp, unpackOp)) {
Expand Down Expand Up @@ -410,7 +428,8 @@ static bool isGenericOutsNotUsed(linalg::GenericOp genericOp) {
/// } -> tensor<?x?x8x2xf32>
static FailureOr<GenericOp>
bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, linalg::PackOp packOp,
const ControlPropagationFn &controlFn) {
const ControlPropagationFn &controlFn,
bool poisonPaddingOk) {
auto genericOp = packOp.getSource().getDefiningOp<GenericOp>();
if (!genericOp)
return failure();
Expand Down Expand Up @@ -473,9 +492,14 @@ bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, linalg::PackOp packOp,
}

// Rebuild the indexing map for the corresponding init operand.
auto [packedOutOperand, packedOutIndexingMap] =
auto mayBepackedOperandAndIndexing =
getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), *packInfo,
genericOp, opOperand);
genericOp, opOperand, poisonPaddingOk);
if (failed(mayBepackedOperandAndIndexing)) {
return failure();
}
auto packedOutOperand = std::get<0>(*mayBepackedOperandAndIndexing);
auto packedOutIndexingMap = std::get<1>(*mayBepackedOperandAndIndexing);

// Forward the new tensor.empty as a destination if it is one of the following
// situations:
Expand All @@ -491,21 +515,24 @@ bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, linalg::PackOp packOp,
// pack(unpack) isn't naively foldable because the unpack op can be from
// an arbitrary domain so we need to keep both.
return packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap,
*packInfo, /*isFoldableUnpackPack=*/false);
*packInfo, /*isFoldableUnpackPack=*/false,
poisonPaddingOk);
}

/// Wrapper pattern that applies bubbleUpPackOpThroughGenericOp method.
struct BubbleUpPackOpThroughGenericOpPattern
: public OpRewritePattern<linalg::PackOp> {
public:
BubbleUpPackOpThroughGenericOpPattern(MLIRContext *context,
ControlPropagationFn fun)
: OpRewritePattern<linalg::PackOp>(context), controlFn(std::move(fun)) {}
ControlPropagationFn fun,
bool poisonPaddingOk)
: OpRewritePattern<linalg::PackOp>(context), controlFn(std::move(fun)),
poisonPaddingOk(std::move(poisonPaddingOk)) {}

LogicalResult matchAndRewrite(linalg::PackOp packOp,
PatternRewriter &rewriter) const override {
auto genericOp =
bubbleUpPackOpThroughGenericOp(rewriter, packOp, controlFn);
auto genericOp = bubbleUpPackOpThroughGenericOp(rewriter, packOp, controlFn,
poisonPaddingOk);
if (failed(genericOp))
return failure();
rewriter.replaceOp(packOp, genericOp->getResults());
Expand All @@ -514,6 +541,7 @@ struct BubbleUpPackOpThroughGenericOpPattern

private:
ControlPropagationFn controlFn;
bool poisonPaddingOk;
};

/// Propagate a linalg.pack operation up through a tensor.pad. The idea is to
Expand Down Expand Up @@ -1083,7 +1111,8 @@ static FailureOr<OpOperand *> getUnPackedOperand(GenericOp genericOp) {
///
static FailureOr<std::tuple<GenericOp, Value>>
pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp,
ControlPropagationFn controlFn) {
ControlPropagationFn controlFn,
bool poisonPaddingOk) {
if (genericOp.getNumResults() != 1)
return failure();

Expand All @@ -1110,9 +1139,14 @@ pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp,
return failure();

// Rebuild the indexing map for the corresponding init operand.
auto [packedOutOperand, packedOutIndexingMap] =
getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), *packInfo,
genericOp, genericOp.getDpsInitOperand(0));
auto mayBepackedOperandAndIndexing = getOrCreatePackedViewOfOperand(
rewriter, genericOp.getLoc(), *packInfo, genericOp,
genericOp.getDpsInitOperand(0), poisonPaddingOk);
if (failed(mayBepackedOperandAndIndexing)) {
return failure();
}
auto packedOutOperand = std::get<0>(*mayBepackedOperandAndIndexing);
auto packedOutIndexingMap = std::get<1>(*mayBepackedOperandAndIndexing);
auto destPack = packedOutOperand.getDefiningOp<linalg::PackOp>();

// Forward the new tensor.empty as a destination if it is one of the following
Expand All @@ -1132,9 +1166,12 @@ pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp,
// pack(unpack) is foldable in this case. This is because in pushing down the
// unpack, by default we will populate an additional pack op after the unpack.
// This guarantees them to be foldable.
GenericOp newGenericOp =
auto maybeGenericOp =
packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap, *packInfo,
/*isFoldableUnpackPack=*/true);
/*isFoldableUnpackPack=*/true, poisonPaddingOk);
if (failed(maybeGenericOp))
return failure();
GenericOp newGenericOp = *maybeGenericOp;
Value newResult =
newGenericOp.getTiedOpResult(newGenericOp.getDpsInitOperand(0));

Expand All @@ -1160,13 +1197,15 @@ pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp,
struct PushDownUnPackOpThroughGenericOp : public OpRewritePattern<GenericOp> {
public:
PushDownUnPackOpThroughGenericOp(MLIRContext *context,
ControlPropagationFn fun)
: OpRewritePattern<GenericOp>(context), controlFn(std::move(fun)) {}
ControlPropagationFn fun,
bool poisonPaddingOk)
: OpRewritePattern<GenericOp>(context), controlFn(std::move(fun)),
poisonPaddingOk(std::move(poisonPaddingOk)) {}

LogicalResult matchAndRewrite(GenericOp genericOp,
PatternRewriter &rewriter) const override {
auto genericAndRepl =
pushDownUnPackOpThroughGenericOp(rewriter, genericOp, controlFn);
auto genericAndRepl = pushDownUnPackOpThroughGenericOp(
rewriter, genericOp, controlFn, poisonPaddingOk);
if (failed(genericAndRepl))
return failure();
rewriter.replaceOp(genericOp, std::get<1>(*genericAndRepl));
Expand All @@ -1175,6 +1214,7 @@ struct PushDownUnPackOpThroughGenericOp : public OpRewritePattern<GenericOp> {

private:
ControlPropagationFn controlFn;
bool poisonPaddingOk;
};

/// Propagate a linalg.unpack operation through a tensor.pad. The idea is to
Expand Down Expand Up @@ -1525,12 +1565,14 @@ class PushDownExtractSliceOpThroughGenericOp final

void mlir::linalg::populateDataLayoutPropagationPatterns(
RewritePatternSet &patterns,
const ControlPropagationFn &controlPackUnPackPropagation) {
patterns
.insert<BubbleUpPackOpThroughGenericOpPattern, BubbleUpPackThroughPadOp,
BubbleUpPackOpThroughReshapeOp, PushDownUnPackOpThroughGenericOp,
PushDownUnPackThroughPadOp, PushDownUnPackOpThroughReshapeOp>(
patterns.getContext(), controlPackUnPackPropagation);
const ControlPropagationFn &controlPackUnPackPropagation,
bool PoisonPaddingOk) {
patterns.insert<BubbleUpPackThroughPadOp, BubbleUpPackOpThroughReshapeOp,
PushDownUnPackThroughPadOp, PushDownUnPackOpThroughReshapeOp>(
patterns.getContext(), controlPackUnPackPropagation);
patterns.insert<BubbleUpPackOpThroughGenericOpPattern,
PushDownUnPackOpThroughGenericOp>(
patterns.getContext(), controlPackUnPackPropagation, PoisonPaddingOk);
}

void mlir::linalg::populateExtractSliceSinkingPatterns(
Expand Down
3 changes: 2 additions & 1 deletion mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ struct TestDataLayoutPropagationPass
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);
linalg::populateDataLayoutPropagationPatterns(
patterns, [](OpOperand *opOperand) { return true; });
patterns, [](OpOperand *opOperand) { return true; },
/*poisonPaddingOk=*/true);
linalg::ControlPropagationFn controlExtract =
[](OpOperand *opOperand) -> bool {
Operation *producer = opOperand->get().getDefiningOp();
Expand Down