Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ namespace linalg {
/// This function uses the helper function `computePackUnPackPerm` to get
/// the permutation vector. Only major difference between UnPack and Pack is
/// that packOp uses destination rank whereas unpack Uses source rank.
SmallVector<int64_t> getPackInverseDestPerm(linalg::PackOp packOp);
SmallVector<int64_t> getPackInverseDestPerm(linalg::PackOp packOp,
PackingMetadata &metadatap);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pMetadata? It does not match the implementation; I think packingMetadata looks better, as you are exposing it as a function argument. Or it can just be metadata like the other function, i.e. getUnPackInverseSrcPerm. The doc needs to be updated as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, that's a typo. Let me update this to match getUnPackInverseSrcPerm. I will also update the docs for both hooks - in fact, I will make them much shorter. Right now, IMHO, they are too long and go into implementation details that should be left for the implementation itself:

/// Shell function to compute the Destination Permutation of PackOp
/// This function uses the helper function `computePackUnPackPerm` to get
/// the permutation vector. Only major difference between UnPack and Pack is
/// that packOp uses destination rank whereas unpack Uses source rank.

I will also remove this helper hook which doesn't seem to be required (at least based on "upstream"):

SmallVector<int64_t> getUnPackInverseSrcPerm(linalg::UnPackOp unpackOp);


/// Shell function to compute the Source Permutation of unPackOp.
/// This function, like the getPackInverseDestPerm uses the helper function
Expand Down
5 changes: 2 additions & 3 deletions mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -232,10 +232,9 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,

// 2. Compute the permutation vector to shuffle packed shape into the shape
// before any outer or inner permutations have been applied.
PackingMetadata packingMetadata = computePackingMetadata(
packedTensorType.getRank(), packOp.getInnerDimsPos());
PackingMetadata packingMetadata;
SmallVector<int64_t> packedToStripMinedShapePerm =
getPackInverseDestPerm(packOp);
getPackInverseDestPerm(packOp, packingMetadata);

// 3. Compute the stripMinedShape: this is the packed shape before any outer
// or inner permutations have been applied.
Expand Down
218 changes: 114 additions & 104 deletions mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1564,13 +1564,6 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
return success();
}

/// Given a linalg::PackOp, return the `dest` shape before any packing
/// permutations.
static SmallVector<int64_t> getTiledPackShape(linalg::PackOp packOp,
ArrayRef<int64_t> destShape) {
return applyPermutation(destShape, linalg::getPackInverseDestPerm(packOp));
}

/// Determines whether a mask for xfer_write is trivially "all true"
///
/// Given all the inputs required to generate a mask (mask sizes and shapes),
Expand Down Expand Up @@ -1761,99 +1754,6 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vecToStore,
return mlir::vector::maskOperation(builder, write, maskForWrite);
}

/// Vectorize linalg::PackOp with (1) static inner_tiles (2) constant
/// padding value and (3) input vector sizes into:
///
/// masked_transfer_read->shape_cast->transpose->transfer_write_in_bounds
///
/// As in the following example:
/// %pack = tensor.pack %src inner_dims_pos = [2, 1] inner_tiles = [16, 2]
/// into %dst : tensor<32x8x16xf32> -> tensor<32x4x1x16x2xf32>
///
/// This pack would be vectorized to:
///
/// %load = vector.mask %mask {
/// vector.transfer_read %arg0[%c0, %c0, %c0], %cst
/// {in_bounds = [true, true, true]} :
/// tensor<32x7x16xf32>, vector<32x8x16xf32>
/// } : vector<32x8x16xi1> -> vector<32x8x16xf32>
/// %shape_cast = vector.shape_cast %load : vector<32x8x16xf32>
/// to vector<32x4x2x1x16xf32>
/// %transpose = vector.transpose %shape_cast, [0, 1, 3, 4, 2]
/// : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32>
/// %write = vector.transfer_write %transpose,
/// %empty[%c0_0, %c0_0, %c0_0, %c0_0, %c0_0]
/// {in_bounds = [true, true, true, true, true]}
/// : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32>
///
/// If the (3) input vector sizes are not provided, the vector sizes are
/// determined by the result tensor shape and the `in_bounds`
/// attribute is used instead of masking to mark out-of-bounds accesses.
///
/// NOTE: The input vector sizes specify the dimensions corresponding to the
/// outer dimensions of the output tensor. The remaining dimensions are
/// computed based on, e.g., the static inner tiles.
/// Supporting dynamic inner tiles will require the user to specify the
/// missing vector sizes. This is left as a TODO.
static LogicalResult
vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
ArrayRef<int64_t> inputVectorSizes,
SmallVectorImpl<Value> &newResults) {
// TODO: Introduce a parent class that will handle the insertion point update.
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(packOp);

Location loc = packOp.getLoc();
std::optional<Value> padValue = packOp.getPaddingValue()
? std::optional(packOp.getPaddingValue())
: std::nullopt;

// If the input vector sizes are not provided, then the vector sizes are
// determined by the result tensor shape. In case the vector sizes aren't
// provided, we update the inBounds attribute instead of masking.
bool useInBoundsInsteadOfMasking = false;
if (inputVectorSizes.empty()) {
ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
inputVectorSizes = resultTensorShape.take_front(packOp.getSourceRank());
useInBoundsInsteadOfMasking = true;
}

// Create masked TransferReadOp.
SmallVector<int64_t> inputShape(inputVectorSizes);
auto innerTiles = packOp.getStaticInnerTiles();
auto innerDimsPos = packOp.getInnerDimsPos();
auto outerDimsPerm = packOp.getOuterDimsPerm();
if (!outerDimsPerm.empty())
applyPermutationToVector(inputShape,
invertPermutationVector(outerDimsPerm));
for (auto [idx, size] : enumerate(innerTiles))
inputShape[innerDimsPos[idx]] *= size;
auto maskedRead = vector::createReadOrMaskedRead(
rewriter, loc, packOp.getSource(), inputShape, padValue,
useInBoundsInsteadOfMasking,
/*inputScalableVecSizes=*/{});

// Create ShapeCastOp.
SmallVector<int64_t> destShape(inputVectorSizes);
destShape.append(innerTiles.begin(), innerTiles.end());
auto tiledPackType = VectorType::get(getTiledPackShape(packOp, destShape),
packOp.getDestType().getElementType());
auto shapeCastOp =
vector::ShapeCastOp::create(rewriter, loc, tiledPackType, maskedRead);

// Create TransposeOp.
auto destPermutation =
invertPermutationVector(getPackInverseDestPerm(packOp));
auto transposeOp = vector::TransposeOp::create(
rewriter, loc, shapeCastOp.getResult(), destPermutation);

// Create TransferWriteOp.
Operation *write = createWriteOrMaskedWrite(
rewriter, loc, transposeOp.getResult(), packOp.getDest());
newResults.push_back(write->getResult(0));
return success();
}

