@@ -1564,13 +1564,6 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
15641564 return success ();
15651565}
15661566
1567- // / Given a linalg::PackOp, return the `dest` shape before any packing
1568- // / permutations.
1569- static SmallVector<int64_t > getTiledPackShape (linalg::PackOp packOp,
1570- ArrayRef<int64_t > destShape) {
1571- return applyPermutation (destShape, linalg::getPackInverseDestPerm (packOp));
1572- }
1573-
15741567// / Determines whether a mask for xfer_write is trivially "all true"
15751568// /
15761569// / Given all the inputs required to generate a mask (mask sizes and shapes),
@@ -1761,99 +1754,6 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vecToStore,
17611754 return mlir::vector::maskOperation (builder, write, maskForWrite);
17621755}
17631756
1764- // / Vectorize linalg::PackOp with (1) static inner_tiles (2) constant
1765- // / padding value and (3) input vector sizes into:
1766- // /
1767- // / masked_transfer_read->shape_cast->transpose->transfer_write_in_bounds
1768- // /
1769- // / As in the following example:
1770- // / %pack = tensor.pack %src inner_dims_pos = [2, 1] inner_tiles = [16, 2]
1771- // / into %dst : tensor<32x8x16xf32> -> tensor<32x4x1x16x2xf32>
1772- // /
1773- // / This pack would be vectorized to:
1774- // /
1775- // / %load = vector.mask %mask {
1776- // / vector.transfer_read %arg0[%c0, %c0, %c0], %cst
1777- // / {in_bounds = [true, true, true]} :
1778- // / tensor<32x7x16xf32>, vector<32x8x16xf32>
1779- // / } : vector<32x8x16xi1> -> vector<32x8x16xf32>
1780- // / %shape_cast = vector.shape_cast %load : vector<32x8x16xf32>
1781- // / to vector<32x4x2x1x16xf32>
1782- // / %transpose = vector.transpose %shape_cast, [0, 1, 3, 4, 2]
1783- // / : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32>
1784- // / %write = vector.transfer_write %transpose,
1785- // / %empty[%c0_0, %c0_0, %c0_0, %c0_0, %c0_0]
1786- // / {in_bounds = [true, true, true, true, true]}
1787- // / : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32>
1788- // /
1789- // / If the (3) input vector sizes are not provided, the vector sizes are
1790- // / determined by the result tensor shape and the `in_bounds`
1791- // / attribute is used instead of masking to mark out-of-bounds accesses.
1792- // /
1793- // / NOTE: The input vector sizes specify the dimensions corresponding to the
1794- // / outer dimensions of the output tensor. The remaining dimensions are
1795- // / computed based on, e.g., the static inner tiles.
1796- // / Supporting dynamic inner tiles will require the user to specify the
1797- // / missing vector sizes. This is left as a TODO.
1798- static LogicalResult
1799- vectorizeAsTensorPackOp (RewriterBase &rewriter, linalg::PackOp packOp,
1800- ArrayRef<int64_t > inputVectorSizes,
1801- SmallVectorImpl<Value> &newResults) {
1802- // TODO: Introduce a parent class that will handle the insertion point update.
1803- OpBuilder::InsertionGuard g (rewriter);
1804- rewriter.setInsertionPoint (packOp);
1805-
1806- Location loc = packOp.getLoc ();
1807- std::optional<Value> padValue = packOp.getPaddingValue ()
1808- ? std::optional (packOp.getPaddingValue ())
1809- : std::nullopt ;
1810-
1811- // If the input vector sizes are not provided, then the vector sizes are
1812- // determined by the result tensor shape. In case the vector sizes aren't
1813- // provided, we update the inBounds attribute instead of masking.
1814- bool useInBoundsInsteadOfMasking = false ;
1815- if (inputVectorSizes.empty ()) {
1816- ArrayRef<int64_t > resultTensorShape = packOp.getDestType ().getShape ();
1817- inputVectorSizes = resultTensorShape.take_front (packOp.getSourceRank ());
1818- useInBoundsInsteadOfMasking = true ;
1819- }
1820-
1821- // Create masked TransferReadOp.
1822- SmallVector<int64_t > inputShape (inputVectorSizes);
1823- auto innerTiles = packOp.getStaticInnerTiles ();
1824- auto innerDimsPos = packOp.getInnerDimsPos ();
1825- auto outerDimsPerm = packOp.getOuterDimsPerm ();
1826- if (!outerDimsPerm.empty ())
1827- applyPermutationToVector (inputShape,
1828- invertPermutationVector (outerDimsPerm));
1829- for (auto [idx, size] : enumerate(innerTiles))
1830- inputShape[innerDimsPos[idx]] *= size;
1831- auto maskedRead = vector::createReadOrMaskedRead (
1832- rewriter, loc, packOp.getSource (), inputShape, padValue,
1833- useInBoundsInsteadOfMasking,
1834- /* inputScalableVecSizes=*/ {});
1835-
1836- // Create ShapeCastOp.
1837- SmallVector<int64_t > destShape (inputVectorSizes);
1838- destShape.append (innerTiles.begin (), innerTiles.end ());
1839- auto tiledPackType = VectorType::get (getTiledPackShape (packOp, destShape),
1840- packOp.getDestType ().getElementType ());
1841- auto shapeCastOp =
1842- vector::ShapeCastOp::create (rewriter, loc, tiledPackType, maskedRead);
1843-
1844- // Create TransposeOp.
1845- auto destPermutation =
1846- invertPermutationVector (getPackInverseDestPerm (packOp));
1847- auto transposeOp = vector::TransposeOp::create (
1848- rewriter, loc, shapeCastOp.getResult (), destPermutation);
1849-
1850- // Create TransferWriteOp.
1851- Operation *write = createWriteOrMaskedWrite (
1852- rewriter, loc, transposeOp.getResult (), packOp.getDest ());
1853- newResults.push_back (write->getResult (0 ));
1854- return success ();
1855- }
1856-
18571757// / Given the re-associations, "collapses" the input Vector type
18581758// /
18591759// / This is similar to CollapseShapeOp::inferCollapsedType with two notable
@@ -1901,12 +1801,119 @@ static VectorType getCollapsedVecType(VectorType type,
19011801 return VectorType::get (newShape, type.getElementType (), newScalableFlags);
19021802}
19031803
1804+ // / Vectorize `linalg.pack` as:
1805+ // / * xfer_read -> shape_cast -> transpose -> xfer_write
1806+ // /
1807+ // / The input-vector-sizes specify the _write_ vector sizes (i.e. the vector
1808+ // / sizes for the xfer_write operation). This is sufficient to infer the other
1809+ // / vector sizes required here.
1810+ // /
1811+ // / If the vector sizes are not provided:
1812+ // / * the vector sizes are determined from the destination tensor static shape.
1813+ // / * the inBounds attribute is used instead of masking.
1814+ // /
1815+ // / EXAMPLE (no vector sizes):
1816+ // / ```
1817+ // / %pack = tensor.pack %src
1818+ // / inner_dims_pos = [2, 1]
1819+ // / inner_tiles = [16, 2]
1820+ // / into %dst : tensor<32x8x16xf32> -> tensor<32x4x1x16x2xf32>
1821+ // / ``
1822+ // / is vectorizes as:
1823+ // / ```
1824+ // / %read = vector.transfer_read %src
1825+ // / : tensor<32x7x16xf32>, vector<32x8x16xf32>
1826+ // / %sc = vector.shape_cast %read
1827+ // / : vector<32x8x16xf32> to vector<32x4x2x1x16xf32>
1828+ // / %tr = vector.transpose %sc, [0, 1, 3, 4, 2]
1829+ // / : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32>
1830+ // / %write = vector.transfer_write %tr into %dest
1831+ // / : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32>
1832+ // / ```
1833+ static LogicalResult
1834+ vectorizeAsTensorPackOp (RewriterBase &rewriter, linalg::PackOp packOp,
1835+ ArrayRef<int64_t > inputVectorSizes,
1836+ SmallVectorImpl<Value> &newResults) {
1837+ if (!inputVectorSizes.empty ()) {
1838+ assert (inputVectorSizes.size () == packOp.getDestRank () &&
1839+ " Invalid number of input vector sizes!" );
1840+ }
1841+
1842+ // TODO: Introduce a parent class that will handle the insertion point update.
1843+ OpBuilder::InsertionGuard g (rewriter);
1844+ rewriter.setInsertionPoint (packOp);
1845+
1846+ Location loc = packOp.getLoc ();
1847+ std::optional<Value> padValue = packOp.getPaddingValue ()
1848+ ? std::optional (packOp.getPaddingValue ())
1849+ : std::nullopt ;
1850+
1851+ SmallVector<int64_t > destShape =
1852+ SmallVector<int64_t >(packOp.getDestType ().getShape ());
1853+
1854+ // This is just a convenience alias to clearly communicate that the input
1855+ // vector sizes determine the _write_ sizes.
1856+ ArrayRef<int64_t > &writeVectorSizes = inputVectorSizes;
1857+
1858+ // In the absence of input-vector-sizes, use the _static_ input tensor shape.
1859+ // In addition, use the inBounds attribute instead of masking.
1860+ bool useInBoundsInsteadOfMasking = false ;
1861+ if (writeVectorSizes.empty ()) {
1862+ if (ShapedType::isDynamicShape (destShape))
1863+ return rewriter.notifyMatchFailure (packOp,
1864+ " Unable to infer vector sizes!" );
1865+
1866+ writeVectorSizes = destShape;
1867+ useInBoundsInsteadOfMasking = true ;
1868+ }
1869+
1870+ // Compute vector type for the _read_ opeartion. The required dims are
1871+ // determined based on the _write_ vector sizes. This is done in two
1872+ // steps:
1873+ // 1) Invert the permutation/transposition that's part of the Pack
1874+ // operation.
1875+ // 2) Collapse the tiled sizes/dims to "return" to the unpacked domain.
1876+ PackingMetadata packMetadata;
1877+ auto destInvPermutation = getPackInverseDestPerm (packOp, packMetadata);
1878+
1879+ SmallVector<int64_t > inputVecSizesPrePerm (writeVectorSizes);
1880+ applyPermutationToVector (inputVecSizesPrePerm, destInvPermutation);
1881+
1882+ VectorType readVecType = getCollapsedVecType (
1883+ VectorType::get (inputVecSizesPrePerm, packOp.getType ().getElementType ()),
1884+ getSymbolLessAffineMaps (convertReassociationIndicesToExprs (
1885+ rewriter.getContext (), packMetadata.reassociations )));
1886+
1887+ // Create masked TransferReadOp.
1888+ auto maskedRead = vector::createReadOrMaskedRead (
1889+ rewriter, loc, packOp.getSource (), readVecType.getShape (), padValue,
1890+ useInBoundsInsteadOfMasking,
1891+ /* inputScalableVecSizes=*/ {});
1892+
1893+ // Create ShapeCastOp.
1894+ auto expandedVecType =
1895+ VectorType::get (inputVecSizesPrePerm, packOp.getType ().getElementType ());
1896+ auto shapeCastOp =
1897+ vector::ShapeCastOp::create (rewriter, loc, expandedVecType, maskedRead);
1898+
1899+ // Create TransposeOp.
1900+ auto destPermutation = invertPermutationVector (destInvPermutation);
1901+ auto transposeOp = vector::TransposeOp::create (
1902+ rewriter, loc, shapeCastOp.getResult (), destPermutation);
1903+
1904+ // Create TransferWriteOp.
1905+ Operation *write = createWriteOrMaskedWrite (
1906+ rewriter, loc, transposeOp.getResult (), packOp.getDest ());
1907+ newResults.push_back (write->getResult (0 ));
1908+ return success ();
1909+ }
1910+
19041911// / Vectorize `linalg.unpack` as:
19051912// / * xfer_read -> vector.transpose -> vector.shape_cast -> xfer_write
19061913// /
1907- // / The input-vector-sizes specify the read vector sizes (i.e. the vector sizes
1908- // / for the xfer_read operation). This is sufficient to infer the other vector
1909- // / sizes required here.
1914+ // / The input-vector-sizes specify the _read_ vector sizes (i.e. the vector
1915+ // / sizes for the xfer_read operation). This is sufficient to infer the other
1916+ // / vector sizes required here.
19101917// /
19111918// / If the vector sizes are not provided:
19121919// / * the vector sizes are determined from the input tensor static shape.
@@ -1960,7 +1967,8 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
19601967 // In the absence of input-vector-sizes, use the _static_ input tensor shape.
19611968 if (inputVectorSizes.empty ()) {
19621969 if (ShapedType::isDynamicShape (sourceShape))
1963- return failure ();
1970+ return rewriter.notifyMatchFailure (unpackOp,
1971+ " Unable to infer vector sizes!" );
19641972
19651973 readVectorSizes.assign (sourceShape.begin (), sourceShape.end ());
19661974 useInBoundsInsteadOfMasking = true ;
@@ -2443,6 +2451,7 @@ vectorizePackOpPrecondition(linalg::PackOp packOp,
24432451 ArrayRef<int64_t > inputVectorSizes) {
24442452 auto padValue = packOp.getPaddingValue ();
24452453 Attribute cstAttr;
2454+ // TODO: Relax this condiiton
24462455 if (padValue && !matchPattern (padValue, m_Constant (&cstAttr))) {
24472456 LDBG () << " pad value is not constant: " << packOp;
24482457 return failure ();
0 commit comments