Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,14 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
ArrayRef<int64_t> outerDimsPerm,
ArrayRef<OpFoldResult> innerTiles);

// Same as above function but here dynamic dimensions are assumed
// to require padding.
static bool requirePaddingValueStrict(ArrayRef<int64_t> inputShape,
ArrayRef<int64_t> innerDimsPos,
ArrayRef<int64_t> outputShape,
ArrayRef<int64_t> outerDimsPerm,
ArrayRef<OpFoldResult> innerTiles);

static Value createDestinationTensor(OpBuilder &b, Location loc,
Value source, ArrayRef<OpFoldResult> innerTileSizes,
ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm);
Expand Down
5 changes: 4 additions & 1 deletion mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -1914,9 +1914,12 @@ void populateElementwiseOpsFusionPatterns(
using ControlPropagationFn = std::function<bool(OpOperand *opOperand)>;

/// Patterns to bubble up or down data layout ops across other operations.
/// The function also has an option to allow the patterns to propagate with
/// poison padding if requested by the caller.
void populateDataLayoutPropagationPatterns(
RewritePatternSet &patterns,
const ControlPropagationFn &controlPackUnPackPropagation);
const ControlPropagationFn &controlPackUnPackPropagation,
bool PoisonPaddingOk = false);

/// Patterns to sink extract slice across other operations.
void populateExtractSliceSinkingPatterns(
Expand Down
26 changes: 26 additions & 0 deletions mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5310,6 +5310,32 @@ bool PackOp::requirePaddingValue(ArrayRef<int64_t> inputShape,
return false;
}

bool PackOp::requirePaddingValueStrict(ArrayRef<int64_t> inputShape,
ArrayRef<int64_t> innerDimsPos,
ArrayRef<int64_t> outputShape,
ArrayRef<int64_t> outerDimsPerm,
ArrayRef<OpFoldResult> innerTiles) {
SmallVector<int64_t> outputTileSizes(
outputShape.take_front(inputShape.size()));
if (!outerDimsPerm.empty()) {
assert(outerDimsPerm.size() == outputTileSizes.size() &&
"expected output and outer_dims_perm to have same size");
applyPermutationToVector(outputTileSizes,
invertPermutationVector(outerDimsPerm));
}
for (auto [pos, tileSize] : llvm::zip_equal(innerDimsPos, innerTiles)) {
if (ShapedType::isDynamic(inputShape[pos]) ||
ShapedType::isDynamic(outputTileSizes[pos]))
return true;
std::optional<int64_t> constantTile = getConstantIntValue(tileSize);
if (!constantTile)
return true;
if (inputShape[pos] % (*constantTile) != 0)
return true;
}
return false;
}

LogicalResult PackOp::verify() {
if (failed(commonVerifierPackAndUnPackOp(*this)))
return failure();
Expand Down
175 changes: 133 additions & 42 deletions mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/IR/Dominance.h"
#include "mlir/IR/TypeUtilities.h"
#include "llvm/ADT/SetOperations.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/TypeSwitch.h"
Expand Down Expand Up @@ -220,9 +221,21 @@ static SmallVector<int64_t> computeOuterDims(ArrayRef<int64_t> perm,
/// inner_dims_pos = [0]
/// inner_tiles = [8]
/// into %init : tensor<?xf32> -> tensor<?x8xf32>
static std::tuple<Value, AffineMap>
getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo,
GenericOp genericOp, OpOperand *opOperand) {

struct PackedOperandDetails {
SmallVector<OpFoldResult> innerTileSizes;
SmallVector<int64_t> innerDimsPos;
SmallVector<int64_t> outerDimsPerm;
AffineMap indexingMap;
};

/// Helper function for getOrCreatePackedViewOfOperand that populates
/// the details of the packedOperand that needs to be formed and also
// returns if the packing would require padding.
static bool getPackedOperandDetails(
OpBuilder &b, PackInfo packInfo, GenericOp genericOp, OpOperand *opOperand,
DenseMap<OpOperand *, PackedOperandDetails> &packedOperandMap) {
PackedOperandDetails currOperandDetails;
int64_t numOrigLoops = genericOp.getNumLoops();
int64_t numInnerLoops = packInfo.getNumTiledLoops();
int64_t numLoops = numOrigLoops + numInnerLoops;
Expand All @@ -231,9 +244,12 @@ getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo,
SmallVector<AffineExpr> exprs(origIndexingMap.getResults());

// If the OpOperand is a scalar or a zero-rank tensor, no need to pack.
if (genericOp.isScalar(opOperand) || exprs.empty())
return std::make_tuple(opOperand->get(),
AffineMap::get(numLoops, 0, exprs, b.getContext()));
if (genericOp.isScalar(opOperand) || exprs.empty()) {
currOperandDetails.indexingMap =
AffineMap::get(numLoops, 0, exprs, b.getContext());
packedOperandMap[opOperand] = currOperandDetails;
return false;
}

// Step 1. Construct the information of packing data dimensions; append inner
// dimensions to the indexing maps for the operand.
Expand Down Expand Up @@ -281,18 +297,57 @@ getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo,
exprs = auxVec;
}
}
auto indexingMap = AffineMap::get(numLoops, 0, exprs, b.getContext());
currOperandDetails.indexingMap =
AffineMap::get(numLoops, 0, exprs, b.getContext());

// The operand does not have dimensions that relates to pack op.
if (innerDimsPos.empty() && outerDimsPerm.empty())
return std::make_tuple(opOperand->get(), indexingMap);
if (innerDimsPos.empty() && outerDimsPerm.empty()) {
packedOperandMap[opOperand] = currOperandDetails;
return false;
}
auto inputType = cast<RankedTensorType>(opOperand->get().getType());

auto maybeIntInnerTileSizes =
llvm::map_to_vector(innerTileSizes, [](OpFoldResult ofr) -> int64_t {
std::optional<int64_t> maybeCst = getConstantIntValue(ofr);
return maybeCst.value_or(ShapedType::kDynamic);
});
bool requirePadding = linalg::PackOp::requirePaddingValueStrict(
inputType.getShape(), innerDimsPos,
linalg::PackOp::inferPackedType(inputType, maybeIntInnerTileSizes,
innerDimsPos, outerDimsPerm)
.getShape(),
outerDimsPerm, innerTileSizes);
currOperandDetails.innerDimsPos = innerDimsPos;
currOperandDetails.innerTileSizes = innerTileSizes;
currOperandDetails.outerDimsPerm = outerDimsPerm;
packedOperandMap[opOperand] = currOperandDetails;

if (requirePadding)
return true;
return false;
}

static std::tuple<Value, AffineMap> getOrCreatePackedViewOfOperand(
OpBuilder &b, Location loc, OpOperand *opOperand,
DenseMap<OpOperand *, PackedOperandDetails> packedOperandMap) {
assert(packedOperandMap.contains(opOperand) &&
"packed operand details expected to be populated");
auto currOperandDetails = packedOperandMap[opOperand];
auto innerDimsPos = currOperandDetails.innerDimsPos;
auto outerDimsPerm = currOperandDetails.outerDimsPerm;
auto innerTileSizes = currOperandDetails.innerTileSizes;
if (innerDimsPos.empty() && outerDimsPerm.empty()) {
return std::make_tuple(opOperand->get(), currOperandDetails.indexingMap);
}
auto empty = linalg::PackOp::createDestinationTensor(
b, loc, opOperand->get(), innerTileSizes, innerDimsPos, outerDimsPerm);
auto packedOperand = linalg::PackOp::create(
b, loc, opOperand->get(), empty, innerDimsPos, innerTileSizes,
/*padding=*/std::nullopt, outerDimsPerm);
return std::make_tuple(packedOperand, indexingMap);
auto poison = ub::PoisonOp::create(
b, loc, getElementTypeOrSelf(opOperand->get().getType()));
Comment on lines +342 to +343
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than explicitly creating the pad value by the user, why not take approach similar to #146088? (i.e. make the padding value "optional" and make the builder worry about it)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is okay to have a separate builder to do what you shared, and it makes sense. However, we don't do it in this builder, because users can decide to not use padding values in dynamic shapes at risk. I think it is okay to create a new builder that has such behavior, but we should leave the current builder as what it is. See below explanation for some details. (One of the difference is that padding is optional to pack ops, but padding is required by vector.trasnfer_read op.)

It is hard to make the builder to check if padding value is required or not in dynamic shapes. To me, there are soft check and hard check for padding value requirement. E.g., the existing method is a soft check that returns false in dynamic cases. I'm proposing a hard check (in IREE's issue), and use it with the control that @nirvedhmeshram is working on.

(The hard check version returns true in dynamic cases.)

bool PackOp::requirePaddingValue(ArrayRef<int64_t> inputShape,
                                 ArrayRef<int64_t> innerDimsPos,
                                 ArrayRef<int64_t> outputShape,
                                 ArrayRef<int64_t> outerDimsPerm,
                                 ArrayRef<OpFoldResult> innerTiles) {
  SmallVector<int64_t> outputTileSizes(
      outputShape.take_front(inputShape.size()));
  if (!outerDimsPerm.empty()) {
    assert(outerDimsPerm.size() == outputTileSizes.size() &&
           "expected output and outer_dims_perm to have same size");
    applyPermutationToVector(outputTileSizes,
                             invertPermutationVector(outerDimsPerm));
  }
  for (auto [pos, tileSize] : llvm::zip_equal(innerDimsPos, innerTiles)) {
    if (ShapedType::isDynamic(inputShape[pos]))
      continue;
    std::optional<int64_t> constantTile = getConstantIntValue(tileSize);

    if (!constantTile) {
      if (ShapedType::isStatic(outputTileSizes[pos]) &&
          (inputShape[pos] % outputTileSizes[pos] != 0))
        return true;
    } else if (inputShape[pos] % (*constantTile) != 0) {
      return true;
    }
  }
  return false;
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For another example -- that I don't see in practice so far, which is just my hypothesis -- you may drop the padding value in this case:

  • The inner tile size is 4.
  • You have int range analysis that tells you the packing dimension size is a multiple of 4 (which is a dynamic shape).

Value packedOperand =
linalg::PackOp::create(b, loc, opOperand->get(), empty, innerDimsPos,
innerTileSizes, poison, outerDimsPerm);
return std::make_tuple(packedOperand, currOperandDetails.indexingMap);
}

/// This function is a helper subroutine to pack a genericOp and return it. It
Expand All @@ -301,10 +356,10 @@ getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo,
/// around it. Implicitly this will only work when a packInfo can be obtained.
/// This make sure that we are only using this function on parallel permuted
/// dimensions.
static GenericOp packGenericOp(RewriterBase &rewriter, GenericOp genericOp,
Value dest, AffineMap packedOutIndexingMap,
const PackInfo &packInfo,
bool isFoldableUnpackPack) {
static FailureOr<GenericOp>
packGenericOp(RewriterBase &rewriter, GenericOp genericOp, Value dest,
AffineMap packedOutIndexingMap, const PackInfo &packInfo,
bool isFoldableUnpackPack, bool poisonPaddingOk) {
Location loc = genericOp.getLoc();
SmallVector<Value> inputOperands;
SmallVector<Value> inputOperandsFromUnpackedSource;
Expand All @@ -314,9 +369,18 @@ static GenericOp packGenericOp(RewriterBase &rewriter, GenericOp genericOp,
packOp.getInnerDimsPos() == unPackOp.getInnerDimsPos() &&
llvm::equal(packOp.getMixedTiles(), unPackOp.getMixedTiles());
};
DenseMap<OpOperand *, PackedOperandDetails> packedOperandMap;
bool requiresPadding = false;
for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
requiresPadding |= getPackedOperandDetails(rewriter, packInfo, genericOp,
inputOperand, packedOperandMap);
}
if (requiresPadding && !poisonPaddingOk) {
return failure();
}
for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
auto [packedOperand, packedIndexingMap] = getOrCreatePackedViewOfOperand(
rewriter, loc, packInfo, genericOp, inputOperand);
rewriter, loc, inputOperand, packedOperandMap);
auto unpackOp = inputOperand->get().getDefiningOp<linalg::UnPackOp>();
auto packOp = packedOperand.getDefiningOp<linalg::PackOp>();
if (packOp && unpackOp && hasEquivalentTiles(packOp, unpackOp)) {
Expand Down Expand Up @@ -407,7 +471,8 @@ static bool isGenericOutsNotUsed(linalg::GenericOp genericOp) {
/// } -> tensor<?x?x8x2xf32>
static FailureOr<GenericOp>
bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, linalg::PackOp packOp,
const ControlPropagationFn &controlFn) {
const ControlPropagationFn &controlFn,
bool poisonPaddingOk) {
auto genericOp = packOp.getSource().getDefiningOp<GenericOp>();
if (!genericOp)
return failure();
Expand Down Expand Up @@ -470,10 +535,15 @@ bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, linalg::PackOp packOp,
}

// Rebuild the indexing map for the corresponding init operand.
DenseMap<OpOperand *, PackedOperandDetails> packedOperandMap;
bool requiresPadding = getPackedOperandDetails(rewriter, *packInfo, genericOp,
opOperand, packedOperandMap);
if (requiresPadding && !poisonPaddingOk) {
return failure();
}
auto [packedOutOperand, packedOutIndexingMap] =
getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), *packInfo,
genericOp, opOperand);

getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), opOperand,
packedOperandMap);
// Forward the new tensor.empty as a destination if it is one of the following
// situations:
// 1) The dps init operand is a tensor.empty.
Expand All @@ -488,21 +558,24 @@ bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, linalg::PackOp packOp,
// pack(unpack) isn't naively foldable because the unpack op can be from
// an arbitrary domain so we need to keep both.
return packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap,
*packInfo, /*isFoldableUnpackPack=*/false);
*packInfo, /*isFoldableUnpackPack=*/false,
poisonPaddingOk);
}

/// Wrapper pattern that applies bubbleUpPackOpThroughGenericOp method.
struct BubbleUpPackOpThroughGenericOpPattern
: public OpRewritePattern<linalg::PackOp> {
public:
BubbleUpPackOpThroughGenericOpPattern(MLIRContext *context,
ControlPropagationFn fun)
: OpRewritePattern<linalg::PackOp>(context), controlFn(std::move(fun)) {}
ControlPropagationFn fun,
bool poisonPaddingOk)
: OpRewritePattern<linalg::PackOp>(context), controlFn(std::move(fun)),
poisonPaddingOk(std::move(poisonPaddingOk)) {}

LogicalResult matchAndRewrite(linalg::PackOp packOp,
PatternRewriter &rewriter) const override {
auto genericOp =
bubbleUpPackOpThroughGenericOp(rewriter, packOp, controlFn);
auto genericOp = bubbleUpPackOpThroughGenericOp(rewriter, packOp, controlFn,
poisonPaddingOk);
if (failed(genericOp))
return failure();
rewriter.replaceOp(packOp, genericOp->getResults());
Expand All @@ -511,6 +584,7 @@ struct BubbleUpPackOpThroughGenericOpPattern

private:
ControlPropagationFn controlFn;
bool poisonPaddingOk;
};

