Skip to content

Commit 7557304

Browse files
authored
[mlir][linalg] Update vectorization of linalg.pack (#163539)
This patch changes `vectorizeAsTensorPackOp` to require users to specify **all** write-side vector sizes for `linalg.pack` (not just the outer dimensions). This makes `linalg.pack` vectorization consistent with `linalg.unpack` (see #149293 for a similar change). Conceptually, `linalg.pack` consists of these high-level steps: * **Read** from the source tensor using `vector.transfer_read`. * **Re-associate** dimensions of the read value, as specified by the op (via `vector.shape_cast`) * **Transpose** the re-associated value according to the permutation in the `linalg.pack` op (via `vector.transpose`). * **Write** the result into the destination tensor via `vector.transfer_write`. Previously, the vector sizes provided by the user were interpreted as write-vector-sizes for PackOp **_outer_** dims (i.e. the final step above). These were used to: * Infer read-vector-sizes using the `inner_tiles` attribute of PackOp. * Deduce vector sizes for the transpose and shape cast operations. * Ultimately determine the vector shape for the read. However, this logic breaks when one or more tile sizes are dynamic (*). In such cases, `vectorizePackOpPrecondition` would currently fail (see `@pack_with_dynamic_dims_and_dynamic_inner_tile` added in this PR - without this change it will crash). This patch updates the contract: users now directly specify _all_ the "write-vector-sizes", which inherently encode all inner tile sizes - including dynamic ones. It becomes the user's responsibility to provide valid sizes. In practice, since `linalg.pack` is typically constructed, tiled, and vectorized by the same transformation pipeline, the necessary "write-vector-sizes" should be recoverable. Notes for reviewers: * See test updates for user-facing impact. * Review `vectorizeAsTensorPackOp` as a new implementation rather than a diff. * Comments and variable names were updated to align with `vectorizeAsTensorUnPackOp`. (*) As a concrete example, "scalable" tile sizes are represent as dynamic values. Note, support for "scalable" vectorisation will be added in a separate PR.
1 parent 4830e63 commit 7557304

File tree

6 files changed

+220
-138
lines changed

6 files changed

+220
-138
lines changed

mlir/include/mlir/Dialect/Linalg/Utils/Utils.h

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -33,22 +33,14 @@ namespace linalg {
3333
//===----------------------------------------------------------------------===//
3434
// Utilities for inferring various semantics properties of Linalg ops.
3535
//===----------------------------------------------------------------------===//
36-
/// Shell function to compute the Destination Permutation of PackOp
37-
/// This function uses the helper function `computePackUnPackPerm` to get
38-
/// the permutation vector. Only major difference between UnPack and Pack is
39-
/// that packOp uses destination rank whereas unpack Uses source rank.
40-
SmallVector<int64_t> getPackInverseDestPerm(linalg::PackOp packOp);
41-
42-
/// Shell function to compute the Source Permutation of unPackOp.
43-
/// This function, like the getPackInverseDestPerm uses the helper function
44-
/// computePackUnPackPerm` to get the permutation vector.
45-
/// Only major difference between UnPack and Pack is that packOp uses
46-
/// destination rank whereas unpack Uses source rank.
47-
SmallVector<int64_t> getUnPackInverseSrcPerm(linalg::UnPackOp unpackOp);
48-
49-
/// Shell function to compute the Source rank permutation for unpackOp
50-
/// Unpack requires some packing metadata data information, so created
51-
/// another function where this value is passed by reference.
36+
37+
/// Compute inverse permutation for the destination tensor (i.e. in the packed
38+
/// domain).
39+
SmallVector<int64_t> getPackInverseDestPerm(linalg::PackOp packOp,
40+
PackingMetadata &metadata);
41+
42+
/// Compute inverse permutation for the source tensor (i.e. in the packed
43+
/// domain).
5244
SmallVector<int64_t> getUnPackInverseSrcPerm(linalg::UnPackOp,
5345
PackingMetadata &metadata);
5446

mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -232,10 +232,9 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
232232

233233
// 2. Compute the permutation vector to shuffle packed shape into the shape
234234
// before any outer or inner permutations have been applied.
235-
PackingMetadata packingMetadata = computePackingMetadata(
236-
packedTensorType.getRank(), packOp.getInnerDimsPos());
235+
PackingMetadata packingMetadata;
237236
SmallVector<int64_t> packedToStripMinedShapePerm =
238-
getPackInverseDestPerm(packOp);
237+
getPackInverseDestPerm(packOp, packingMetadata);
239238

240239
// 3. Compute the stripMinedShape: this is the packed shape before any outer
241240
// or inner permutations have been applied.

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 115 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -1564,13 +1564,6 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
15641564
return success();
15651565
}
15661566

1567-
/// Given a linalg::PackOp, return the `dest` shape before any packing
1568-
/// permutations.
1569-
static SmallVector<int64_t> getTiledPackShape(linalg::PackOp packOp,
1570-
ArrayRef<int64_t> destShape) {
1571-
return applyPermutation(destShape, linalg::getPackInverseDestPerm(packOp));
1572-
}
1573-
15741567
/// Determines whether a mask for xfer_write is trivially "all true"
15751568
///
15761569
/// Given all the inputs required to generate a mask (mask sizes and shapes),
@@ -1761,99 +1754,6 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vecToStore,
17611754
return mlir::vector::maskOperation(builder, write, maskForWrite);
17621755
}
17631756

1764-
/// Vectorize linalg::PackOp with (1) static inner_tiles (2) constant
1765-
/// padding value and (3) input vector sizes into:
1766-
///
1767-
/// masked_transfer_read->shape_cast->transpose->transfer_write_in_bounds
1768-
///
1769-
/// As in the following example:
1770-
/// %pack = tensor.pack %src inner_dims_pos = [2, 1] inner_tiles = [16, 2]
1771-
/// into %dst : tensor<32x8x16xf32> -> tensor<32x4x1x16x2xf32>
1772-
///
1773-
/// This pack would be vectorized to:
1774-
///
1775-
/// %load = vector.mask %mask {
1776-
/// vector.transfer_read %arg0[%c0, %c0, %c0], %cst
1777-
/// {in_bounds = [true, true, true]} :
1778-
/// tensor<32x7x16xf32>, vector<32x8x16xf32>
1779-
/// } : vector<32x8x16xi1> -> vector<32x8x16xf32>
1780-
/// %shape_cast = vector.shape_cast %load : vector<32x8x16xf32>
1781-
/// to vector<32x4x2x1x16xf32>
1782-
/// %transpose = vector.transpose %shape_cast, [0, 1, 3, 4, 2]
1783-
/// : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32>
1784-
/// %write = vector.transfer_write %transpose,
1785-
/// %empty[%c0_0, %c0_0, %c0_0, %c0_0, %c0_0]
1786-
/// {in_bounds = [true, true, true, true, true]}
1787-
/// : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32>
1788-
///
1789-
/// If the (3) input vector sizes are not provided, the vector sizes are
1790-
/// determined by the result tensor shape and the `in_bounds`
1791-
/// attribute is used instead of masking to mark out-of-bounds accesses.
1792-
///
1793-
/// NOTE: The input vector sizes specify the dimensions corresponding to the
1794-
/// outer dimensions of the output tensor. The remaining dimensions are
1795-
/// computed based on, e.g., the static inner tiles.
1796-
/// Supporting dynamic inner tiles will require the user to specify the
1797-
/// missing vector sizes. This is left as a TODO.
1798-
static LogicalResult
1799-
vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
1800-
ArrayRef<int64_t> inputVectorSizes,
1801-
SmallVectorImpl<Value> &newResults) {
1802-
// TODO: Introduce a parent class that will handle the insertion point update.
1803-
OpBuilder::InsertionGuard g(rewriter);
1804-
rewriter.setInsertionPoint(packOp);
1805-
1806-
Location loc = packOp.getLoc();
1807-
std::optional<Value> padValue = packOp.getPaddingValue()
1808-
? std::optional(packOp.getPaddingValue())
1809-
: std::nullopt;
1810-
1811-
// If the input vector sizes are not provided, then the vector sizes are
1812-
// determined by the result tensor shape. In case the vector sizes aren't
1813-
// provided, we update the inBounds attribute instead of masking.
1814-
bool useInBoundsInsteadOfMasking = false;
1815-
if (inputVectorSizes.empty()) {
1816-
ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
1817-
inputVectorSizes = resultTensorShape.take_front(packOp.getSourceRank());
1818-
useInBoundsInsteadOfMasking = true;
1819-
}
1820-
1821-
// Create masked TransferReadOp.
1822-
SmallVector<int64_t> inputShape(inputVectorSizes);
1823-
auto innerTiles = packOp.getStaticInnerTiles();
1824-
auto innerDimsPos = packOp.getInnerDimsPos();
1825-
auto outerDimsPerm = packOp.getOuterDimsPerm();
1826-
if (!outerDimsPerm.empty())
1827-
applyPermutationToVector(inputShape,
1828-
invertPermutationVector(outerDimsPerm));
1829-
for (auto [idx, size] : enumerate(innerTiles))
1830-
inputShape[innerDimsPos[idx]] *= size;
1831-
auto maskedRead = vector::createReadOrMaskedRead(
1832-
rewriter, loc, packOp.getSource(), inputShape, padValue,
1833-
useInBoundsInsteadOfMasking,
1834-
/*inputScalableVecSizes=*/{});
1835-
1836-
// Create ShapeCastOp.
1837-
SmallVector<int64_t> destShape(inputVectorSizes);
1838-
destShape.append(innerTiles.begin(), innerTiles.end());
1839-
auto tiledPackType = VectorType::get(getTiledPackShape(packOp, destShape),
1840-
packOp.getDestType().getElementType());
1841-
auto shapeCastOp =
1842-
vector::ShapeCastOp::create(rewriter, loc, tiledPackType, maskedRead);
1843-
1844-
// Create TransposeOp.
1845-
auto destPermutation =
1846-
invertPermutationVector(getPackInverseDestPerm(packOp));
1847-
auto transposeOp = vector::TransposeOp::create(
1848-
rewriter, loc, shapeCastOp.getResult(), destPermutation);
1849-
1850-
// Create TransferWriteOp.
1851-
Operation *write = createWriteOrMaskedWrite(
1852-
rewriter, loc, transposeOp.getResult(), packOp.getDest());
1853-
newResults.push_back(write->getResult(0));
1854-
return success();
1855-
}
1856-
18571757
/// Given the re-associations, "collapses" the input Vector type
18581758
///
18591759
/// This is similar to CollapseShapeOp::inferCollapsedType with two notable
@@ -1901,12 +1801,121 @@ static VectorType getCollapsedVecType(VectorType type,
19011801
return VectorType::get(newShape, type.getElementType(), newScalableFlags);
19021802
}
19031803

