@@ -305,12 +305,14 @@ static Value downcastSelectAndUpcast(OpBuilder &builder, Location loc,
305305 assert (
306306 downcastType.getNumElements () * downcastType.getElementTypeBitWidth () ==
307307 upcastType.getNumElements () * upcastType.getElementTypeBitWidth () &&
308- " expected upcastType size to be twice the size of downcastType " );
309- if (trueValue.getType () != downcastType)
308+ " expected input and output number of bits to match " );
309+ if (trueValue.getType () != downcastType) {
310310 trueValue = builder.create <vector::BitCastOp>(loc, downcastType, trueValue);
311- if (falseValue.getType () != downcastType)
311+ }
312+ if (falseValue.getType () != downcastType) {
312313 falseValue =
313314 builder.create <vector::BitCastOp>(loc, downcastType, falseValue);
315+ }
314316 Value selectedType =
315317 builder.create <arith::SelectOp>(loc, mask, trueValue, falseValue);
316318 // Upcast the selected value to the new type.
@@ -454,28 +456,33 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
454456 stridedMetadata.getConstifiedMixedStrides (),
455457 getAsOpFoldResult (adaptor.getIndices ()));
456458
457- auto foldedNumFrontPadElems =
459+ std::optional< int64_t > foldedNumFrontPadElems =
458460 isUnalignedEmulation
459461 ? getConstantIntValue (linearizedInfo.intraDataOffset )
460462 : 0 ;
461463
462464 if (!foldedNumFrontPadElems) {
463- // Unimplemented case for dynamic front padding size != 0
464- return failure ( );
465+ return failure ( " subbyte store emulation: dynamic front padding size is "
466+ " not yet implemented " );
465467 }
466468
467- auto linearizedMemref = cast<MemRefValue>(adaptor.getBase ());
469+ auto memrefBase = cast<MemRefValue>(adaptor.getBase ());
468470
469- // Shortcut: conditions when subbyte store at the front is not needed:
471+ // Shortcut: conditions when subbyte emulated store at the front is not
472+ // needed:
470473 // 1. The source vector size is multiple of byte size
471474 // 2. The address of the store is aligned to the emulated width boundary
475+ //
476+ // For example, to store a vector<4xi2> to <13xi2> at offset 4, does not
477+ // need unaligned emulation because the store address is aligned and the
478+ // source is a whole byte.
472479 if (!isUnalignedEmulation && *foldedNumFrontPadElems == 0 ) {
473480 auto numElements = origElements / numSrcElemsPerDest;
474481 auto bitCast = rewriter.create <vector::BitCastOp>(
475482 loc, VectorType::get (numElements, newElementType),
476483 op.getValueToStore ());
477484 rewriter.replaceOpWithNewOp <vector::StoreOp>(
478- op, bitCast.getResult (), linearizedMemref ,
485+ op, bitCast.getResult (), memrefBase ,
479486 getValueOrCreateConstantIndexOp (rewriter, loc, linearizedIndices));
480487 return success ();
481488 }
@@ -511,7 +518,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
511518 extractSliceIntoByte (rewriter, loc, valueToStore, 0 ,
512519 frontSubWidthStoreElem, *foldedNumFrontPadElems);
513520
514- atomicStore (rewriter, loc, linearizedMemref , currentDestIndex,
521+ atomicStore (rewriter, loc, memrefBase , currentDestIndex,
515522 cast<VectorValue>(value), frontMask.getResult ());
516523 }
517524
@@ -537,13 +544,13 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
537544 numNonFullWidthElements);
538545
539546 auto originType = cast<VectorType>(fullWidthStorePart.getType ());
540- auto memrefElemType = getElementTypeOrSelf (linearizedMemref .getType ());
547+ auto memrefElemType = getElementTypeOrSelf (memrefBase .getType ());
541548 auto storeType = VectorType::get (
542549 {originType.getNumElements () / numSrcElemsPerDest}, memrefElemType);
543550 auto bitCast = rewriter.create <vector::BitCastOp>(loc, storeType,
544551 fullWidthStorePart);
545- rewriter.create <vector::StoreOp>(loc, bitCast.getResult (),
546- linearizedMemref, currentDestIndex);
552+ rewriter.create <vector::StoreOp>(loc, bitCast.getResult (), memrefBase,
553+ currentDestIndex);
547554
548555 currentSourceIndex += numNonFullWidthElements;
549556 currentDestIndex = rewriter.create <arith::AddIOp>(
@@ -565,7 +572,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
565572 auto backMask = rewriter.create <arith::ConstantOp>(
566573 loc, DenseElementsAttr::get (subWidthStoreMaskType, maskValues));
567574
568- atomicStore (rewriter, loc, linearizedMemref , currentDestIndex,
575+ atomicStore (rewriter, loc, memrefBase , currentDestIndex,
569576 cast<VectorValue>(subWidthStorePart), backMask.getResult ());
570577 }
571578
0 commit comments