/// Propagate a linalg.pack operation up through a tensor.pad. The idea is to
Expand Down Expand Up @@ -1080,7 +1154,8 @@ static FailureOr<OpOperand *> getUnPackedOperand(GenericOp genericOp) {
///
static FailureOr<std::tuple<GenericOp, Value>>
pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp,
ControlPropagationFn controlFn) {
ControlPropagationFn controlFn,
bool poisonPaddingOk) {
if (genericOp.getNumResults() != 1)
return failure();

Expand All @@ -1107,9 +1182,17 @@ pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp,
return failure();

// Rebuild the indexing map for the corresponding init operand.
DenseMap<OpOperand *, PackedOperandDetails> packedOperandMap;
bool requiresPadding =
getPackedOperandDetails(rewriter, *packInfo, genericOp,
genericOp.getDpsInitOperand(0), packedOperandMap);
if (requiresPadding && !poisonPaddingOk) {
return failure();
}
auto [packedOutOperand, packedOutIndexingMap] =
getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), *packInfo,
genericOp, genericOp.getDpsInitOperand(0));
getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(),
genericOp.getDpsInitOperand(0),
packedOperandMap);
auto destPack = packedOutOperand.getDefiningOp<linalg::PackOp>();

// Forward the new tensor.empty as a destination if it is one of the following
Expand All @@ -1129,9 +1212,12 @@ pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp,
// pack(unpack) is foldable in this case. This is because in pushing down the
// unpack, by default we will populate an additional pack op after the unpack.
// This guarantees them to be foldable.
GenericOp newGenericOp =
auto maybeGenericOp =
packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap, *packInfo,
/*isFoldableUnpackPack=*/true);
/*isFoldableUnpackPack=*/true, poisonPaddingOk);
if (failed(maybeGenericOp))
return failure();
GenericOp newGenericOp = *maybeGenericOp;
Value newResult =
newGenericOp.getTiedOpResult(newGenericOp.getDpsInitOperand(0));