1804+
/// Vectorize `linalg.pack` as:
1805+
/// * xfer_read -> shape_cast -> transpose -> xfer_write
1806+
///
1807+
/// The input-vector-sizes specify the _write_ vector sizes (i.e. the vector
1808+
/// sizes for the xfer_write operation). This is sufficient to infer the other
1809+
/// vector sizes required here.
1810+
///
1811+
/// If the vector sizes are not provided:
1812+
/// * the vector sizes are determined from the destination tensor static shape.
1813+
/// * the inBounds attribute is used instead of masking.
1814+
///
1815+
/// EXAMPLE (no vector sizes):
1816+
/// ```
1817+
/// %pack = tensor.pack %src
1818+
/// inner_dims_pos = [2, 1]
1819+
/// inner_tiles = [16, 2]
1820+
/// into %dst : tensor<32x8x16xf32> -> tensor<32x4x1x16x2xf32>
1821+
/// ``
1822+
/// is vectorizes as:
1823+
/// ```
1824+
/// %read = vector.transfer_read %src
1825+
/// : tensor<32x7x16xf32>, vector<32x8x16xf32>
1826+
/// %sc = vector.shape_cast %read
1827+
/// : vector<32x8x16xf32> to vector<32x4x2x1x16xf32>
1828+
/// %tr = vector.transpose %sc, [0, 1, 3, 4, 2]
1829+
/// : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32>
1830+
/// %write = vector.transfer_write %tr into %dest
1831+
/// : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32>
1832+
/// ```
1833+
static LogicalResult
1834+
vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
1835+
ArrayRef<int64_t> inputVectorSizes,
1836+
SmallVectorImpl<Value> &newResults) {
1837+
if (!inputVectorSizes.empty()) {
1838+
assert(inputVectorSizes.size() == packOp.getDestRank() &&
1839+
"Invalid number of input vector sizes!");
1840+
}
1841+
1842+
// TODO: Introduce a parent class that will handle the insertion point update.
1843+
OpBuilder::InsertionGuard g(rewriter);
1844+
rewriter.setInsertionPoint(packOp);
1845+
1846+
Location loc = packOp.getLoc();
1847+
std::optional<Value> padValue = packOp.getPaddingValue()
1848+
? std::optional(packOp.getPaddingValue())
1849+
: std::nullopt;
1850+
1851+
SmallVector<int64_t> destShape =
1852+
SmallVector<int64_t>(packOp.getDestType().getShape());
1853+
1854+
// This is just a convenience alias to clearly communicate that the input
1855+
// vector sizes determine the _write_ sizes.
1856+
ArrayRef<int64_t> &writeVectorSizes = inputVectorSizes;
1857+
1858+
// In the absence of input-vector-sizes, use the _static_ input tensor shape.
1859+
// In addition, use the inBounds attribute instead of masking.
1860+
bool useInBoundsInsteadOfMasking = false;
1861+
if (writeVectorSizes.empty()) {
1862+
if (ShapedType::isDynamicShape(destShape))
1863+
return rewriter.notifyMatchFailure(packOp,
1864+
"unable to infer vector sizes");
1865+
1866+
writeVectorSizes = destShape;
1867+
useInBoundsInsteadOfMasking = true;
1868+
}
1869+
1870+
// Compute pre-transpose-write-vector-type, i.e. the write vector type
1871+
// _before_ the transposition (i.e. before dimension permutation). This is
1872+
// done by inverting the permutation/transposition that's part of the Pack
1873+
// operation. This type is required to:
1874+
// 1) compute the read vector type for masked-read below, and
1875+
// 2) generate shape-cast Op below that expands the read vector type.
1876+
PackingMetadata packMetadata;
1877+
SmallVector<int64_t> preTransposeWriteVecSizses(writeVectorSizes);
1878+
auto destInvPermutation = getPackInverseDestPerm(packOp, packMetadata);
1879+
applyPermutationToVector(preTransposeWriteVecSizses, destInvPermutation);
1880+
auto preTransposeWriteVecType = VectorType::get(
1881+
preTransposeWriteVecSizses, packOp.getType().getElementType());
1882+
1883+
// Compute vector type for the _read_ opeartion. This is simply
1884+
// pre-transpose-write-vector-type with the dimensions collapsed
1885+
// as per the Pack operation.
1886+
VectorType readVecType = getCollapsedVecType(
1887+
preTransposeWriteVecType,
1888+
getSymbolLessAffineMaps(convertReassociationIndicesToExprs(
1889+
rewriter.getContext(), packMetadata.reassociations)));
1890+
1891+
// Create masked TransferReadOp.
1892+
auto maskedRead = vector::createReadOrMaskedRead(
1893+
rewriter, loc, packOp.getSource(), readVecType.getShape(), padValue,
1894+
useInBoundsInsteadOfMasking,
1895+
/*inputScalableVecSizes=*/{});
1896+
1897+
// Create ShapeCastOp.
1898+
auto shapeCastOp = vector::ShapeCastOp::create(
1899+
rewriter, loc, preTransposeWriteVecType, maskedRead);
1900+
1901+
// Create TransposeOp.
1902+
auto destPermutation = invertPermutationVector(destInvPermutation);
1903+
auto transposeOp = vector::TransposeOp::create(
1904+
rewriter, loc, shapeCastOp.getResult(), destPermutation);
1905+
1906+
// Create TransferWriteOp.
1907+
Operation *write = createWriteOrMaskedWrite(
1908+
rewriter, loc, transposeOp.getResult(), packOp.getDest());
1909+
newResults.push_back(write->getResult(0));
1910+
return success();
1911+
}
1912+
19041913
/// Vectorize `linalg.unpack` as:
19051914
/// * xfer_read -> vector.transpose -> vector.shape_cast -> xfer_write
19061915
///
1907-
/// The input-vector-sizes specify the read vector sizes (i.e. the vector sizes
1908-
/// for the xfer_read operation). This is sufficient to infer the other vector
1909-
/// sizes required here.
1916+
/// The input-vector-sizes specify the _read_ vector sizes (i.e. the vector
1917+
/// sizes for the xfer_read operation). This is sufficient to infer the other
1918+
/// vector sizes required here.
19101919
///
19111920
/// If the vector sizes are not provided:
19121921
/// * the vector sizes are determined from the input tensor static shape.
@@ -1960,7 +1969,8 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
19601969
// In the absence of input-vector-sizes, use the _static_ input tensor shape.
19611970
if (inputVectorSizes.empty()) {
19621971
if (ShapedType::isDynamicShape(sourceShape))
1963-
return failure();
1972+
return rewriter.notifyMatchFailure(unpackOp,
1973+
"Unable to infer vector sizes!");
19641974

19651975
readVectorSizes.assign(sourceShape.begin(), sourceShape.end());
19661976
useInBoundsInsteadOfMasking = true;
@@ -2443,6 +2453,7 @@ vectorizePackOpPrecondition(linalg::PackOp packOp,
24432453
ArrayRef<int64_t> inputVectorSizes) {
24442454
auto padValue = packOp.getPaddingValue();
24452455
Attribute cstAttr;
2456+
// TODO: Relax this condiiton
24462457
if (padValue && !matchPattern(padValue, m_Constant(&cstAttr))) {
24472458
LDBG() << "pad value is not constant: " << packOp;
24482459
return failure();

mlir/lib/Dialect/Linalg/Utils/Utils.cpp

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -171,29 +171,24 @@ computePackUnPackPerm(int64_t rank, ArrayRef<int64_t> &innerDimsPos,
171171
namespace mlir {
172172
namespace linalg {
173173

174-
SmallVector<int64_t> getPackInverseDestPerm(PackOp packOp) {
174+
SmallVector<int64_t> getPackInverseDestPerm(PackOp packOp,
175+
PackingMetadata &metadata) {
175176

176-
PackingMetadata pMetadata;
177177
int64_t packedRank = packOp.getDestType().getRank();
178178
ArrayRef<int64_t> innerDimPos = packOp.getInnerDimsPos();
179179
ArrayRef<int64_t> outerPerm = packOp.getOuterDimsPerm();
180180
SmallVector<int64_t> packInvDestPerm =
181-
computePackUnPackPerm(packedRank, innerDimPos, outerPerm, pMetadata);
181+
computePackUnPackPerm(packedRank, innerDimPos, outerPerm, metadata);
182182
return packInvDestPerm;
183183
}
184184

185-
SmallVector<int64_t> getUnPackInverseSrcPerm(UnPackOp unpackOp) {
186-
PackingMetadata metadata;
187-
return getUnPackInverseSrcPerm(unpackOp, metadata);
188-
}
189-
190185
SmallVector<int64_t> getUnPackInverseSrcPerm(UnPackOp unpackOp,
191186
PackingMetadata &metadata) {
192-
int64_t unpackRank = unpackOp.getSourceType().getRank();
187+
int64_t packedRank = unpackOp.getSourceType().getRank();
193188
ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos();
194189
ArrayRef<int64_t> outerPerm = unpackOp.getOuterDimsPerm();
195190
SmallVector<int64_t> unpackInvSrcPerm =
196-
computePackUnPackPerm(unpackRank, innerDimPos, outerPerm, metadata);
191+
computePackUnPackPerm(packedRank, innerDimPos, outerPerm, metadata);
197192
return unpackInvSrcPerm;
198193
}
199194

mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,8 @@ module attributes {transform.with_named_sequence} {
285285

286286
///----------------------------------------------------------------------------------------
287287
/// Tests for linalg.pack
288+
///
289+
/// TODO: Add similar tests for linalg.unpack
288290
///----------------------------------------------------------------------------------------
289291

290292
// Note, see a similar test in:

0 commit comments

Comments
 (0)