Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,14 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
ArrayRef<int64_t> outerDimsPerm,
ArrayRef<OpFoldResult> innerTiles);

// Same as above function but here dynamic dimensions are assumed
// to require padding.
static bool requirePaddingValueStrict(ArrayRef<int64_t> inputShape,
ArrayRef<int64_t> innerDimsPos,
ArrayRef<int64_t> outputShape,
ArrayRef<int64_t> outerDimsPerm,
ArrayRef<OpFoldResult> innerTiles);

static Value createDestinationTensor(OpBuilder &b, Location loc,
Value source, ArrayRef<OpFoldResult> innerTileSizes,
ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm);
Expand Down
5 changes: 4 additions & 1 deletion mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -1914,9 +1914,12 @@ void populateElementwiseOpsFusionPatterns(
using ControlPropagationFn = std::function<bool(OpOperand *opOperand)>;

/// Patterns to bubble up or down data layout ops across other operations.
/// The function also has an option to allow the patterns to propagate with
/// poison padding if requested by the caller.
void populateDataLayoutPropagationPatterns(
RewritePatternSet &patterns,
const ControlPropagationFn &controlPackUnPackPropagation);
const ControlPropagationFn &controlPackUnPackPropagation,
bool PoisonPaddingOk = false);

/// Patterns to sink extract slice across other operations.
void populateExtractSliceSinkingPatterns(
Expand Down
26 changes: 26 additions & 0 deletions mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5310,6 +5310,32 @@ bool PackOp::requirePaddingValue(ArrayRef<int64_t> inputShape,
return false;
}

bool PackOp::requirePaddingValueStrict(ArrayRef<int64_t> inputShape,
ArrayRef<int64_t> innerDimsPos,
ArrayRef<int64_t> outputShape,
ArrayRef<int64_t> outerDimsPerm,
ArrayRef<OpFoldResult> innerTiles) {
SmallVector<int64_t> outputTileSizes(
outputShape.take_front(inputShape.size()));
if (!outerDimsPerm.empty()) {
assert(outerDimsPerm.size() == outputTileSizes.size() &&
"expected output and outer_dims_perm to have same size");
applyPermutationToVector(outputTileSizes,
invertPermutationVector(outerDimsPerm));
}
for (auto [pos, tileSize] : llvm::zip_equal(innerDimsPos, innerTiles)) {
if (ShapedType::isDynamic(inputShape[pos]) ||
ShapedType::isDynamic(outputTileSizes[pos]))
return true;
std::optional<int64_t> constantTile = getConstantIntValue(tileSize);
if (!constantTile)
return true;
if (inputShape[pos] % (*constantTile) != 0)
return true;
}
return false;
}

LogicalResult PackOp::verify() {
if (failed(commonVerifierPackAndUnPackOp(*this)))
return failure();
Expand Down
Loading
Loading