Expand All @@ -1157,13 +1243,15 @@ pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp,
struct PushDownUnPackOpThroughGenericOp : public OpRewritePattern<GenericOp> {
public:
PushDownUnPackOpThroughGenericOp(MLIRContext *context,
ControlPropagationFn fun)
: OpRewritePattern<GenericOp>(context), controlFn(std::move(fun)) {}
ControlPropagationFn fun,
bool poisonPaddingOk)
: OpRewritePattern<GenericOp>(context), controlFn(std::move(fun)),
poisonPaddingOk(std::move(poisonPaddingOk)) {}

LogicalResult matchAndRewrite(GenericOp genericOp,
PatternRewriter &rewriter) const override {
auto genericAndRepl =
pushDownUnPackOpThroughGenericOp(rewriter, genericOp, controlFn);
auto genericAndRepl = pushDownUnPackOpThroughGenericOp(
rewriter, genericOp, controlFn, poisonPaddingOk);
if (failed(genericAndRepl))
return failure();
rewriter.replaceOp(genericOp, std::get<1>(*genericAndRepl));
Expand All @@ -1172,6 +1260,7 @@ struct PushDownUnPackOpThroughGenericOp : public OpRewritePattern<GenericOp> {

private:
ControlPropagationFn controlFn;
bool poisonPaddingOk;
};

/// Propagate a linalg.unpack operation through a tensor.pad. The idea is to
Expand Down Expand Up @@ -1522,12 +1611,14 @@ class PushDownExtractSliceOpThroughGenericOp final

void mlir::linalg::populateDataLayoutPropagationPatterns(
RewritePatternSet &patterns,
const ControlPropagationFn &controlPackUnPackPropagation) {
patterns
.insert<BubbleUpPackOpThroughGenericOpPattern, BubbleUpPackThroughPadOp,
BubbleUpPackOpThroughReshapeOp, PushDownUnPackOpThroughGenericOp,
PushDownUnPackThroughPadOp, PushDownUnPackOpThroughReshapeOp>(
patterns.getContext(), controlPackUnPackPropagation);
const ControlPropagationFn &controlPackUnPackPropagation,
bool PoisonPaddingOk) {
patterns.insert<BubbleUpPackThroughPadOp, BubbleUpPackOpThroughReshapeOp,
PushDownUnPackThroughPadOp, PushDownUnPackOpThroughReshapeOp>(
patterns.getContext(), controlPackUnPackPropagation);
patterns.insert<BubbleUpPackOpThroughGenericOpPattern,
PushDownUnPackOpThroughGenericOp>(
patterns.getContext(), controlPackUnPackPropagation, PoisonPaddingOk);
}

void mlir::linalg::populateExtractSliceSinkingPatterns(
Expand Down
Loading