3333#include " mlir/Transforms/DialectConversion.h"
3434#include " llvm/ADT/SmallVector.h"
3535#include " llvm/Support/Debug.h"
36+ #include " llvm/Support/LogicalResult.h"
3637#include " llvm/Support/MathExtras.h"
3738#include " llvm/Support/raw_ostream.h"
3839#include < cstdint>
@@ -157,13 +158,10 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
157158// / Extracts 1-D subvector from a 1-D vector. It is a wrapper function for
158159// / emitting `vector.extract_strided_slice`.
159160static Value staticallyExtractSubvector (OpBuilder &rewriter, Location loc,
160- VectorType extractType, Value source,
161- int64_t frontOffset,
161+ Value source, int64_t frontOffset,
162162 int64_t subvecSize) {
163163 auto vectorType = cast<VectorType>(source.getType ());
164- assert ((vectorType.getRank () == 1 && extractType.getRank () == 1 ) &&
165- " expected 1-D source and destination types" );
166- (void )vectorType;
164+ assert (vectorType.getRank () == 1 && " expected 1-D source types" );
167165 assert (frontOffset + subvecSize <= vectorType.getNumElements () &&
168166 " subvector out of bounds" );
169167
@@ -174,9 +172,12 @@ static Value staticallyExtractSubvector(OpBuilder &rewriter, Location loc,
174172 auto offsets = rewriter.getI64ArrayAttr ({frontOffset});
175173 auto sizes = rewriter.getI64ArrayAttr ({subvecSize});
176174 auto strides = rewriter.getI64ArrayAttr ({1 });
175+
176+ auto resultVectorType =
177+ VectorType::get ({subvecSize}, vectorType.getElementType ());
177178 return rewriter
178- .create <vector::ExtractStridedSliceOp>(loc, extractType , source, offsets ,
179- sizes, strides)
179+ .create <vector::ExtractStridedSliceOp>(loc, resultVectorType , source,
180+ offsets, sizes, strides)
180181 ->getResult (0 );
181182}
182183
@@ -185,12 +186,10 @@ static Value staticallyExtractSubvector(OpBuilder &rewriter, Location loc,
185186// / `vector.insert_strided_slice`.
186187static Value staticallyInsertSubvector (OpBuilder &rewriter, Location loc,
187188 Value src, Value dest, int64_t offset) {
188- auto srcType = cast<VectorType>(src.getType ());
189- auto destType = cast<VectorType>(dest.getType ());
189+ [[maybe_unused]] auto srcType = cast<VectorType>(src.getType ());
190+ [[maybe_unused]] auto destType = cast<VectorType>(dest.getType ());
190191 assert (srcType.getRank () == 1 && destType.getRank () == 1 &&
191192 " expected source and dest to be vector type" );
192- (void )srcType;
193- (void )destType;
194193 auto offsets = rewriter.getI64ArrayAttr ({offset});
195194 auto strides = rewriter.getI64ArrayAttr ({1 });
196195 return rewriter.create <vector::InsertStridedSliceOp>(loc, dest.getType (), src,
@@ -257,6 +256,63 @@ emulatedVectorLoad(OpBuilder &rewriter, Location loc, Value base,
257256 newLoad);
258257}
259258
259+ static void nonAtomicStore (ConversionPatternRewriter &rewriter, Location loc,
260+ Value memref, Value index, Value value) {
261+ auto originType = dyn_cast<VectorType>(value.getType ());
262+ auto memrefElemType = dyn_cast<MemRefType>(memref.getType ()).getElementType ();
263+ auto scale = memrefElemType.getIntOrFloatBitWidth () /
264+ originType.getElementType ().getIntOrFloatBitWidth ();
265+ auto storeType =
266+ VectorType::get ({originType.getNumElements () / scale}, memrefElemType);
267+ auto bitCast = rewriter.create <vector::BitCastOp>(loc, storeType, value);
268+ rewriter.create <vector::StoreOp>(loc, bitCast.getResult (), memref, index);
269+ }
270+
271+ // / atomically store a subbyte-sized value to memory, with a mask.
272+ static Value atomicStore (OpBuilder &rewriter, Location loc,
273+ Value emulatedMemref, Value emulatedIndex,
274+ TypedValue<VectorType> value, Value mask,
275+ int64_t scale) {
276+ auto atomicOp = rewriter.create <memref::GenericAtomicRMWOp>(
277+ loc, emulatedMemref, ValueRange{emulatedIndex});
278+ OpBuilder builder =
279+ OpBuilder::atBlockEnd (atomicOp.getBody (), rewriter.getListener ());
280+ Value origValue = atomicOp.getCurrentValue ();
281+
282+ // i8 -> vector type <1xi8> then <1xi8> -> <scale x i.>
283+ auto oneVectorType = VectorType::get ({1 }, origValue.getType ());
284+ auto fromElem = builder.create <vector::FromElementsOp>(loc, oneVectorType,
285+ ValueRange{origValue});
286+ auto vectorBitCast =
287+ builder.create <vector::BitCastOp>(loc, value.getType (), fromElem);
288+
289+ auto select =
290+ builder.create <arith::SelectOp>(loc, mask, value, vectorBitCast);
291+ auto bitcast2 = builder.create <vector::BitCastOp>(loc, oneVectorType, select);
292+ auto extract = builder.create <vector::ExtractOp>(loc, bitcast2, 0 );
293+ builder.create <memref::AtomicYieldOp>(loc, extract.getResult ());
294+ return atomicOp;
295+ }
296+
297+ // Extract a slice of a vector, and insert it into a byte vector.
298+ static Value extractSliceIntoByte (ConversionPatternRewriter &rewriter,
299+ Location loc, TypedValue<VectorType> vector,
300+ int64_t sliceOffset, int64_t sliceNumElements,
301+ int64_t byteOffset) {
302+ auto vectorElementType = vector.getType ().getElementType ();
303+ assert (8 % vectorElementType.getIntOrFloatBitWidth () == 0 &&
304+ " vector element must be a valid sub-byte type" );
305+ auto scale = 8 / vectorElementType.getIntOrFloatBitWidth ();
306+ auto emptyByteVector = rewriter.create <arith::ConstantOp>(
307+ loc, VectorType::get ({scale}, vectorElementType),
308+ rewriter.getZeroAttr (VectorType::get ({scale}, vectorElementType)));
309+ auto extracted = staticallyExtractSubvector (rewriter, loc, vector,
310+ sliceOffset, sliceNumElements);
311+ auto inserted = staticallyInsertSubvector (rewriter, loc, extracted,
312+ emptyByteVector, byteOffset);
313+ return inserted;
314+ }
315+
260316namespace {
261317
262318// ===----------------------------------------------------------------------===//
@@ -277,7 +333,8 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
277333
278334 auto loc = op.getLoc ();
279335 auto convertedType = cast<MemRefType>(adaptor.getBase ().getType ());
280- Type oldElementType = op.getValueToStore ().getType ().getElementType ();
336+ auto valueToStore = op.getValueToStore ();
337+ Type oldElementType = valueToStore.getType ().getElementType ();
281338 Type newElementType = convertedType.getElementType ();
282339 int srcBits = oldElementType.getIntOrFloatBitWidth ();
283340 int dstBits = newElementType.getIntOrFloatBitWidth ();
@@ -301,30 +358,124 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
301358 // vector.store %bitcast, %alloc[%linear_index] : memref<16xi8>,
302359 // vector<4xi8>
303360
304- auto origElements = op.getValueToStore ().getType ().getNumElements ();
305- if (origElements % scale != 0 )
306- return failure ();
361+ auto origElements = valueToStore.getType ().getNumElements ();
362+ bool isUnalignedEmulation = origElements % scale != 0 ;
307363
308364 auto stridedMetadata =
309365 rewriter.create <memref::ExtractStridedMetadataOp>(loc, op.getBase ());
310366
311367 OpFoldResult linearizedIndices;
312- std::tie (std::ignore, linearizedIndices) =
368+ memref::LinearizedMemRefInfo linearizedInfo;
369+ std::tie (linearizedInfo, linearizedIndices) =
313370 memref::getLinearizedMemRefOffsetAndSize (
314371 rewriter, loc, srcBits, dstBits,
315372 stridedMetadata.getConstifiedMixedOffset (),
316373 stridedMetadata.getConstifiedMixedSizes (),
317374 stridedMetadata.getConstifiedMixedStrides (),
318375 getAsOpFoldResult (adaptor.getIndices ()));
319376
320- auto numElements = origElements / scale;
321- auto bitCast = rewriter.create <vector::BitCastOp>(
322- loc, VectorType::get (numElements, newElementType),
323- op.getValueToStore ());
377+ auto foldedIntraVectorOffset =
378+ isUnalignedEmulation
379+ ? getConstantIntValue (linearizedInfo.intraDataOffset )
380+ : 0 ;
381+
382+ if (!foldedIntraVectorOffset) {
383+ // unimplemented case for dynamic front padding size
384+ return failure ();
385+ }
386+
387+ // conditions when atomic stores and all that are not needed:
388+ // 1. The source vector size is multiple of byte size
389+ // 2. The address of the store is byte aligned
390+ if (!isUnalignedEmulation && *foldedIntraVectorOffset == 0 ) {
391+ auto numElements = origElements / scale;
392+ auto bitCast = rewriter.create <vector::BitCastOp>(
393+ loc, VectorType::get (numElements, newElementType),
394+ op.getValueToStore ());
395+ rewriter.replaceOpWithNewOp <vector::StoreOp>(
396+ op, bitCast.getResult (), adaptor.getBase (),
397+ getValueOrCreateConstantIndexOp (rewriter, loc, linearizedIndices));
398+ return llvm::success ();
399+ }
400+
401+ Value emulatedMemref = adaptor.getBase ();
402+ // the index into the target memref we are storing to
403+ Value currentDestIndex =
404+ getValueOrCreateConstantIndexOp (rewriter, loc, linearizedIndices);
405+ auto constantOne = rewriter.create <arith::ConstantIndexOp>(loc, 1 );
406+ auto atomicMaskType = VectorType::get ({scale}, rewriter.getI1Type ());
407+ // the index into the source vector we are currently processing
408+ auto currentSourceIndex = 0 ;
409+
410+ // 1. atomic store for the first byte
411+ auto frontAtomicStoreElem = (scale - *foldedIntraVectorOffset) % scale;
412+ if (frontAtomicStoreElem != 0 ) {
413+ auto frontMaskValues = llvm::SmallVector<bool >(scale, false );
414+ if (*foldedIntraVectorOffset + origElements < scale) {
415+ std::fill_n (frontMaskValues.begin () + *foldedIntraVectorOffset,
416+ origElements, true );
417+ frontAtomicStoreElem = origElements;
418+ } else {
419+ std::fill_n (frontMaskValues.end () - frontAtomicStoreElem,
420+ *foldedIntraVectorOffset, true );
421+ }
422+ auto frontMask = rewriter.create <arith::ConstantOp>(
423+ loc, DenseElementsAttr::get (atomicMaskType, frontMaskValues));
424+
425+ currentSourceIndex = scale - (*foldedIntraVectorOffset);
426+ auto value = extractSliceIntoByte (
427+ rewriter, loc, cast<TypedValue<VectorType>>(valueToStore), 0 ,
428+ frontAtomicStoreElem, *foldedIntraVectorOffset);
429+
430+ atomicStore (rewriter, loc, emulatedMemref, currentDestIndex,
431+ cast<TypedValue<VectorType>>(value), frontMask.getResult (),
432+ scale);
433+
434+ currentDestIndex = rewriter.create <arith::AddIOp>(
435+ loc, rewriter.getIndexType (), currentDestIndex, constantOne);
436+ }
437+
438+ if (currentSourceIndex >= origElements) {
439+ rewriter.eraseOp (op);
440+ return success ();
441+ }
442+
443+ // 2. non-atomic store
444+ int64_t nonAtomicStoreSize = (origElements - currentSourceIndex) / scale;
445+ int64_t numNonAtomicElements = nonAtomicStoreSize * scale;
446+ if (nonAtomicStoreSize != 0 ) {
447+ auto nonAtomicStorePart = staticallyExtractSubvector (
448+ rewriter, loc, cast<TypedValue<VectorType>>(valueToStore),
449+ currentSourceIndex, numNonAtomicElements);
450+
451+ nonAtomicStore (rewriter, loc, emulatedMemref, currentDestIndex,
452+ nonAtomicStorePart);
453+
454+ currentSourceIndex += numNonAtomicElements;
455+ currentDestIndex = rewriter.create <arith::AddIOp>(
456+ loc, rewriter.getIndexType (), currentDestIndex,
457+ rewriter.create <arith::ConstantIndexOp>(loc, nonAtomicStoreSize));
458+ }
459+
460+ // 3. atomic store for the last byte
461+ auto remainingElements = origElements - currentSourceIndex;
462+ if (remainingElements != 0 ) {
463+ auto atomicStorePart = extractSliceIntoByte (
464+ rewriter, loc, cast<TypedValue<VectorType>>(valueToStore),
465+ currentSourceIndex, remainingElements, 0 );
466+
467+ // back mask
468+ auto maskValues = llvm::SmallVector<bool >(scale, 0 );
469+ std::fill_n (maskValues.begin (), remainingElements, 1 );
470+ auto backMask = rewriter.create <arith::ConstantOp>(
471+ loc, DenseElementsAttr::get (atomicMaskType, maskValues));
472+
473+ atomicStore (rewriter, loc, emulatedMemref, currentDestIndex,
474+ cast<TypedValue<VectorType>>(atomicStorePart),
475+ backMask.getResult (), scale);
476+ }
324477
325- rewriter.replaceOpWithNewOp <vector::StoreOp>(
326- op, bitCast.getResult (), adaptor.getBase (),
327- getValueOrCreateConstantIndexOp (rewriter, loc, linearizedIndices));
478+ rewriter.eraseOp (op);
328479 return success ();
329480 }
330481};
@@ -532,9 +683,8 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
532683 rewriter, loc, dyn_cast<TypedValue<VectorType>>(result), resultVector,
533684 linearizedInfo.intraDataOffset , origElements);
534685 } else if (isUnalignedEmulation) {
535- result =
536- staticallyExtractSubvector (rewriter, loc, op.getType (), result,
537- *foldedIntraVectorOffset, origElements);
686+ result = staticallyExtractSubvector (
687+ rewriter, loc, result, *foldedIntraVectorOffset, origElements);
538688 }
539689 rewriter.replaceOp (op, result);
540690 return success ();
@@ -693,9 +843,8 @@ struct ConvertVectorMaskedLoad final
693843 rewriter, loc, dyn_cast<TypedValue<VectorType>>(result),
694844 op.getPassThru (), linearizedInfo.intraDataOffset , origElements);
695845 } else if (isUnalignedEmulation) {
696- result =
697- staticallyExtractSubvector (rewriter, loc, op.getType (), result,
698- *foldedIntraVectorOffset, origElements);
846+ result = staticallyExtractSubvector (
847+ rewriter, loc, result, *foldedIntraVectorOffset, origElements);
699848 }
700849 rewriter.replaceOp (op, result);
701850
@@ -778,9 +927,8 @@ struct ConvertVectorTransferRead final
778927 linearizedInfo.intraDataOffset ,
779928 origElements);
780929 } else if (isUnalignedEmulation) {
781- result =
782- staticallyExtractSubvector (rewriter, loc, op.getType (), result,
783- *foldedIntraVectorOffset, origElements);
930+ result = staticallyExtractSubvector (
931+ rewriter, loc, result, *foldedIntraVectorOffset, origElements);
784932 }
785933 rewriter.replaceOp (op, result);
786934
0 commit comments