Skip to content

Commit 3f2aacb

Browse files
committed
[mlir][linalg] Update vectorizatio of linalg.pack
This patch changes `vectorizeAsTensorPackOp` to require users to specify all write-side vector sizes for `linalg.pack` (not just the outer dimensions). This makes `linalg.pack` vectorization consistent with `linalg.unpack` (see #149293 for a similar change). Conceptually, `linalg.pack` consists of these high-level steps: * **Read** from the source tensor using `vector.transfer_read`. * **Re-associate** dimensions of the transposed value, as specified by the op (via `vector.shape_cast`) * **Transpose** the re-associated value according to the permutation in the `linalg.pack` op (via `vector.transpose`). * **Write** the result into the destination tensor via `vector.transfer_write`. Previously, the vector sizes provided by the user were interpreted as write-vector-sizes for PackOp _outer_ dims (i.e. the final step above). These were used to: * Infer read-vector-sizes using the `inner_tiles` attribute of PackOp. * Deduce vector sizes for the transpose and shape cast operations. * Ultimately determine the vector shape for the read. However, this logic breaks when one or more tile sizes are dynamic (*). In such cases, `vectorizePackOpPrecondition` would currently fail (see `@pack_with_dynamic_dims_and_dynamic_inner_tile` added in this PR - without this change it will crash). This patch updates the contract: users now directly specify _all_ the "write-vector-sizes", which inherently encode all inner tile sizes - including dynamic ones. It becomes the user's responsibility to provide valid sizes. In practice, since `linalg.pack` is typically constructed, tiled, and vectorized by the same transformation pipeline, the necessary "write-vector-sizes" should be recoverable. Notes for reviewers: * See test updates for user-facing impact. * Review `vectorizeAsTensorPackOp` as a new implementation rather than a diff. * Comments and variable names were updated to align with `vectorizeAsTensorUnPackOp`. (*) As a concrete example, "scalable" tile sizes are represent as dynamic values. Note, support for "scalable" vectorisation will be added in a separate PR.
1 parent ee12ab2 commit 3f2aacb

File tree

6 files changed

+219
-109
lines changed

6 files changed

+219
-109
lines changed

mlir/include/mlir/Dialect/Linalg/Utils/Utils.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@ namespace linalg {
3737
/// This function uses the helper function `computePackUnPackPerm` to get
3838
/// the permutation vector. Only major difference between UnPack and Pack is
3939
/// that packOp uses destination rank whereas unpack Uses source rank.
40-
SmallVector<int64_t> getPackInverseDestPerm(linalg::PackOp packOp);
40+
SmallVector<int64_t> getPackInverseDestPerm(linalg::PackOp packOp,
41+
PackingMetadata &metadatap);
4142

4243
/// Shell function to compute the Source Permutation of unPackOp.
4344
/// This function, like the getPackInverseDestPerm uses the helper function

mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,8 +234,9 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
234234
// before any outer or inner permutations have been applied.
235235
PackingMetadata packingMetadata = computePackingMetadata(
236236
packedTensorType.getRank(), packOp.getInnerDimsPos());
237+
PackingMetadata packMetadata;
237238
SmallVector<int64_t> packedToStripMinedShapePerm =
238-
getPackInverseDestPerm(packOp);
239+
getPackInverseDestPerm(packOp, packMetadata);
239240

240241
// 3. Compute the stripMinedShape: this is the packed shape before any outer
241242
// or inner permutations have been applied.

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 116 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -1568,7 +1568,9 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
15681568
/// permutations.
15691569
static SmallVector<int64_t> getTiledPackShape(linalg::PackOp packOp,
15701570
ArrayRef<int64_t> destShape) {
1571-
return applyPermutation(destShape, linalg::getPackInverseDestPerm(packOp));
1571+
PackingMetadata metadata;
1572+
return applyPermutation(destShape,
1573+
linalg::getPackInverseDestPerm(packOp, metadata));
15721574
}
15731575

15741576
/// Determines whether a mask for xfer_write is trivially "all true"
@@ -1761,99 +1763,6 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vecToStore,
17611763
return mlir::vector::maskOperation(builder, write, maskForWrite);
17621764
}
17631765

1764-
/// Vectorize linalg::PackOp with (1) static inner_tiles (2) constant
1765-
/// padding value and (3) input vector sizes into:
1766-
///
1767-
/// masked_transfer_read->shape_cast->transpose->transfer_write_in_bounds
1768-
///
1769-
/// As in the following example:
1770-
/// %pack = tensor.pack %src inner_dims_pos = [2, 1] inner_tiles = [16, 2]
1771-
/// into %dst : tensor<32x8x16xf32> -> tensor<32x4x1x16x2xf32>
1772-
///
1773-
/// This pack would be vectorized to:
1774-
///
1775-
/// %load = vector.mask %mask {
1776-
/// vector.transfer_read %arg0[%c0, %c0, %c0], %cst
1777-
/// {in_bounds = [true, true, true]} :
1778-
/// tensor<32x7x16xf32>, vector<32x8x16xf32>
1779-
/// } : vector<32x8x16xi1> -> vector<32x8x16xf32>
1780-
/// %shape_cast = vector.shape_cast %load : vector<32x8x16xf32>
1781-
/// to vector<32x4x2x1x16xf32>
1782-
/// %transpose = vector.transpose %shape_cast, [0, 1, 3, 4, 2]
1783-
/// : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32>
1784-
/// %write = vector.transfer_write %transpose,
1785-
/// %empty[%c0_0, %c0_0, %c0_0, %c0_0, %c0_0]
1786-
/// {in_bounds = [true, true, true, true, true]}
1787-
/// : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32>
1788-
///
1789-
/// If the (3) input vector sizes are not provided, the vector sizes are
1790-
/// determined by the result tensor shape and the `in_bounds`
1791-
/// attribute is used instead of masking to mark out-of-bounds accesses.
1792-
///
1793-
/// NOTE: The input vector sizes specify the dimensions corresponding to the
1794-
/// outer dimensions of the output tensor. The remaining dimensions are
1795-
/// computed based on, e.g., the static inner tiles.
1796-
/// Supporting dynamic inner tiles will require the user to specify the
1797-
/// missing vector sizes. This is left as a TODO.
1798-
static LogicalResult
1799-
vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
1800-
ArrayRef<int64_t> inputVectorSizes,
1801-
SmallVectorImpl<Value> &newResults) {
1802-
// TODO: Introduce a parent class that will handle the insertion point update.
1803-
OpBuilder::InsertionGuard g(rewriter);
1804-
rewriter.setInsertionPoint(packOp);
1805-
1806-
Location loc = packOp.getLoc();
1807-
std::optional<Value> padValue = packOp.getPaddingValue()
1808-
? std::optional(packOp.getPaddingValue())
1809-
: std::nullopt;
1810-
1811-
// If the input vector sizes are not provided, then the vector sizes are
1812-
// determined by the result tensor shape. In case the vector sizes aren't
1813-
// provided, we update the inBounds attribute instead of masking.
1814-
bool useInBoundsInsteadOfMasking = false;
1815-
if (inputVectorSizes.empty()) {
1816-
ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
1817-
inputVectorSizes = resultTensorShape.take_front(packOp.getSourceRank());
1818-
useInBoundsInsteadOfMasking = true;
1819-
}
1820-
1821-
// Create masked TransferReadOp.
1822-
SmallVector<int64_t> inputShape(inputVectorSizes);
1823-
auto innerTiles = packOp.getStaticInnerTiles();
1824-
auto innerDimsPos = packOp.getInnerDimsPos();
1825-
auto outerDimsPerm = packOp.getOuterDimsPerm();
1826-
if (!outerDimsPerm.empty())
1827-
applyPermutationToVector(inputShape,
1828-
invertPermutationVector(outerDimsPerm));
1829-
for (auto [idx, size] : enumerate(innerTiles))
1830-
inputShape[innerDimsPos[idx]] *= size;
1831-
auto maskedRead = vector::createReadOrMaskedRead(
1832-
rewriter, loc, packOp.getSource(), inputShape, padValue,
1833-
useInBoundsInsteadOfMasking,
1834-
/*inputScalableVecSizes=*/{});
1835-
1836-
// Create ShapeCastOp.
1837-
SmallVector<int64_t> destShape(inputVectorSizes);
1838-
destShape.append(innerTiles.begin(), innerTiles.end());
1839-
auto tiledPackType = VectorType::get(getTiledPackShape(packOp, destShape),
1840-
packOp.getDestType().getElementType());
1841-
auto shapeCastOp =
1842-
vector::ShapeCastOp::create(rewriter, loc, tiledPackType, maskedRead);
1843-
1844-
// Create TransposeOp.
1845-
auto destPermutation =
1846-
invertPermutationVector(getPackInverseDestPerm(packOp));
1847-
auto transposeOp = vector::TransposeOp::create(
1848-
rewriter, loc, shapeCastOp.getResult(), destPermutation);
1849-
1850-
// Create TransferWriteOp.
1851-
Operation *write = createWriteOrMaskedWrite(
1852-
rewriter, loc, transposeOp.getResult(), packOp.getDest());
1853-
newResults.push_back(write->getResult(0));
1854-
return success();
1855-
}
1856-
18571766
/// Given the re-associations, "collapses" the input Vector type
18581767
///
18591768
/// This is similar to CollapseShapeOp::inferCollapsedType with two notable
@@ -1901,12 +1810,119 @@ static VectorType getCollapsedVecType(VectorType type,
19011810
return VectorType::get(newShape, type.getElementType(), newScalableFlags);
19021811
}
19031812

