@@ -1506,84 +1506,86 @@ static SmallVector<int64_t> getTiledPackShape(linalg::PackOp packOp,
15061506 return applyPermutation (destShape, linalg::getPackInverseDestPerm (packOp));
15071507}
15081508
1509- // / Creates a TransferWriteOp to write `input` into a newly initialized
1510- // / output tensor.
1509+ // / Creates an optionally masked TransferWriteOp
15111510// /
1512- // / Given:
1513- // / - an input vector to write,
1514- // / - the mixed destination sizes for the output tensor,
1515- // / - and the vector sizes used for vectorization (i.e., the leading N dims,
1516- // / for some value of N),
1517- // /
1518- // / this function generates the following sequence of ops:
1519- // /
1520- // / %dest = tensor.empty(%destSizes)
1521- // / %res = vector.transfer_write %input into %dest
1511+ // / Generates the following operation:
1512+ // / %res = vector.transfer_write %vectorToStore into %dest
15221513// /
15231514// / If the leading N dimensions of the destination tensor do not match
1524- // / `inputVecSizesForLeadingDims` (where N =
1525- // / rank(`inputVecSizesForLeadingDims`)), masking is applied to ensure
1526- // / correctness:
1515+ // / `inputVecSizesForLeadingDims` (N = rank(inputVecSizesForLeadingDims)),
1516+ // / masking is applied to ensure correctness:
15271517// /
1528- // / %dest = tensor.empty(%destSizes )
1529- // / %write = vector.transfer_write %input into %dest
1530- // / %mask = vector.create_mask(%destSizes)
1531- // / %res = vector.mask %mask { %write }
1518+ // / %mask = vector.create_mask(%destShape )
1519+ // / %res = vector.mask %mask {
1520+ // / vector.transfer_write %vectorToStore into %dest
1521+ // / }
15321522// /
15331523// / If `useInBoundsInsteadOfMasking` is set to `true`, the `in_bounds` attribute
15341524// / is used instead of masking:
15351525// /
1536- // / %dest = tensor.empty(%destSizes)
1526+ // / %write = vector.transfer_write %vectorToStore into %dest
15371527// / in_bounds_flags = (...)
15381528// / %res = vector.transfer_write %input into %dest
15391529// / {in_bounds = in_bounds_flags}
15401530// /
1541- // / NOTE: all write offsets are set to 0.
1531+ // / NOTE: All write offsets are set to 0.
1532+ // / TODO: Allow specyfying write offsets.
15421533// / NOTE: When N < rank(input), the missing vector sizes are effectively
15431534// / extracted from the trailing sizes of `destSizes`. This means those sizes
1544- // / must be static. Supporting dynamic sizes will require the user to specify
1545- // / the remaining vector sizes. This is left as a TODO.
1535+ // / must be static.
1536+ // / TODO: Support cases where an arbitrary dim is dynamic - this will require
1537+ // / specifying all the vector sizes.
15461538static Operation *
1547- createWriteOrMaskedWrite (OpBuilder &builder, Location loc, Value input ,
1548- SmallVector<OpFoldResult> destSizes ,
1539+ createWriteOrMaskedWrite (OpBuilder &builder, Location loc, Value vectorToStore ,
1540+ Value dest ,
15491541 ArrayRef<int64_t > inputVecSizesForLeadingDims,
15501542 bool useInBoundsInsteadOfMasking = false ) {
15511543
1552- auto inputType = cast<VectorType>(input.getType ());
1553- assert (inputType.getRank () == static_cast <int64_t >(destSizes.size ()) &&
1544+ ShapedType destType = cast<ShapedType>(dest.getType ());
1545+ assert (cast<VectorType>(vectorToStore.getType ()).getRank () ==
1546+ static_cast <int64_t >(destType.getRank ()) &&
15541547 " Rank mismatch!" );
15551548
1556- Value dest = builder.create <tensor::EmptyOp>(loc, destSizes,
1557- inputType.getElementType ());
15581549 int64_t rank = cast<ShapedType>(dest.getType ()).getRank ();
1559- auto zero = builder.create <arith::ConstantIndexOp>(loc, 0 );
15601550 auto destShape = cast<ShapedType>(dest.getType ()).getShape ();
1551+
1552+ // Compute the in_bounds attribute
15611553 SmallVector<bool > inBoundsVal (rank, true );
15621554 if (useInBoundsInsteadOfMasking) {
15631555 // In this case, assume that all the required vector sizes have been
15641556 // provided.
1565- assert (inputVecSizesForLeadingDims.size () == destSizes.size () &&
1557+ assert (inputVecSizesForLeadingDims.size () ==
1558+ static_cast <size_t >(destType.getRank ()) &&
15661559 " Insufficient number of input vector sizes!" );
15671560 // Update the inBounds attribute.
15681561 for (unsigned i = 0 ; i < rank; i++)
15691562 inBoundsVal[i] = (destShape[i] == inputVecSizesForLeadingDims[i]) &&
15701563 !ShapedType::isDynamic (destShape[i]);
15711564 }
1565+
1566+ // Generate the xfer_write Op
1567+ auto zero = builder.create <arith::ConstantIndexOp>(loc, 0 );
15721568 Operation *write = builder.create <vector::TransferWriteOp>(
15731569 loc,
1574- /* vector=*/ input ,
1570+ /* vector=*/ vectorToStore ,
15751571 /* source=*/ dest,
15761572 /* indices=*/ SmallVector<Value>(rank, zero),
15771573 /* inBounds=*/ inBoundsVal);
15781574 assert (llvm::none_of (
15791575 destShape.drop_front (inputVecSizesForLeadingDims.size ()),
15801576 [](int64_t size) { return size == ShapedType::kDynamic ; }) &&
15811577 " Only dims aligned with inputVecSizesForLeadingDims may be dynamic" );
1578+
1579+ // If masking is disabled, exit.
15821580 if (useInBoundsInsteadOfMasking)
15831581 return write;
1582+
1583+ // Check if masking is needed.
15841584 bool needMaskForWrite =
15851585 !llvm::equal (inputVecSizesForLeadingDims,
15861586 destShape.take_front (inputVecSizesForLeadingDims.size ()));
1587+
1588+ // If masking is needed, generate the mask and mask the operation.
15871589 if (needMaskForWrite) {
15881590 SmallVector<int64_t > writeMaskShape;
15891591 writeMaskShape.append (inputVecSizesForLeadingDims.begin (),
@@ -1592,10 +1594,11 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value input,
15921594 inputVecSizesForLeadingDims.size (),
15931595 destShape.end ());
15941596 auto writeMaskType = VectorType::get (writeMaskShape, builder.getI1Type ());
1595- Value maskForWrite =
1596- builder. create <vector::CreateMaskOp>(loc, writeMaskType, destSizes );
1597+ Value maskForWrite = builder. create <vector::CreateMaskOp>(
1598+ loc, writeMaskType, tensor::getMixedSizes (builder, loc, dest) );
15971599 write = mlir::vector::maskOperation (builder, write, maskForWrite);
15981600 }
1601+
15991602 return write;
16001603}
16011604
@@ -1693,9 +1696,11 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
16931696 loc, shapeCastOp.getResult (), destPermutation);
16941697
16951698 // Create TransferWriteOp.
1699+ Value dest = rewriter.create <tensor::EmptyOp>(
1700+ loc, reifiedReturnShapes[0 ],
1701+ transposeOp.getResult ().getType ().getElementType ());
16961702 Operation *write =
1697- createWriteOrMaskedWrite (rewriter, loc, transposeOp.getResult (),
1698- /* destSizes=*/ reifiedReturnShapes[0 ],
1703+ createWriteOrMaskedWrite (rewriter, loc, transposeOp.getResult (), dest,
16991704 /* inputVecSizesForLeadingDims=*/ inputVectorSizes,
17001705 /* useInBoundsInsteadOfMasking=*/ false );
17011706 newResults.push_back (write->getResult (0 ));
@@ -1830,10 +1835,13 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
18301835 unpackOp.getDestType ().hasStaticShape ()
18311836 ? vectorSizes
18321837 : shapeCastOp.getResultVectorType ().getShape ());
1833- Operation *write = createWriteOrMaskedWrite (
1834- rewriter, loc, shapeCastOp.getResult (), /* destSizes=*/ reifiedRetShapes[0 ],
1835- /* inputVecSizesForLeadingDims=*/ writeVectorSizes,
1836- useInBoundsInsteadOfMasking);
1838+ Value dest = rewriter.create <tensor::EmptyOp>(
1839+ loc, reifiedRetShapes[0 ],
1840+ shapeCastOp.getResult ().getType ().getElementType ());
1841+ Operation *write =
1842+ createWriteOrMaskedWrite (rewriter, loc, shapeCastOp.getResult (), dest,
1843+ /* inputVecSizesForLeadingDims=*/ writeVectorSizes,
1844+ useInBoundsInsteadOfMasking);
18371845 newResults.push_back (write->getResult (0 ));
18381846 return success ();
18391847}
@@ -1861,10 +1869,14 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
18611869 auto maskedRead = vector::createReadOrMaskedRead (
18621870 rewriter, loc, padOp.getSource (), inputVectorSizes, padValue,
18631871 /* useInBoundsInsteadOfMasking=*/ false );
1864- Operation *write = createWriteOrMaskedWrite (
1865- rewriter, loc, maskedRead, reifiedReturnShapes[0 ],
1866- /* inputVecSizesForLeadingDims=*/ inputVectorSizes,
1867- /* useInBoundsInsteadOfMasking=*/ false );
1872+
1873+ // Create Xfer write Op
1874+ Value dest = rewriter.create <tensor::EmptyOp>(
1875+ loc, reifiedReturnShapes[0 ], padOp.getResultType ().getElementType ());
1876+ Operation *write =
1877+ createWriteOrMaskedWrite (rewriter, loc, maskedRead, dest,
1878+ /* inputVecSizesForLeadingDims=*/ inputVectorSizes,
1879+ /* useInBoundsInsteadOfMasking=*/ false );
18681880 newResults.push_back (write->getResult (0 ));
18691881 return success ();
18701882}
0 commit comments