-
Notifications
You must be signed in to change notification settings - Fork 14.9k
[mlir][linalg] Use ub.poison in data layout propagation if a packed operand requires padding. #159467
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][linalg] Use ub.poison in data layout propagation if a packed operand requires padding. #159467
Changes from 2 commits
511fa23
d9f526e
5a5c53e
15c9016
83c4777
ada610e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,6 +14,7 @@ | |
#include "mlir/Dialect/UB/IR/UBOps.h" | ||
#include "mlir/Dialect/Utils/IndexingUtils.h" | ||
#include "mlir/IR/Dominance.h" | ||
#include "mlir/IR/TypeUtilities.h" | ||
#include "llvm/ADT/SetOperations.h" | ||
#include "llvm/ADT/SetVector.h" | ||
#include "llvm/ADT/TypeSwitch.h" | ||
|
@@ -220,9 +221,10 @@ static SmallVector<int64_t> computeOuterDims(ArrayRef<int64_t> perm, | |
/// inner_dims_pos = [0] | ||
/// inner_tiles = [8] | ||
/// into %init : tensor<?xf32> -> tensor<?x8xf32> | ||
static std::tuple<Value, AffineMap> | ||
static FailureOr<std::tuple<Value, AffineMap>> | ||
getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo, | ||
GenericOp genericOp, OpOperand *opOperand) { | ||
GenericOp genericOp, OpOperand *opOperand, | ||
bool poisonPaddingOk) { | ||
int64_t numOrigLoops = genericOp.getNumLoops(); | ||
int64_t numInnerLoops = packInfo.getNumTiledLoops(); | ||
int64_t numLoops = numOrigLoops + numInnerLoops; | ||
|
@@ -286,12 +288,26 @@ getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo, | |
// The operand does not have dimensions that relates to pack op. | ||
if (innerDimsPos.empty() && outerDimsPerm.empty()) | ||
return std::make_tuple(opOperand->get(), indexingMap); | ||
|
||
auto inputType = cast<RankedTensorType>(opOperand->get().getType()); | ||
auto maybeIntInnerTileSizes = getConstantIntValues(innerTileSizes); | ||
if (!maybeIntInnerTileSizes.has_value()) { | ||
return failure(); | ||
} | ||
nirvedhmeshram marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
if (!poisonPaddingOk && | ||
linalg::PackOp::requirePaddingValueStrict( | ||
inputType.getShape(), innerDimsPos, | ||
linalg::PackOp::inferPackedType(inputType, *maybeIntInnerTileSizes, | ||
innerDimsPos, outerDimsPerm) | ||
.getShape(), | ||
outerDimsPerm, innerTileSizes)) | ||
return failure(); | ||
nirvedhmeshram marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
auto empty = linalg::PackOp::createDestinationTensor( | ||
b, loc, opOperand->get(), innerTileSizes, innerDimsPos, outerDimsPerm); | ||
auto packedOperand = linalg::PackOp::create( | ||
b, loc, opOperand->get(), empty, innerDimsPos, innerTileSizes, | ||
/*padding=*/std::nullopt, outerDimsPerm); | ||
auto poison = ub::PoisonOp::create( | ||
b, loc, getElementTypeOrSelf(opOperand->get().getType())); | ||
Comment on lines
+342
to
+343
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Rather than explicitly creating the pad value by the user, why not take approach similar to #146088? (i.e. make the padding value "optional" and make the builder worry about it) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it is okay to have a separate builder to do what you shared, and it makes sense. However, we don't do it in this builder, because users can decide to not use padding values in dynamic shapes at risk. I think it is okay to create a new builder that has such behavior, but we should leave the current builder as what it is. See below explanation for some details. (One of the difference is that padding is optional to pack ops, but padding is required by vector.trasnfer_read op.) It is hard to make the builder to check if padding value is required or not in dynamic shapes. To me, there are soft check and hard check for padding value requirement. E.g., the existing method is a soft check that returns false in dynamic cases. I'm proposing a hard check (in IREE's issue), and use it with the control that @nirvedhmeshram is working on. (The hard check version returns true in dynamic cases.) bool PackOp::requirePaddingValue(ArrayRef<int64_t> inputShape,
ArrayRef<int64_t> innerDimsPos,
ArrayRef<int64_t> outputShape,
ArrayRef<int64_t> outerDimsPerm,
ArrayRef<OpFoldResult> innerTiles) {
SmallVector<int64_t> outputTileSizes(
outputShape.take_front(inputShape.size()));
if (!outerDimsPerm.empty()) {
assert(outerDimsPerm.size() == outputTileSizes.size() &&
"expected output and outer_dims_perm to have same size");
applyPermutationToVector(outputTileSizes,
invertPermutationVector(outerDimsPerm));
}
for (auto [pos, tileSize] : llvm::zip_equal(innerDimsPos, innerTiles)) {
if (ShapedType::isDynamic(inputShape[pos]))
continue;
std::optional<int64_t> constantTile = getConstantIntValue(tileSize);
if (!constantTile) {
if (ShapedType::isStatic(outputTileSizes[pos]) &&
(inputShape[pos] % outputTileSizes[pos] != 0))
return true;
} else if (inputShape[pos] % (*constantTile) != 0) {
return true;
}
}
return false;
} There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For another example -- that I don't see in practice so far, which is just my hypothesis -- you may drop the padding value in this case:
|
||
Value packedOperand = | ||
linalg::PackOp::create(b, loc, opOperand->get(), empty, innerDimsPos, | ||
innerTileSizes, poison, outerDimsPerm); | ||
return std::make_tuple(packedOperand, indexingMap); | ||
} | ||
|
||
|
@@ -301,10 +317,10 @@ getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo, | |
/// around it. Implicitly this will only work when a packInfo can be obtained. | ||
/// This make sure that we are only using this function on parallel permuted | ||
/// dimensions. | ||
static GenericOp packGenericOp(RewriterBase &rewriter, GenericOp genericOp, | ||
Value dest, AffineMap packedOutIndexingMap, | ||
const PackInfo &packInfo, | ||
bool isFoldableUnpackPack) { | ||
static FailureOr<GenericOp> | ||
packGenericOp(RewriterBase &rewriter, GenericOp genericOp, Value dest, | ||
AffineMap packedOutIndexingMap, const PackInfo &packInfo, | ||
bool isFoldableUnpackPack, bool poisonPaddingOk) { | ||
Location loc = genericOp.getLoc(); | ||
SmallVector<Value> inputOperands; | ||
SmallVector<Value> inputOperandsFromUnpackedSource; | ||
|
@@ -315,8 +331,13 @@ static GenericOp packGenericOp(RewriterBase &rewriter, GenericOp genericOp, | |
llvm::equal(packOp.getMixedTiles(), unPackOp.getMixedTiles()); | ||
}; | ||
for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) { | ||
auto [packedOperand, packedIndexingMap] = getOrCreatePackedViewOfOperand( | ||
rewriter, loc, packInfo, genericOp, inputOperand); | ||
auto mayBepackedOperandAndIndexing = getOrCreatePackedViewOfOperand( | ||
rewriter, loc, packInfo, genericOp, inputOperand, poisonPaddingOk); | ||
if (failed(mayBepackedOperandAndIndexing)) { | ||
return failure(); | ||
} | ||
auto packedOperand = std::get<0>(*mayBepackedOperandAndIndexing); | ||
auto packedIndexingMap = std::get<1>(*mayBepackedOperandAndIndexing); | ||
auto unpackOp = inputOperand->get().getDefiningOp<linalg::UnPackOp>(); | ||
auto packOp = packedOperand.getDefiningOp<linalg::PackOp>(); | ||
if (packOp && unpackOp && hasEquivalentTiles(packOp, unpackOp)) { | ||
|
@@ -407,7 +428,8 @@ static bool isGenericOutsNotUsed(linalg::GenericOp genericOp) { | |
/// } -> tensor<?x?x8x2xf32> | ||
static FailureOr<GenericOp> | ||
bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, linalg::PackOp packOp, | ||
const ControlPropagationFn &controlFn) { | ||
const ControlPropagationFn &controlFn, | ||
bool poisonPaddingOk) { | ||
auto genericOp = packOp.getSource().getDefiningOp<GenericOp>(); | ||
if (!genericOp) | ||
return failure(); | ||
|
@@ -470,9 +492,14 @@ bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, linalg::PackOp packOp, | |
} | ||
|
||
// Rebuild the indexing map for the corresponding init operand. | ||
auto [packedOutOperand, packedOutIndexingMap] = | ||
auto mayBepackedOperandAndIndexing = | ||
getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), *packInfo, | ||
genericOp, opOperand); | ||
genericOp, opOperand, poisonPaddingOk); | ||
if (failed(mayBepackedOperandAndIndexing)) { | ||
return failure(); | ||
} | ||
auto packedOutOperand = std::get<0>(*mayBepackedOperandAndIndexing); | ||
auto packedOutIndexingMap = std::get<1>(*mayBepackedOperandAndIndexing); | ||
nirvedhmeshram marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
||
// Forward the new tensor.empty as a destination if it is one of the following | ||
// situations: | ||
|
@@ -488,21 +515,24 @@ bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, linalg::PackOp packOp, | |
// pack(unpack) isn't naively foldable because the unpack op can be from | ||
// an arbitrary domain so we need to keep both. | ||
return packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap, | ||
*packInfo, /*isFoldableUnpackPack=*/false); | ||
*packInfo, /*isFoldableUnpackPack=*/false, | ||
poisonPaddingOk); | ||
} | ||
|
||
/// Wrapper pattern that applies bubbleUpPackOpThroughGenericOp method. | ||
struct BubbleUpPackOpThroughGenericOpPattern | ||
: public OpRewritePattern<linalg::PackOp> { | ||
public: | ||
BubbleUpPackOpThroughGenericOpPattern(MLIRContext *context, | ||
ControlPropagationFn fun) | ||
: OpRewritePattern<linalg::PackOp>(context), controlFn(std::move(fun)) {} | ||
ControlPropagationFn fun, | ||
bool poisonPaddingOk) | ||
: OpRewritePattern<linalg::PackOp>(context), controlFn(std::move(fun)), | ||
poisonPaddingOk(std::move(poisonPaddingOk)) {} | ||
|
||
LogicalResult matchAndRewrite(linalg::PackOp packOp, | ||
PatternRewriter &rewriter) const override { | ||
auto genericOp = | ||
bubbleUpPackOpThroughGenericOp(rewriter, packOp, controlFn); | ||
auto genericOp = bubbleUpPackOpThroughGenericOp(rewriter, packOp, controlFn, | ||
poisonPaddingOk); | ||
if (failed(genericOp)) | ||
return failure(); | ||
rewriter.replaceOp(packOp, genericOp->getResults()); | ||
|
@@ -511,6 +541,7 @@ struct BubbleUpPackOpThroughGenericOpPattern | |
|
||
private: | ||
ControlPropagationFn controlFn; | ||
bool poisonPaddingOk; | ||
}; | ||
|
||
/// Propagate a linalg.pack operation up through a tensor.pad. The idea is to | ||
|
@@ -1080,7 +1111,8 @@ static FailureOr<OpOperand *> getUnPackedOperand(GenericOp genericOp) { | |
/// | ||
static FailureOr<std::tuple<GenericOp, Value>> | ||
pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp, | ||
ControlPropagationFn controlFn) { | ||
ControlPropagationFn controlFn, | ||
bool poisonPaddingOk) { | ||
if (genericOp.getNumResults() != 1) | ||
return failure(); | ||
|
||
|
@@ -1107,9 +1139,14 @@ pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp, | |
return failure(); | ||
|
||
// Rebuild the indexing map for the corresponding init operand. | ||
auto [packedOutOperand, packedOutIndexingMap] = | ||
getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), *packInfo, | ||
genericOp, genericOp.getDpsInitOperand(0)); | ||
auto mayBepackedOperandAndIndexing = getOrCreatePackedViewOfOperand( | ||
rewriter, genericOp.getLoc(), *packInfo, genericOp, | ||
genericOp.getDpsInitOperand(0), poisonPaddingOk); | ||
if (failed(mayBepackedOperandAndIndexing)) { | ||
return failure(); | ||
} | ||
auto packedOutOperand = std::get<0>(*mayBepackedOperandAndIndexing); | ||
auto packedOutIndexingMap = std::get<1>(*mayBepackedOperandAndIndexing); | ||
nirvedhmeshram marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
auto destPack = packedOutOperand.getDefiningOp<linalg::PackOp>(); | ||
|
||
// Forward the new tensor.empty as a destination if it is one of the following | ||
|
@@ -1129,9 +1166,12 @@ pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp, | |
// pack(unpack) is foldable in this case. This is because in pushing down the | ||
// unpack, by default we will populate an additional pack op after the unpack. | ||
// This guarantees them to be foldable. | ||
GenericOp newGenericOp = | ||
auto maybeGenericOp = | ||
packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap, *packInfo, | ||
/*isFoldableUnpackPack=*/true); | ||
/*isFoldableUnpackPack=*/true, poisonPaddingOk); | ||
if (failed(maybeGenericOp)) | ||
return failure(); | ||
GenericOp newGenericOp = *maybeGenericOp; | ||
Value newResult = | ||
newGenericOp.getTiedOpResult(newGenericOp.getDpsInitOperand(0)); | ||
|
||
|
@@ -1157,13 +1197,15 @@ pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp, | |
struct PushDownUnPackOpThroughGenericOp : public OpRewritePattern<GenericOp> { | ||
public: | ||
PushDownUnPackOpThroughGenericOp(MLIRContext *context, | ||
ControlPropagationFn fun) | ||
: OpRewritePattern<GenericOp>(context), controlFn(std::move(fun)) {} | ||
ControlPropagationFn fun, | ||
bool poisonPaddingOk) | ||
: OpRewritePattern<GenericOp>(context), controlFn(std::move(fun)), | ||
poisonPaddingOk(std::move(poisonPaddingOk)) {} | ||
|
||
LogicalResult matchAndRewrite(GenericOp genericOp, | ||
PatternRewriter &rewriter) const override { | ||
auto genericAndRepl = | ||
pushDownUnPackOpThroughGenericOp(rewriter, genericOp, controlFn); | ||
auto genericAndRepl = pushDownUnPackOpThroughGenericOp( | ||
rewriter, genericOp, controlFn, poisonPaddingOk); | ||
if (failed(genericAndRepl)) | ||
return failure(); | ||
rewriter.replaceOp(genericOp, std::get<1>(*genericAndRepl)); | ||
|
@@ -1172,6 +1214,7 @@ struct PushDownUnPackOpThroughGenericOp : public OpRewritePattern<GenericOp> { | |
|
||
private: | ||
ControlPropagationFn controlFn; | ||
bool poisonPaddingOk; | ||
}; | ||
|
||
/// Propagate a linalg.unpack operation through a tensor.pad. The idea is to | ||
|
@@ -1522,12 +1565,14 @@ class PushDownExtractSliceOpThroughGenericOp final | |
|
||
void mlir::linalg::populateDataLayoutPropagationPatterns( | ||
RewritePatternSet &patterns, | ||
const ControlPropagationFn &controlPackUnPackPropagation) { | ||
patterns | ||
.insert<BubbleUpPackOpThroughGenericOpPattern, BubbleUpPackThroughPadOp, | ||
BubbleUpPackOpThroughReshapeOp, PushDownUnPackOpThroughGenericOp, | ||
PushDownUnPackThroughPadOp, PushDownUnPackOpThroughReshapeOp>( | ||
patterns.getContext(), controlPackUnPackPropagation); | ||
const ControlPropagationFn &controlPackUnPackPropagation, | ||
bool PoisonPaddingOk) { | ||
patterns.insert<BubbleUpPackThroughPadOp, BubbleUpPackOpThroughReshapeOp, | ||
PushDownUnPackThroughPadOp, PushDownUnPackOpThroughReshapeOp>( | ||
patterns.getContext(), controlPackUnPackPropagation); | ||
patterns.insert<BubbleUpPackOpThroughGenericOpPattern, | ||
PushDownUnPackOpThroughGenericOp>( | ||
patterns.getContext(), controlPackUnPackPropagation, PoisonPaddingOk); | ||
} | ||
|
||
void mlir::linalg::populateExtractSliceSinkingPatterns( | ||
|
Uh oh!
There was an error while loading. Please reload this page.