/// Given the re-associations, "collapses" the input Vector type
///
/// This is similar to CollapseShapeOp::inferCollapsedType with two notable
Expand Down Expand Up @@ -1901,12 +1801,120 @@ static VectorType getCollapsedVecType(VectorType type,
return VectorType::get(newShape, type.getElementType(), newScalableFlags);
}

/// Vectorize `linalg.pack` as:
/// * xfer_read -> shape_cast -> transpose -> xfer_write
///
/// The input-vector-sizes specify the _write_ vector sizes (i.e. the vector
/// sizes for the xfer_write operation). This is sufficient to infer the other
/// vector sizes required here.
///
/// If the vector sizes are not provided:
/// * the vector sizes are determined from the destination tensor static shape.
/// * the inBounds attribute is used instead of masking.
///
/// EXAMPLE (no vector sizes):
/// ```
/// %pack = tensor.pack %src
/// inner_dims_pos = [2, 1]
/// inner_tiles = [16, 2]
/// into %dst : tensor<32x8x16xf32> -> tensor<32x4x1x16x2xf32>
/// ``
/// is vectorizes as:
/// ```
/// %read = vector.transfer_read %src
/// : tensor<32x7x16xf32>, vector<32x8x16xf32>
/// %sc = vector.shape_cast %read
/// : vector<32x8x16xf32> to vector<32x4x2x1x16xf32>
/// %tr = vector.transpose %sc, [0, 1, 3, 4, 2]
/// : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32>
/// %write = vector.transfer_write %tr into %dest
/// : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32>
/// ```
static LogicalResult
vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
ArrayRef<int64_t> inputVectorSizes,
SmallVectorImpl<Value> &newResults) {
if (!inputVectorSizes.empty()) {
assert(inputVectorSizes.size() == packOp.getDestRank() &&
"Invalid number of input vector sizes!");
}

// TODO: Introduce a parent class that will handle the insertion point update.
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(packOp);

Location loc = packOp.getLoc();
std::optional<Value> padValue = packOp.getPaddingValue()
? std::optional(packOp.getPaddingValue())
: std::nullopt;

SmallVector<int64_t> destShape =
SmallVector<int64_t>(packOp.getDestType().getShape());

// This is just a convenience alias to clearly communicate that the input
// vector sizes determine the _write_ sizes.
ArrayRef<int64_t> &writeVectorSizes = inputVectorSizes;

// In the absence of input-vector-sizes, use the _static_ input tensor shape.
// In addition, use the inBounds attribute instead of masking.
bool useInBoundsInsteadOfMasking = false;
if (writeVectorSizes.empty()) {
if (ShapedType::isDynamicShape(destShape))
return rewriter.notifyMatchFailure(packOp,
"Unable to infer vector sizes!");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We usually start the first sentence with a lowercase letter, and finish the last sentence without a period/exclamation mark .

https://llvm.org/docs/CodingStandards.html#error-and-warning-messages

Copy link
Contributor Author

@banach-space banach-space Oct 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the reminder!


writeVectorSizes = destShape;
useInBoundsInsteadOfMasking = true;
}

// Compute vector type for the _read_ opeartion. The required dims are
// determined based on the _write_ vector sizes. This is done in two
// steps:
// 1) Invert the permutation/transposition that's part of the Pack
// operation.
// 2) Collapse the tiled sizes/dims to "return" to the unpacked domain.
PackingMetadata packMetadata;
auto destInvPermutation = getPackInverseDestPerm(packOp, packMetadata);

SmallVector<int64_t> writeVecSizesUnpermuted(writeVectorSizes);
applyPermutationToVector(writeVecSizesUnpermuted, destInvPermutation);

VectorType readVecType = getCollapsedVecType(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this is a little hard to follow at first glance IMO

VectorType::get(writeVecSizesUnpermuted,
packOp.getType().getElementType()),
getSymbolLessAffineMaps(convertReassociationIndicesToExprs(
rewriter.getContext(), packMetadata.reassociations)));

// Create masked TransferReadOp.
auto maskedRead = vector::createReadOrMaskedRead(
rewriter, loc, packOp.getSource(), readVecType.getShape(), padValue,
useInBoundsInsteadOfMasking,
/*inputScalableVecSizes=*/{});

// Create ShapeCastOp.
auto expandedVecType = VectorType::get(writeVecSizesUnpermuted,
packOp.getType().getElementType());
auto shapeCastOp =
vector::ShapeCastOp::create(rewriter, loc, expandedVecType, maskedRead);

// Create TransposeOp.
auto destPermutation = invertPermutationVector(destInvPermutation);
auto transposeOp = vector::TransposeOp::create(
rewriter, loc, shapeCastOp.getResult(), destPermutation);

// Create TransferWriteOp.
Operation *write = createWriteOrMaskedWrite(
rewriter, loc, transposeOp.getResult(), packOp.getDest());
newResults.push_back(write->getResult(0));
return success();
}

