@@ -1506,29 +1506,67 @@ static SmallVector<int64_t> getTiledPackShape(linalg::PackOp packOp,
1506
1506
return applyPermutation (destShape, linalg::getPackInverseDestPerm (packOp));
1507
1507
}
1508
1508
1509
- // / Given an input, the mixed destSizes, and the vector sizes for vectorization,
1510
- // / create an empty destination tensor and create a TransferWriteOp from the
1511
- // / input to the empty tensor. If the destination shape is not the same as the
1512
- // / inputVectorSizes for the first rank(inputVectorSizes) dims, then create a
1513
- // / mask for the write. If `useInBoundsInsteadOfMasking` is set, then update the
1514
- // / inBounds attribute of the transfer write op instead of masking.
1515
- static Operation *createWriteOrMaskedWrite (OpBuilder &builder, Location loc,
1516
- Value input,
1517
- SmallVector<OpFoldResult> destSizes,
1518
- ArrayRef<int64_t > inputVectorSizes,
1519
- bool useInBoundsInsteadOfMasking) {
1509
+ // / Creates a TransferWriteOp to write `input` into a newly initialized
1510
+ // / output tensor.
1511
+ // /
1512
+ // / Given:
1513
+ // / - an input vector to write,
1514
+ // / - the mixed destination sizes for the output tensor,
1515
+ // / - and the vector sizes used for vectorization (i.e., the leading N dims,
1516
+ // / for some value of N),
1517
+ // /
1518
+ // / this function generates the following sequence of ops:
1519
+ // /
1520
+ // / %dest = tensor.empty(%destSizes)
1521
+ // / %res = vector.transfer_write %input into %dest
1522
+ // /
1523
+ // / If the leading N dimensions of the destination tensor do not match
1524
+ // / `inputVecSizesForLeadingDims` (where N =
1525
+ // / rank(`inputVecSizesForLeadingDims`)), masking is applied to ensure
1526
+ // / correctness:
1527
+ // /
1528
+ // / %dest = tensor.empty(%destSizes)
1529
+ // / %write = vector.transfer_write %input into %dest
1530
+ // / %mask = vector.create_mask(%destSizes)
1531
+ // / %res = vector.mask %mask { %write }
1532
+ // /
1533
+ // / If `useInBoundsInsteadOfMasking` is set to `true`, the `in_bounds` attribute
1534
+ // / is used instead of masking:
1535
+ // /
1536
+ // / %dest = tensor.empty(%destSizes)
1537
+ // / in_bounds_flags = (...)
1538
+ // / %res = vector.transfer_write %input into %dest
1539
+ // / {in_bounds = in_bounds_flags}
1540
+ // /
1541
+ // / NOTE: all write offsets are set to 0.
1542
+ // / NOTE: When N < rank(input), the missing vector sizes are effectively
1543
+ // / extracted from the trailing sizes of `destSizes`. This means those sizes
1544
+ // / must be static. Supporting dynamic sizes will require the user to specify
1545
+ // / the remaining vector sizes. This is left as a TODO.
1546
+ static Operation *
1547
+ createWriteOrMaskedWrite (OpBuilder &builder, Location loc, Value input,
1548
+ SmallVector<OpFoldResult> destSizes,
1549
+ ArrayRef<int64_t > inputVecSizesForLeadingDims,
1550
+ bool useInBoundsInsteadOfMasking = false ) {
1520
1551
1521
1552
auto inputType = cast<VectorType>(input.getType ());
1553
+ assert (inputType.getRank () == static_cast <int64_t >(destSizes.size ()) &&
1554
+ " Rank mismatch!" );
1555
+
1522
1556
Value dest = builder.create <tensor::EmptyOp>(loc, destSizes,
1523
1557
inputType.getElementType ());
1524
1558
int64_t rank = cast<ShapedType>(dest.getType ()).getRank ();
1525
1559
auto zero = builder.create <arith::ConstantIndexOp>(loc, 0 );
1526
1560
auto destShape = cast<ShapedType>(dest.getType ()).getShape ();
1527
1561
SmallVector<bool > inBoundsVal (rank, true );
1528
1562
if (useInBoundsInsteadOfMasking) {
1563
+ // In this case, assume that all the required vector sizes have been
1564
+ // provided.
1565
+ assert (inputVecSizesForLeadingDims.size () == destSizes.size () &&
1566
+ " Insufficient number of input vector sizes!" );
1529
1567
// Update the inBounds attribute.
1530
1568
for (unsigned i = 0 ; i < rank; i++)
1531
- inBoundsVal[i] = (destShape[i] == inputVectorSizes [i]) &&
1569
+ inBoundsVal[i] = (destShape[i] == inputVecSizesForLeadingDims [i]) &&
1532
1570
!ShapedType::isDynamic (destShape[i]);
1533
1571
}
1534
1572
Operation *write = builder.create <vector::TransferWriteOp>(
@@ -1538,17 +1576,20 @@ static Operation *createWriteOrMaskedWrite(OpBuilder &builder, Location loc,
1538
1576
/* indices=*/ SmallVector<Value>(rank, zero),
1539
1577
/* inBounds=*/ inBoundsVal);
1540
1578
assert (llvm::none_of (
1541
- destShape.drop_front (inputVectorSizes .size ()),
1579
+ destShape.drop_front (inputVecSizesForLeadingDims .size ()),
1542
1580
[](int64_t size) { return size == ShapedType::kDynamic ; }) &&
1543
- " Only dims aligned with inputVectorSizes may be dynamic" );
1581
+ " Only dims aligned with inputVecSizesForLeadingDims may be dynamic" );
1544
1582
if (useInBoundsInsteadOfMasking)
1545
1583
return write;
1546
- bool needMaskForWrite = !llvm::equal (
1547
- inputVectorSizes, destShape.take_front (inputVectorSizes.size ()));
1584
+ bool needMaskForWrite =
1585
+ !llvm::equal (inputVecSizesForLeadingDims,
1586
+ destShape.take_front (inputVecSizesForLeadingDims.size ()));
1548
1587
if (needMaskForWrite) {
1549
1588
SmallVector<int64_t > writeMaskShape;
1550
- writeMaskShape.append (inputVectorSizes.begin (), inputVectorSizes.end ());
1551
- writeMaskShape.append (destShape.begin () + inputVectorSizes.size (),
1589
+ writeMaskShape.append (inputVecSizesForLeadingDims.begin (),
1590
+ inputVecSizesForLeadingDims.end ());
1591
+ writeMaskShape.append (destShape.begin () +
1592
+ inputVecSizesForLeadingDims.size (),
1552
1593
destShape.end ());
1553
1594
auto writeMaskType = VectorType::get (writeMaskShape, builder.getI1Type ());
1554
1595
Value maskForWrite =
@@ -1558,9 +1599,11 @@ static Operation *createWriteOrMaskedWrite(OpBuilder &builder, Location loc,
1558
1599
return write;
1559
1600
}
1560
1601
1561
- // / Vectorize linalg::PackOp with (1) static innerTiles (2) constant
1602
+ // / Vectorize linalg::PackOp with (1) static inner_tiles (2) constant
1562
1603
// / padding value and (3) input vector sizes into:
1563
- // / masked_transfer_read->shape_cast->transpose->transfer_write_in_bounds
1604
+ // /
1605
+ // / masked_transfer_read->shape_cast->transpose->transfer_write_in_bounds
1606
+ // /
1564
1607
// / As in the following example:
1565
1608
// / %pack = tensor.pack %src inner_dims_pos = [2, 1] inner_tiles = [16, 2]
1566
1609
// / into %dst : tensor<32x8x16xf32> -> tensor<32x4x1x16x2xf32>
@@ -1582,8 +1625,14 @@ static Operation *createWriteOrMaskedWrite(OpBuilder &builder, Location loc,
1582
1625
// / : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32>
1583
1626
// /
1584
1627
// / If the (3) input vector sizes are not provided, the vector sizes are
1585
- // / determined by the result tensor shape. Also, we update the inBounds
1586
- // / attribute instead of masking.
1628
+ // / determined by the result tensor shape and the `in_bounds`
1629
+ // / attribute is used instead of masking to mark out-of-bounds accesses.
1630
+ // /
1631
+ // / NOTE: The input vector sizes specify the dimensions corresponding to the
1632
+ // / outer dimensions of the output tensor. The remaining dimensions are
1633
+ // / computed based on, e.g., the static inner tiles.
1634
+ // / Supporting dynamic inner tiles will require the user to specify the
1635
+ // / missing vector sizes. This is left as a TODO.
1587
1636
static LogicalResult
1588
1637
vectorizeAsTensorPackOp (RewriterBase &rewriter, linalg::PackOp packOp,
1589
1638
ArrayRef<int64_t > inputVectorSizes,
@@ -1644,9 +1693,11 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
1644
1693
loc, shapeCastOp.getResult (), destPermutation);
1645
1694
1646
1695
// Create TransferWriteOp.
1647
- Operation *write = createWriteOrMaskedWrite (
1648
- rewriter, loc, transposeOp.getResult (), reifiedReturnShapes[0 ],
1649
- inputVectorSizes, /* useInBoundsInsteadOfMasking=*/ false );
1696
+ Operation *write =
1697
+ createWriteOrMaskedWrite (rewriter, loc, transposeOp.getResult (),
1698
+ /* destSizes=*/ reifiedReturnShapes[0 ],
1699
+ /* inputVecSizesForLeadingDims=*/ inputVectorSizes,
1700
+ /* useInBoundsInsteadOfMasking=*/ false );
1650
1701
newResults.push_back (write->getResult (0 ));
1651
1702
return success ();
1652
1703
}
@@ -1780,8 +1831,9 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
1780
1831
? vectorSizes
1781
1832
: shapeCastOp.getResultVectorType ().getShape ());
1782
1833
Operation *write = createWriteOrMaskedWrite (
1783
- rewriter, loc, shapeCastOp.getResult (), reifiedRetShapes[0 ],
1784
- writeVectorSizes, useInBoundsInsteadOfMasking);
1834
+ rewriter, loc, shapeCastOp.getResult (), /* destSizes=*/ reifiedRetShapes[0 ],
1835
+ /* inputVecSizesForLeadingDims=*/ writeVectorSizes,
1836
+ useInBoundsInsteadOfMasking);
1785
1837
newResults.push_back (write->getResult (0 ));
1786
1838
return success ();
1787
1839
}
@@ -1810,7 +1862,8 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
1810
1862
rewriter, loc, padOp.getSource (), inputVectorSizes, padValue,
1811
1863
/* useInBoundsInsteadOfMasking=*/ false );
1812
1864
Operation *write = createWriteOrMaskedWrite (
1813
- rewriter, loc, maskedRead, reifiedReturnShapes[0 ], inputVectorSizes,
1865
+ rewriter, loc, maskedRead, reifiedReturnShapes[0 ],
1866
+ /* inputVecSizesForLeadingDims=*/ inputVectorSizes,
1814
1867
/* useInBoundsInsteadOfMasking=*/ false );
1815
1868
newResults.push_back (write->getResult (0 ));
1816
1869
return success ();
0 commit comments