3333#include " mlir/Transforms/DialectConversion.h"
3434#include " llvm/ADT/SmallVector.h"
3535#include " llvm/Support/Debug.h"
36+ #include " llvm/Support/LogicalResult.h"
3637#include " llvm/Support/MathExtras.h"
3738#include " llvm/Support/raw_ostream.h"
3839#include < cstdint>
@@ -208,13 +209,10 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
208209// / Extracts 1-D subvector from a 1-D vector. It is a wrapper function for
209210// / emitting `vector.extract_strided_slice`.
210211static Value staticallyExtractSubvector (OpBuilder &rewriter, Location loc,
211- VectorType extractType, Value source,
212- int64_t frontOffset,
212+ Value source, int64_t frontOffset,
213213 int64_t subvecSize) {
214214 auto vectorType = cast<VectorType>(source.getType ());
215- assert ((vectorType.getRank () == 1 && extractType.getRank () == 1 ) &&
216- " expected 1-D source and destination types" );
217- (void )vectorType;
215+ assert (vectorType.getRank () == 1 && " expected 1-D source types" );
218216 assert (frontOffset + subvecSize <= vectorType.getNumElements () &&
219217 " subvector out of bounds" );
220218
@@ -225,9 +223,12 @@ static Value staticallyExtractSubvector(OpBuilder &rewriter, Location loc,
225223 auto offsets = rewriter.getI64ArrayAttr ({frontOffset});
226224 auto sizes = rewriter.getI64ArrayAttr ({subvecSize});
227225 auto strides = rewriter.getI64ArrayAttr ({1 });
226+
227+ auto resultVectorType =
228+ VectorType::get ({subvecSize}, vectorType.getElementType ());
228229 return rewriter
229- .create <vector::ExtractStridedSliceOp>(loc, extractType , source, offsets ,
230- sizes, strides)
230+ .create <vector::ExtractStridedSliceOp>(loc, resultVectorType , source,
231+ offsets, sizes, strides)
231232 ->getResult (0 );
232233}
233234
@@ -306,6 +307,73 @@ emulatedVectorLoad(OpBuilder &rewriter, Location loc, Value base,
306307 newLoad);
307308}
308309
310+ // / Atomically store a subbyte-sized value to memory, with a mask.
311+ static void atomicStore (OpBuilder &rewriter, Location loc,
312+ TypedValue<MemRefType> emulatedMemref,
313+ Value emulatedIndex, TypedValue<VectorType> value,
314+ Value mask, int64_t scale) {
315+ auto atomicOp = rewriter.create <memref::GenericAtomicRMWOp>(
316+ loc, emulatedMemref, ValueRange{emulatedIndex});
317+ OpBuilder builder =
318+ OpBuilder::atBlockEnd (atomicOp.getBody (), rewriter.getListener ());
319+ Value origValue = atomicOp.getCurrentValue ();
320+
321+ // i8 -> vector type <1xi8> then <1xi8> -> <scale x i.>
322+ auto oneVectorType = VectorType::get ({1 }, origValue.getType ());
323+ auto fromElem = builder.create <vector::FromElementsOp>(loc, oneVectorType,
324+ ValueRange{origValue});
325+ auto vectorBitCast =
326+ builder.create <vector::BitCastOp>(loc, value.getType (), fromElem);
327+
328+ auto select =
329+ builder.create <arith::SelectOp>(loc, mask, value, vectorBitCast);
330+ auto bitcast2 = builder.create <vector::BitCastOp>(loc, oneVectorType, select);
331+ auto extract = builder.create <vector::ExtractOp>(loc, bitcast2, 0 );
332+ builder.create <memref::AtomicYieldOp>(loc, extract.getResult ());
333+ }
334+
335+ // / Generate a non-atomic read-modify-write sequence for subbyte storing.
336+ static void rmwStore (OpBuilder &rewriter, Location loc,
337+ TypedValue<MemRefType> emulatedMemref, Value emulatedIndex,
338+ TypedValue<VectorType> value, Value mask,
339+ int64_t numSrcElemsPerDest) {
340+ auto emulatedIOType =
341+ VectorType::get ({1 }, emulatedMemref.getType ().getElementType ());
342+ auto elemLoad = rewriter.create <vector::LoadOp>(
343+ loc, emulatedIOType, emulatedMemref, ValueRange{emulatedIndex});
344+ auto fromBitcast = rewriter.create <vector::BitCastOp>(
345+ loc,
346+ VectorType::get ({numSrcElemsPerDest}, value.getType ().getElementType ()),
347+ elemLoad);
348+ auto select = rewriter.create <arith::SelectOp>(loc, mask, fromBitcast, value);
349+ auto toBitcast =
350+ rewriter.create <vector::BitCastOp>(loc, emulatedIOType, select);
351+ rewriter.create <vector::StoreOp>(loc, toBitcast, emulatedMemref,
352+ emulatedIndex);
353+ }
354+
355+ static_assert (std::is_same_v<decltype (atomicStore), decltype (rmwStore)> &&
356+ " `atomicStore` and `rmwStore` must have same function type." );
357+
358+ // Extract a slice of a vector, and insert it into a byte vector.
359+ static Value extractSliceIntoByte (ConversionPatternRewriter &rewriter,
360+ Location loc, TypedValue<VectorType> vector,
361+ int64_t sliceOffset, int64_t sliceNumElements,
362+ int64_t byteOffset) {
363+ auto vectorElementType = vector.getType ().getElementType ();
364+ assert (8 % vectorElementType.getIntOrFloatBitWidth () == 0 &&
365+ " vector element must be a valid sub-byte type" );
366+ auto scale = 8 / vectorElementType.getIntOrFloatBitWidth ();
367+ auto emptyByteVector = rewriter.create <arith::ConstantOp>(
368+ loc, VectorType::get ({scale}, vectorElementType),
369+ rewriter.getZeroAttr (VectorType::get ({scale}, vectorElementType)));
370+ auto extracted = staticallyExtractSubvector (rewriter, loc, vector,
371+ sliceOffset, sliceNumElements);
372+ auto inserted = staticallyInsertSubvector (rewriter, loc, extracted,
373+ emptyByteVector, byteOffset);
374+ return inserted;
375+ }
376+
309377namespace {
310378
311379// ===----------------------------------------------------------------------===//
@@ -315,6 +383,10 @@ namespace {
315383struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
316384 using OpConversionPattern::OpConversionPattern;
317385
386+ ConvertVectorStore (MLIRContext *context, bool useAtomicWrites)
387+ : OpConversionPattern<vector::StoreOp>(context),
388+ useAtomicWrites_ (useAtomicWrites) {}
389+
318390 LogicalResult
319391 matchAndRewrite (vector::StoreOp op, OpAdaptor adaptor,
320392 ConversionPatternRewriter &rewriter) const override {
@@ -326,7 +398,8 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
326398
327399 auto loc = op.getLoc ();
328400 auto convertedType = cast<MemRefType>(adaptor.getBase ().getType ());
329- Type oldElementType = op.getValueToStore ().getType ().getElementType ();
401+ auto valueToStore = op.getValueToStore ();
402+ Type oldElementType = valueToStore.getType ().getElementType ();
330403 Type newElementType = convertedType.getElementType ();
331404 int srcBits = oldElementType.getIntOrFloatBitWidth ();
332405 int dstBits = newElementType.getIntOrFloatBitWidth ();
@@ -335,7 +408,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
335408 return rewriter.notifyMatchFailure (
336409 op, " only dstBits % srcBits == 0 supported" );
337410 }
338- int scale = dstBits / srcBits;
411+ int numSrcElemsPerDest = dstBits / srcBits;
339412
340413 // Adjust the number of elements to store when emulating narrow types.
341414 // Here only the 1-D vector store is considered, and the N-D memref types
@@ -350,32 +423,154 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
350423 // vector.store %bitcast, %alloc[%linear_index] : memref<16xi8>,
351424 // vector<4xi8>
352425
353- auto origElements = op.getValueToStore ().getType ().getNumElements ();
354- if (origElements % scale != 0 )
355- return failure ();
426+ auto origElements = valueToStore.getType ().getNumElements ();
427+ bool isUnalignedEmulation = origElements % numSrcElemsPerDest != 0 ;
356428
357429 auto stridedMetadata =
358430 rewriter.create <memref::ExtractStridedMetadataOp>(loc, op.getBase ());
359431
360432 OpFoldResult linearizedIndices;
361- std::tie (std::ignore, linearizedIndices) =
433+ memref::LinearizedMemRefInfo linearizedInfo;
434+ std::tie (linearizedInfo, linearizedIndices) =
362435 memref::getLinearizedMemRefOffsetAndSize (
363436 rewriter, loc, srcBits, dstBits,
364437 stridedMetadata.getConstifiedMixedOffset (),
365438 stridedMetadata.getConstifiedMixedSizes (),
366439 stridedMetadata.getConstifiedMixedStrides (),
367440 getAsOpFoldResult (adaptor.getIndices ()));
368441
369- auto numElements = origElements / scale;
370- auto bitCast = rewriter.create <vector::BitCastOp>(
371- loc, VectorType::get (numElements, newElementType),
372- op.getValueToStore ());
442+ auto foldedNumFrontPadElems =
443+ isUnalignedEmulation
444+ ? getConstantIntValue (linearizedInfo.intraDataOffset )
445+ : 0 ;
446+
447+ if (!foldedNumFrontPadElems) {
448+ // Unimplemented case for dynamic front padding size != 0
449+ return failure ();
450+ }
451+
452+ TypedValue<MemRefType> emulatedMemref =
453+ cast<TypedValue<MemRefType>>(adaptor.getBase ());
454+
455+ // Shortcut: conditions when subbyte store at the front is not needed:
456+ // 1. The source vector size is multiple of byte size
457+ // 2. The address of the store is aligned to the emulated width boundary
458+ if (!isUnalignedEmulation && *foldedNumFrontPadElems == 0 ) {
459+ auto numElements = origElements / numSrcElemsPerDest;
460+ auto bitCast = rewriter.create <vector::BitCastOp>(
461+ loc, VectorType::get (numElements, newElementType),
462+ op.getValueToStore ());
463+ rewriter.replaceOpWithNewOp <vector::StoreOp>(
464+ op, bitCast.getResult (), emulatedMemref,
465+ getValueOrCreateConstantIndexOp (rewriter, loc, linearizedIndices));
466+ return llvm::success ();
467+ }
468+
469+ // The index into the target memref we are storing to
470+ Value currentDestIndex =
471+ getValueOrCreateConstantIndexOp (rewriter, loc, linearizedIndices);
472+ auto constantOne = rewriter.create <arith::ConstantIndexOp>(loc, 1 );
473+ auto subWidthStoreMaskType =
474+ VectorType::get ({numSrcElemsPerDest}, rewriter.getI1Type ());
475+ // The index into the source vector we are currently processing
476+ auto currentSourceIndex = 0 ;
477+
478+ // 1. Partial width store for the first byte, when the store address is not
479+ // aligned to emulated width boundary, deal with the unaligned part so that
480+ // the rest elements are aligned to width boundary.
481+ auto frontSubWidthStoreElem =
482+ (numSrcElemsPerDest - *foldedNumFrontPadElems) % numSrcElemsPerDest;
483+ if (frontSubWidthStoreElem != 0 ) {
484+ auto frontMaskValues = llvm::SmallVector<bool >(numSrcElemsPerDest, false );
485+ if (*foldedNumFrontPadElems + origElements < numSrcElemsPerDest) {
486+ std::fill_n (frontMaskValues.begin () + *foldedNumFrontPadElems,
487+ origElements, true );
488+ frontSubWidthStoreElem = origElements;
489+ } else {
490+ std::fill_n (frontMaskValues.end () - frontSubWidthStoreElem,
491+ *foldedNumFrontPadElems, true );
492+ }
493+ auto frontMask = rewriter.create <arith::ConstantOp>(
494+ loc, DenseElementsAttr::get (subWidthStoreMaskType, frontMaskValues));
373495
374- rewriter.replaceOpWithNewOp <vector::StoreOp>(
375- op, bitCast.getResult (), adaptor.getBase (),
376- getValueOrCreateConstantIndexOp (rewriter, loc, linearizedIndices));
496+ currentSourceIndex = numSrcElemsPerDest - (*foldedNumFrontPadElems);
497+ auto value = extractSliceIntoByte (
498+ rewriter, loc, cast<TypedValue<VectorType>>(valueToStore), 0 ,
499+ frontSubWidthStoreElem, *foldedNumFrontPadElems);
500+
501+ subEmulatedWidthStore (rewriter, loc, emulatedMemref, currentDestIndex,
502+ cast<TypedValue<VectorType>>(value),
503+ frontMask.getResult (), numSrcElemsPerDest);
504+
505+ currentDestIndex = rewriter.create <arith::AddIOp>(
506+ loc, rewriter.getIndexType (), currentDestIndex, constantOne);
507+ }
508+
509+ if (currentSourceIndex >= origElements) {
510+ rewriter.eraseOp (op);
511+ return success ();
512+ }
513+
514+ // 2. Full width store. After the previous step, the store address is
515+ // aligned to the emulated width boundary.
516+ int64_t fullWidthStoreSize =
517+ (origElements - currentSourceIndex) / numSrcElemsPerDest;
518+ int64_t numNonFullWidthElements = fullWidthStoreSize * numSrcElemsPerDest;
519+ if (fullWidthStoreSize != 0 ) {
520+ auto fullWidthStorePart = staticallyExtractSubvector (
521+ rewriter, loc, cast<TypedValue<VectorType>>(valueToStore),
522+ currentSourceIndex, numNonFullWidthElements);
523+
524+ auto originType = dyn_cast<VectorType>(fullWidthStorePart.getType ());
525+ auto memrefElemType =
526+ dyn_cast<MemRefType>(emulatedMemref.getType ()).getElementType ();
527+ auto storeType = VectorType::get (
528+ {originType.getNumElements () / numSrcElemsPerDest}, memrefElemType);
529+ auto bitCast = rewriter.create <vector::BitCastOp>(loc, storeType,
530+ fullWidthStorePart);
531+ rewriter.create <vector::StoreOp>(loc, bitCast.getResult (), emulatedMemref,
532+ currentDestIndex);
533+
534+ currentSourceIndex += numNonFullWidthElements;
535+ currentDestIndex = rewriter.create <arith::AddIOp>(
536+ loc, rewriter.getIndexType (), currentDestIndex,
537+ rewriter.create <arith::ConstantIndexOp>(loc, fullWidthStoreSize));
538+ }
539+
540+ // 3. Deal with trailing elements that are aligned to the emulated width,
541+ // but their length is smaller than the emulated width.
542+ auto remainingElements = origElements - currentSourceIndex;
543+ if (remainingElements != 0 ) {
544+ auto subWidthStorePart = extractSliceIntoByte (
545+ rewriter, loc, cast<TypedValue<VectorType>>(valueToStore),
546+ currentSourceIndex, remainingElements, 0 );
547+
548+ // Generate back mask
549+ auto maskValues = llvm::SmallVector<bool >(numSrcElemsPerDest, 0 );
550+ std::fill_n (maskValues.begin (), remainingElements, 1 );
551+ auto backMask = rewriter.create <arith::ConstantOp>(
552+ loc, DenseElementsAttr::get (subWidthStoreMaskType, maskValues));
553+
554+ subEmulatedWidthStore (rewriter, loc, emulatedMemref, currentDestIndex,
555+ cast<TypedValue<VectorType>>(subWidthStorePart),
556+ backMask.getResult (), numSrcElemsPerDest);
557+ }
558+
559+ rewriter.eraseOp (op);
377560 return success ();
378561 }
562+
563+ // / Store a subbyte-sized value to memory, with a mask. Depending on the
564+ // / configuration, it could be an atomic store or an RMW sequence.
565+ template <typename ... Args>
566+ void subEmulatedWidthStore (Args &&...args) const {
567+ std::function<decltype (atomicStore)> storeFunc =
568+ useAtomicWrites_ ? atomicStore : rmwStore;
569+ storeFunc (std::forward<Args>(args)...);
570+ }
571+
572+ private:
573+ const bool useAtomicWrites_;
379574};
380575
381576// ===----------------------------------------------------------------------===//
@@ -581,9 +776,8 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
581776 rewriter, loc, dyn_cast<TypedValue<VectorType>>(result), resultVector,
582777 linearizedInfo.intraDataOffset , origElements);
583778 } else if (isUnalignedEmulation) {
584- result =
585- staticallyExtractSubvector (rewriter, loc, op.getType (), result,
586- *foldedIntraVectorOffset, origElements);
779+ result = staticallyExtractSubvector (
780+ rewriter, loc, result, *foldedIntraVectorOffset, origElements);
587781 }
588782 rewriter.replaceOp (op, result);
589783 return success ();
@@ -742,9 +936,8 @@ struct ConvertVectorMaskedLoad final
742936 rewriter, loc, dyn_cast<TypedValue<VectorType>>(result),
743937 op.getPassThru (), linearizedInfo.intraDataOffset , origElements);
744938 } else if (isUnalignedEmulation) {
745- result =
746- staticallyExtractSubvector (rewriter, loc, op.getType (), result,
747- *foldedIntraVectorOffset, origElements);
939+ result = staticallyExtractSubvector (
940+ rewriter, loc, result, *foldedIntraVectorOffset, origElements);
748941 }
749942 rewriter.replaceOp (op, result);
750943
@@ -827,9 +1020,8 @@ struct ConvertVectorTransferRead final
8271020 linearizedInfo.intraDataOffset ,
8281021 origElements);
8291022 } else if (isUnalignedEmulation) {
830- result =
831- staticallyExtractSubvector (rewriter, loc, op.getType (), result,
832- *foldedIntraVectorOffset, origElements);
1023+ result = staticallyExtractSubvector (
1024+ rewriter, loc, result, *foldedIntraVectorOffset, origElements);
8331025 }
8341026 rewriter.replaceOp (op, result);
8351027
@@ -1574,12 +1766,17 @@ struct RewriteVectorTranspose : OpRewritePattern<vector::TransposeOp> {
15741766
15751767void vector::populateVectorNarrowTypeEmulationPatterns (
15761768 const arith::NarrowTypeEmulationConverter &typeConverter,
1577- RewritePatternSet &patterns) {
1769+ RewritePatternSet &patterns, bool useAtomicWrites ) {
15781770
1579- // Populate `vector.*` conversion patterns.
1580- patterns.add <ConvertVectorLoad, ConvertVectorMaskedLoad, ConvertVectorStore,
1771+ // Populate `vector.*` load conversion patterns.
1772+ patterns.add <ConvertVectorLoad, ConvertVectorMaskedLoad,
15811773 ConvertVectorMaskedStore, ConvertVectorTransferRead>(
15821774 typeConverter, patterns.getContext ());
1775+
1776+ // Populate `vector.*` store conversion patterns. The caller can choose
1777+ // to avoid emitting atomic operations and reduce it to load-modify-write
1778+ // sequence for stores if it is known there are no thread contentions.
1779+ patterns.insert <ConvertVectorStore>(patterns.getContext (), useAtomicWrites);
15831780}
15841781
15851782void vector::populateVectorNarrowTypeRewritePatterns (
0 commit comments