1813+
/// Vectorize `linalg.pack` as:
1814+
/// * xfer_read -> shape_cast -> transpose -> xfer_write
1815+
///
1816+
/// The input-vector-sizes specify the _write_ vector sizes (i.e. the vector
1817+
/// sizes for the xfer_write operation). This is sufficient to infer the other
1818+
/// vector sizes required here.
1819+
///
1820+
/// If the vector sizes are not provided:
1821+
/// * the vector sizes are determined from the destination tensor static shape.
1822+
/// * the inBounds attribute is used instead of masking.
1823+
///
1824+
/// EXAMPLE (no vector sizes):
1825+
/// ```
1826+
/// %pack = tensor.pack %src
1827+
/// inner_dims_pos = [2, 1]
1828+
/// inner_tiles = [16, 2]
1829+
/// into %dst : tensor<32x8x16xf32> -> tensor<32x4x1x16x2xf32>
1830+
/// ``
1831+
/// is vectorizes as:
1832+
/// ```
1833+
/// %read = vector.transfer_read %src
1834+
/// : tensor<32x7x16xf32>, vector<32x8x16xf32>
1835+
/// %sc = vector.shape_cast %read
1836+
/// : vector<32x8x16xf32> to vector<32x4x2x1x16xf32>
1837+
/// %tr = vector.transpose %sc, [0, 1, 3, 4, 2]
1838+
/// : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32>
1839+
/// %write = vector.transfer_write %tr into %dest
1840+
/// : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32>
1841+
/// ```
1842+
static LogicalResult
1843+
vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
1844+
ArrayRef<int64_t> inputVectorSizes,
1845+
SmallVectorImpl<Value> &newResults) {
1846+
if (!inputVectorSizes.empty()) {
1847+
assert(inputVectorSizes.size() == packOp.getDestRank() &&
1848+
"Invalid number of input vector sizes!");
1849+
}
1850+
1851+
// TODO: Introduce a parent class that will handle the insertion point update.
1852+
OpBuilder::InsertionGuard g(rewriter);
1853+
rewriter.setInsertionPoint(packOp);
1854+
1855+
Location loc = packOp.getLoc();
1856+
std::optional<Value> padValue = packOp.getPaddingValue()
1857+
? std::optional(packOp.getPaddingValue())
1858+
: std::nullopt;
1859+
1860+
SmallVector<int64_t> destShape =
1861+
SmallVector<int64_t>(packOp.getDestType().getShape());
1862+
1863+
// This is just a convenience alias to clearly communicate that the input
1864+
// vector sizes determine the _write_ sizes.
1865+
ArrayRef<int64_t> &writeVectorSizes = inputVectorSizes;
1866+
1867+
// In the absence of input-vector-sizes, use the _static_ input tensor shape.
1868+
// In addition, use the inBounds attribute instead of masking.
1869+
bool useInBoundsInsteadOfMasking = false;
1870+
if (writeVectorSizes.empty()) {
1871+
if (ShapedType::isDynamicShape(destShape))
1872+
return rewriter.notifyMatchFailure(packOp,
1873+
"Unable to infer vector sizes!");
1874+
1875+
writeVectorSizes = destShape;
1876+
useInBoundsInsteadOfMasking = true;
1877+
}
1878+
1879+
// Compute vector type for the _read_ opeartion. The required dims are
1880+
// determined based on the _write_ vector sizes. This is done in two
1881+
// steps:
1882+
// 1) Invert the permutation/transposition that's part of the Pack
1883+
// operation.
1884+
// 2) Collapse the tiled sizes/dims to "return" to the unpacked domain.
1885+
PackingMetadata packMetadata;
1886+
auto destInvPermutation = getPackInverseDestPerm(packOp, packMetadata);
1887+
1888+
SmallVector<int64_t> inputVecSizesPrePerm(writeVectorSizes);
1889+
applyPermutationToVector(inputVecSizesPrePerm, destInvPermutation);
1890+
1891+
VectorType readVecType = getCollapsedVecType(
1892+
VectorType::get(inputVecSizesPrePerm, packOp.getType().getElementType()),
1893+
getSymbolLessAffineMaps(convertReassociationIndicesToExprs(
1894+
rewriter.getContext(), packMetadata.reassociations)));
1895+
1896+
// Create masked TransferReadOp.
1897+
auto maskedRead = vector::createReadOrMaskedRead(
1898+
rewriter, loc, packOp.getSource(), readVecType.getShape(), padValue,
1899+
useInBoundsInsteadOfMasking,
1900+
/*inputScalableVecSizes=*/{});
1901+
1902+
// Create ShapeCastOp.
1903+
auto expandedVecType =
1904+
VectorType::get(inputVecSizesPrePerm, packOp.getType().getElementType());
1905+
auto shapeCastOp =
1906+
vector::ShapeCastOp::create(rewriter, loc, expandedVecType, maskedRead);
1907+
1908+
// Create TransposeOp.
1909+
auto destPermutation = invertPermutationVector(destInvPermutation);
1910+
auto transposeOp = vector::TransposeOp::create(
1911+
rewriter, loc, shapeCastOp.getResult(), destPermutation);
1912+
1913+
// Create TransferWriteOp.
1914+
Operation *write = createWriteOrMaskedWrite(
1915+
rewriter, loc, transposeOp.getResult(), packOp.getDest());
1916+
newResults.push_back(write->getResult(0));
1917+
return success();
1918+
}
1919+
19041920
/// Vectorize `linalg.unpack` as:
19051921
/// * xfer_read -> vector.transpose -> vector.shape_cast -> xfer_write
19061922
///
1907-
/// The input-vector-sizes specify the read vector sizes (i.e. the vector sizes
1908-
/// for the xfer_read operation). This is sufficient to infer the other vector
1909-
/// sizes required here.
1923+
/// The input-vector-sizes specify the _read_ vector sizes (i.e. the vector
1924+
/// sizes for the xfer_read operation). This is sufficient to infer the other
1925+
/// vector sizes required here.
19101926
///
19111927
/// If the vector sizes are not provided:
19121928
/// * the vector sizes are determined from the input tensor static shape.
@@ -1960,7 +1976,8 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
19601976
// In the absence of input-vector-sizes, use the _static_ input tensor shape.
19611977
if (inputVectorSizes.empty()) {
19621978
if (ShapedType::isDynamicShape(sourceShape))
1963-
return failure();
1979+
return rewriter.notifyMatchFailure(unpackOp,
1980+
"Unable to infer vector sizes!");
19641981

19651982
readVectorSizes.assign(sourceShape.begin(), sourceShape.end());
19661983
useInBoundsInsteadOfMasking = true;
@@ -2443,6 +2460,7 @@ vectorizePackOpPrecondition(linalg::PackOp packOp,
24432460
ArrayRef<int64_t> inputVectorSizes) {
24442461
auto padValue = packOp.getPaddingValue();
24452462
Attribute cstAttr;
2463+
// TODO: Relax this condiiton
24462464
if (padValue && !matchPattern(padValue, m_Constant(&cstAttr))) {
24472465
LDBG() << "pad value is not constant: " << packOp;
24482466
return failure();

mlir/lib/Dialect/Linalg/Utils/Utils.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -171,9 +171,9 @@ computePackUnPackPerm(int64_t rank, ArrayRef<int64_t> &innerDimsPos,
171171
namespace mlir {
172172
namespace linalg {
173173

174-
SmallVector<int64_t> getPackInverseDestPerm(PackOp packOp) {
174+
SmallVector<int64_t> getPackInverseDestPerm(PackOp packOp,
175+
PackingMetadata &pMetadata) {
175176

176-
PackingMetadata pMetadata;
177177
int64_t packedRank = packOp.getDestType().getRank();
178178
ArrayRef<int64_t> innerDimPos = packOp.getInnerDimsPos();
179179
ArrayRef<int64_t> outerPerm = packOp.getOuterDimsPerm();
@@ -189,11 +189,11 @@ SmallVector<int64_t> getUnPackInverseSrcPerm(UnPackOp unpackOp) {
189189

190190
SmallVector<int64_t> getUnPackInverseSrcPerm(UnPackOp unpackOp,
191191
PackingMetadata &metadata) {
192-
int64_t unpackRank = unpackOp.getSourceType().getRank();
192+
int64_t packedRank = unpackOp.getSourceType().getRank();
193193
ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos();
194194
ArrayRef<int64_t> outerPerm = unpackOp.getOuterDimsPerm();
195195
SmallVector<int64_t> unpackInvSrcPerm =
196-
computePackUnPackPerm(unpackRank, innerDimPos, outerPerm, metadata);
196+
computePackUnPackPerm(packedRank, innerDimPos, outerPerm, metadata);
197197
return unpackInvSrcPerm;
198198
}
199199

mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,8 @@ module attributes {transform.with_named_sequence} {
285285

286286
///----------------------------------------------------------------------------------------
287287
/// Tests for linalg.pack
288+
///
289+
/// TODO: Add similar tests for linalg.unpack
288290
///----------------------------------------------------------------------------------------
289291

290292
// Note, see a similar test in:

0 commit comments

Comments
 (0)