@@ -1568,7 +1568,9 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
15681568// / permutations.
15691569static SmallVector<int64_t > getTiledPackShape (linalg::PackOp packOp,
15701570 ArrayRef<int64_t > destShape) {
1571- return applyPermutation (destShape, linalg::getPackInverseDestPerm (packOp));
1571+ PackingMetadata metadata;
1572+ return applyPermutation (destShape,
1573+ linalg::getPackInverseDestPerm (packOp, metadata));
15721574}
15731575
15741576// / Determines whether a mask for xfer_write is trivially "all true"
@@ -1761,99 +1763,6 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vecToStore,
17611763 return mlir::vector::maskOperation (builder, write, maskForWrite);
17621764}
17631765
1764- // / Vectorize linalg::PackOp with (1) static inner_tiles (2) constant
1765- // / padding value and (3) input vector sizes into:
1766- // /
1767- // / masked_transfer_read->shape_cast->transpose->transfer_write_in_bounds
1768- // /
1769- // / As in the following example:
1770- // / %pack = tensor.pack %src inner_dims_pos = [2, 1] inner_tiles = [16, 2]
1771- // / into %dst : tensor<32x8x16xf32> -> tensor<32x4x1x16x2xf32>
1772- // /
1773- // / This pack would be vectorized to:
1774- // /
1775- // / %load = vector.mask %mask {
1776- // / vector.transfer_read %arg0[%c0, %c0, %c0], %cst
1777- // / {in_bounds = [true, true, true]} :
1778- // / tensor<32x7x16xf32>, vector<32x8x16xf32>
1779- // / } : vector<32x8x16xi1> -> vector<32x8x16xf32>
1780- // / %shape_cast = vector.shape_cast %load : vector<32x8x16xf32>
1781- // / to vector<32x4x2x1x16xf32>
1782- // / %transpose = vector.transpose %shape_cast, [0, 1, 3, 4, 2]
1783- // / : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32>
1784- // / %write = vector.transfer_write %transpose,
1785- // / %empty[%c0_0, %c0_0, %c0_0, %c0_0, %c0_0]
1786- // / {in_bounds = [true, true, true, true, true]}
1787- // / : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32>
1788- // /
1789- // / If the (3) input vector sizes are not provided, the vector sizes are
1790- // / determined by the result tensor shape and the `in_bounds`
1791- // / attribute is used instead of masking to mark out-of-bounds accesses.
1792- // /
1793- // / NOTE: The input vector sizes specify the dimensions corresponding to the
1794- // / outer dimensions of the output tensor. The remaining dimensions are
1795- // / computed based on, e.g., the static inner tiles.
1796- // / Supporting dynamic inner tiles will require the user to specify the
1797- // / missing vector sizes. This is left as a TODO.
1798- static LogicalResult
1799- vectorizeAsTensorPackOp (RewriterBase &rewriter, linalg::PackOp packOp,
1800- ArrayRef<int64_t > inputVectorSizes,
1801- SmallVectorImpl<Value> &newResults) {
1802- // TODO: Introduce a parent class that will handle the insertion point update.
1803- OpBuilder::InsertionGuard g (rewriter);
1804- rewriter.setInsertionPoint (packOp);
1805-
1806- Location loc = packOp.getLoc ();
1807- std::optional<Value> padValue = packOp.getPaddingValue ()
1808- ? std::optional (packOp.getPaddingValue ())
1809- : std::nullopt ;
1810-
1811- // If the input vector sizes are not provided, then the vector sizes are
1812- // determined by the result tensor shape. In case the vector sizes aren't
1813- // provided, we update the inBounds attribute instead of masking.
1814- bool useInBoundsInsteadOfMasking = false ;
1815- if (inputVectorSizes.empty ()) {
1816- ArrayRef<int64_t > resultTensorShape = packOp.getDestType ().getShape ();
1817- inputVectorSizes = resultTensorShape.take_front (packOp.getSourceRank ());
1818- useInBoundsInsteadOfMasking = true ;
1819- }
1820-
1821- // Create masked TransferReadOp.
1822- SmallVector<int64_t > inputShape (inputVectorSizes);
1823- auto innerTiles = packOp.getStaticInnerTiles ();
1824- auto innerDimsPos = packOp.getInnerDimsPos ();
1825- auto outerDimsPerm = packOp.getOuterDimsPerm ();
1826- if (!outerDimsPerm.empty ())
1827- applyPermutationToVector (inputShape,
1828- invertPermutationVector (outerDimsPerm));
1829- for (auto [idx, size] : enumerate(innerTiles))
1830- inputShape[innerDimsPos[idx]] *= size;
1831- auto maskedRead = vector::createReadOrMaskedRead (
1832- rewriter, loc, packOp.getSource (), inputShape, padValue,
1833- useInBoundsInsteadOfMasking,
1834- /* inputScalableVecSizes=*/ {});
1835-
1836- // Create ShapeCastOp.
1837- SmallVector<int64_t > destShape (inputVectorSizes);
1838- destShape.append (innerTiles.begin (), innerTiles.end ());
1839- auto tiledPackType = VectorType::get (getTiledPackShape (packOp, destShape),
1840- packOp.getDestType ().getElementType ());
1841- auto shapeCastOp =
1842- vector::ShapeCastOp::create (rewriter, loc, tiledPackType, maskedRead);
1843-
1844- // Create TransposeOp.
1845- auto destPermutation =
1846- invertPermutationVector (getPackInverseDestPerm (packOp));
1847- auto transposeOp = vector::TransposeOp::create (
1848- rewriter, loc, shapeCastOp.getResult (), destPermutation);
1849-
1850- // Create TransferWriteOp.
1851- Operation *write = createWriteOrMaskedWrite (
1852- rewriter, loc, transposeOp.getResult (), packOp.getDest ());
1853- newResults.push_back (write->getResult (0 ));
1854- return success ();
1855- }
1856-
18571766// / Given the re-associations, "collapses" the input Vector type
18581767// /
18591768// / This is similar to CollapseShapeOp::inferCollapsedType with two notable
@@ -1901,12 +1810,119 @@ static VectorType getCollapsedVecType(VectorType type,
19011810 return VectorType::get (newShape, type.getElementType (), newScalableFlags);
19021811}
19031812
1813+ // / Vectorize `linalg.pack` as:
1814+ // / * xfer_read -> shape_cast -> transpose -> xfer_write
1815+ // /
1816+ // / The input-vector-sizes specify the _write_ vector sizes (i.e. the vector
1817+ // / sizes for the xfer_write operation). This is sufficient to infer the other
1818+ // / vector sizes required here.
1819+ // /
1820+ // / If the vector sizes are not provided:
1821+ // / * the vector sizes are determined from the destination tensor static shape.
1822+ // / * the inBounds attribute is used instead of masking.
1823+ // /
1824+ // / EXAMPLE (no vector sizes):
1825+ // / ```
1826+ // / %pack = tensor.pack %src
1827+ // / inner_dims_pos = [2, 1]
1828+ // / inner_tiles = [16, 2]
1829+ // / into %dst : tensor<32x8x16xf32> -> tensor<32x4x1x16x2xf32>
1830+ // / ``
1831+ // / is vectorizes as:
1832+ // / ```
1833+ // / %read = vector.transfer_read %src
1834+ // / : tensor<32x7x16xf32>, vector<32x8x16xf32>
1835+ // / %sc = vector.shape_cast %read
1836+ // / : vector<32x8x16xf32> to vector<32x4x2x1x16xf32>
1837+ // / %tr = vector.transpose %sc, [0, 1, 3, 4, 2]
1838+ // / : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32>
1839+ // / %write = vector.transfer_write %tr into %dest
1840+ // / : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32>
1841+ // / ```
1842+ static LogicalResult
1843+ vectorizeAsTensorPackOp (RewriterBase &rewriter, linalg::PackOp packOp,
1844+ ArrayRef<int64_t > inputVectorSizes,
1845+ SmallVectorImpl<Value> &newResults) {
1846+ if (!inputVectorSizes.empty ()) {
1847+ assert (inputVectorSizes.size () == packOp.getDestRank () &&
1848+ " Invalid number of input vector sizes!" );
1849+ }
1850+
1851+ // TODO: Introduce a parent class that will handle the insertion point update.
1852+ OpBuilder::InsertionGuard g (rewriter);
1853+ rewriter.setInsertionPoint (packOp);
1854+
1855+ Location loc = packOp.getLoc ();
1856+ std::optional<Value> padValue = packOp.getPaddingValue ()
1857+ ? std::optional (packOp.getPaddingValue ())
1858+ : std::nullopt ;
1859+
1860+ SmallVector<int64_t > destShape =
1861+ SmallVector<int64_t >(packOp.getDestType ().getShape ());
1862+
1863+ // This is just a convenience alias to clearly communicate that the input
1864+ // vector sizes determine the _write_ sizes.
1865+ ArrayRef<int64_t > &writeVectorSizes = inputVectorSizes;
1866+
1867+ // In the absence of input-vector-sizes, use the _static_ input tensor shape.
1868+ // In addition, use the inBounds attribute instead of masking.
1869+ bool useInBoundsInsteadOfMasking = false ;
1870+ if (writeVectorSizes.empty ()) {
1871+ if (ShapedType::isDynamicShape (destShape))
1872+ return rewriter.notifyMatchFailure (packOp,
1873+ " Unable to infer vector sizes!" );
1874+
1875+ writeVectorSizes = destShape;
1876+ useInBoundsInsteadOfMasking = true ;
1877+ }
1878+
1879+ // Compute vector type for the _read_ opeartion. The required dims are
1880+ // determined based on the _write_ vector sizes. This is done in two
1881+ // steps:
1882+ // 1) Invert the permutation/transposition that's part of the Pack
1883+ // operation.
1884+ // 2) Collapse the tiled sizes/dims to "return" to the unpacked domain.
1885+ PackingMetadata packMetadata;
1886+ auto destInvPermutation = getPackInverseDestPerm (packOp, packMetadata);
1887+
1888+ SmallVector<int64_t > inputVecSizesPrePerm (writeVectorSizes);
1889+ applyPermutationToVector (inputVecSizesPrePerm, destInvPermutation);
1890+
1891+ VectorType readVecType = getCollapsedVecType (
1892+ VectorType::get (inputVecSizesPrePerm, packOp.getType ().getElementType ()),
1893+ getSymbolLessAffineMaps (convertReassociationIndicesToExprs (
1894+ rewriter.getContext (), packMetadata.reassociations )));
1895+
1896+ // Create masked TransferReadOp.
1897+ auto maskedRead = vector::createReadOrMaskedRead (
1898+ rewriter, loc, packOp.getSource (), readVecType.getShape (), padValue,
1899+ useInBoundsInsteadOfMasking,
1900+ /* inputScalableVecSizes=*/ {});
1901+
1902+ // Create ShapeCastOp.
1903+ auto expandedVecType =
1904+ VectorType::get (inputVecSizesPrePerm, packOp.getType ().getElementType ());
1905+ auto shapeCastOp =
1906+ vector::ShapeCastOp::create (rewriter, loc, expandedVecType, maskedRead);
1907+
1908+ // Create TransposeOp.
1909+ auto destPermutation = invertPermutationVector (destInvPermutation);
1910+ auto transposeOp = vector::TransposeOp::create (
1911+ rewriter, loc, shapeCastOp.getResult (), destPermutation);
1912+
1913+ // Create TransferWriteOp.
1914+ Operation *write = createWriteOrMaskedWrite (
1915+ rewriter, loc, transposeOp.getResult (), packOp.getDest ());
1916+ newResults.push_back (write->getResult (0 ));
1917+ return success ();
1918+ }
1919+
19041920// / Vectorize `linalg.unpack` as:
19051921// / * xfer_read -> vector.transpose -> vector.shape_cast -> xfer_write
19061922// /
1907- // / The input-vector-sizes specify the read vector sizes (i.e. the vector sizes
1908- // / for the xfer_read operation). This is sufficient to infer the other vector
1909- // / sizes required here.
1923+ // / The input-vector-sizes specify the _read_ vector sizes (i.e. the vector
1924+ // / sizes for the xfer_read operation). This is sufficient to infer the other
1925+ // / vector sizes required here.
19101926// /
19111927// / If the vector sizes are not provided:
19121928// / * the vector sizes are determined from the input tensor static shape.
@@ -1960,7 +1976,8 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
19601976 // In the absence of input-vector-sizes, use the _static_ input tensor shape.
19611977 if (inputVectorSizes.empty ()) {
19621978 if (ShapedType::isDynamicShape (sourceShape))
1963- return failure ();
1979+ return rewriter.notifyMatchFailure (unpackOp,
1980+ " Unable to infer vector sizes!" );
19641981
19651982 readVectorSizes.assign (sourceShape.begin (), sourceShape.end ());
19661983 useInBoundsInsteadOfMasking = true ;
@@ -2443,6 +2460,7 @@ vectorizePackOpPrecondition(linalg::PackOp packOp,
24432460 ArrayRef<int64_t > inputVectorSizes) {
24442461 auto padValue = packOp.getPaddingValue ();
24452462 Attribute cstAttr;
2463+ // TODO: Relax this condiiton
24462464 if (padValue && !matchPattern (padValue, m_Constant (&cstAttr))) {
24472465 LDBG () << " pad value is not constant: " << packOp;
24482466 return failure ();
0 commit comments