/// Vectorize `linalg.unpack` as:
/// * xfer_read -> vector.transpose -> vector.shape_cast -> xfer_write
///
/// The input-vector-sizes specify the read vector sizes (i.e. the vector sizes
/// for the xfer_read operation). This is sufficient to infer the other vector
/// sizes required here.
/// The input-vector-sizes specify the _read_ vector sizes (i.e. the vector
/// sizes for the xfer_read operation). This is sufficient to infer the other
/// vector sizes required here.
///
/// If the vector sizes are not provided:
/// * the vector sizes are determined from the input tensor static shape.
Expand Down Expand Up @@ -1960,7 +1968,8 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
// In the absence of input-vector-sizes, use the _static_ input tensor shape.
if (inputVectorSizes.empty()) {
if (ShapedType::isDynamicShape(sourceShape))
return failure();
return rewriter.notifyMatchFailure(unpackOp,
"Unable to infer vector sizes!");

readVectorSizes.assign(sourceShape.begin(), sourceShape.end());
useInBoundsInsteadOfMasking = true;
Expand Down Expand Up @@ -2443,6 +2452,7 @@ vectorizePackOpPrecondition(linalg::PackOp packOp,
ArrayRef<int64_t> inputVectorSizes) {
auto padValue = packOp.getPaddingValue();
Attribute cstAttr;
// TODO: Relax this condiiton
if (padValue && !matchPattern(padValue, m_Constant(&cstAttr))) {
LDBG() << "pad value is not constant: " << packOp;
return failure();
Expand Down
8 changes: 4 additions & 4 deletions mlir/lib/Dialect/Linalg/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -171,9 +171,9 @@ computePackUnPackPerm(int64_t rank, ArrayRef<int64_t> &innerDimsPos,
namespace mlir {
namespace linalg {

SmallVector<int64_t> getPackInverseDestPerm(PackOp packOp) {
SmallVector<int64_t> getPackInverseDestPerm(PackOp packOp,
PackingMetadata &pMetadata) {

PackingMetadata pMetadata;
int64_t packedRank = packOp.getDestType().getRank();
ArrayRef<int64_t> innerDimPos = packOp.getInnerDimsPos();
ArrayRef<int64_t> outerPerm = packOp.getOuterDimsPerm();
Expand All @@ -189,11 +189,11 @@ SmallVector<int64_t> getUnPackInverseSrcPerm(UnPackOp unpackOp) {

SmallVector<int64_t> getUnPackInverseSrcPerm(UnPackOp unpackOp,
PackingMetadata &metadata) {
int64_t unpackRank = unpackOp.getSourceType().getRank();
int64_t packedRank = unpackOp.getSourceType().getRank();
ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos();
ArrayRef<int64_t> outerPerm = unpackOp.getOuterDimsPerm();
SmallVector<int64_t> unpackInvSrcPerm =
computePackUnPackPerm(unpackRank, innerDimPos, outerPerm, metadata);
computePackUnPackPerm(packedRank, innerDimPos, outerPerm, metadata);
return unpackInvSrcPerm;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,8 @@ module attributes {transform.with_named_sequence} {

///----------------------------------------------------------------------------------------
/// Tests for linalg.pack
///
/// TODO: Add similar tests for linalg.unpack
///----------------------------------------------------------------------------------------

// Note, see a similar test in:
Expand Down
Loading