3333#include " mlir/Transforms/DialectConversion.h"
3434#include " llvm/ADT/SmallVector.h"
3535#include " llvm/Support/Debug.h"
36+ #include " llvm/Support/LogicalResult.h"
3637#include " llvm/Support/MathExtras.h"
3738#include " llvm/Support/raw_ostream.h"
3839#include < cstdint>
@@ -143,19 +144,19 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
143144// / Extracts 1-D subvector from a 1-D vector. It is a wrapper function for
144145// / emitting `vector.extract_strided_slice`.
145146static Value staticallyExtractSubvector (OpBuilder &rewriter, Location loc,
146- VectorType extractType, Value source,
147- int64_t frontOffset,
147+ Value source, int64_t frontOffset,
148148 int64_t subvecSize) {
149- auto vectorType = cast<VectorType>(source.getType ());
150- assert ((vectorType.getRank () == 1 && extractType.getRank () == 1 ) &&
151- " expected 1-D source and destination types" );
152- (void )vectorType;
149+ auto vectorType = llvm::cast<VectorType>(source.getType ());
150+ assert (vectorType.getRank () == 1 && " expected 1-D source types" );
153151 auto offsets = rewriter.getI64ArrayAttr ({frontOffset});
154152 auto sizes = rewriter.getI64ArrayAttr ({subvecSize});
155153 auto strides = rewriter.getI64ArrayAttr ({1 });
154+
155+ auto resultVectorType =
156+ VectorType::get ({subvecSize}, vectorType.getElementType ());
156157 return rewriter
157- .create <vector::ExtractStridedSliceOp>(loc, extractType , source, offsets ,
158- sizes, strides)
158+ .create <vector::ExtractStridedSliceOp>(loc, resultVectorType , source,
159+ offsets, sizes, strides)
159160 ->getResult (0 );
160161}
161162
@@ -164,12 +165,10 @@ static Value staticallyExtractSubvector(OpBuilder &rewriter, Location loc,
164165// / `vector.insert_strided_slice`.
165166static Value staticallyInsertSubvector (OpBuilder &rewriter, Location loc,
166167 Value src, Value dest, int64_t offset) {
167- auto srcType = cast<VectorType>(src.getType ());
168- auto destType = cast<VectorType>(dest.getType ());
168+ [[maybe_unused]] auto srcType = cast<VectorType>(src.getType ());
169+ [[maybe_unused]] auto destType = cast<VectorType>(dest.getType ());
169170 assert (srcType.getRank () == 1 && destType.getRank () == 1 &&
170171 " expected source and dest to be vector type" );
171- (void )srcType;
172- (void )destType;
173172 auto offsets = rewriter.getI64ArrayAttr ({offset});
174173 auto strides = rewriter.getI64ArrayAttr ({1 });
175174 return rewriter.create <vector::InsertStridedSliceOp>(loc, dest.getType (), src,
@@ -236,6 +235,63 @@ emulatedVectorLoad(OpBuilder &rewriter, Location loc, Value base,
236235 newLoad);
237236}
238237
238+ static void nonAtomicStore (ConversionPatternRewriter &rewriter, Location loc,
239+ Value memref, Value index, Value value) {
240+ auto originType = dyn_cast<VectorType>(value.getType ());
241+ auto memrefElemType = dyn_cast<MemRefType>(memref.getType ()).getElementType ();
242+ auto scale = memrefElemType.getIntOrFloatBitWidth () /
243+ originType.getElementType ().getIntOrFloatBitWidth ();
244+ auto storeType =
245+ VectorType::get ({originType.getNumElements () / scale}, memrefElemType);
246+ auto bitCast = rewriter.create <vector::BitCastOp>(loc, storeType, value);
247+ rewriter.create <vector::StoreOp>(loc, bitCast.getResult (), memref, index);
248+ }
249+
250+ // / atomically store a subbyte-sized value to memory, with a mask.
251+ static Value atomicStore (OpBuilder &rewriter, Location loc,
252+ Value emulatedMemref, Value emulatedIndex,
253+ TypedValue<VectorType> value, Value mask,
254+ int64_t scale) {
255+ auto atomicOp = rewriter.create <memref::GenericAtomicRMWOp>(
256+ loc, emulatedMemref, ValueRange{emulatedIndex});
257+ OpBuilder builder =
258+ OpBuilder::atBlockEnd (atomicOp.getBody (), rewriter.getListener ());
259+ Value origValue = atomicOp.getCurrentValue ();
260+
261+ // i8 -> vector type <1xi8> then <1xi8> -> <scale x i.>
262+ auto oneVectorType = VectorType::get ({1 }, origValue.getType ());
263+ auto fromElem = builder.create <vector::FromElementsOp>(loc, oneVectorType,
264+ ValueRange{origValue});
265+ auto vectorBitCast =
266+ builder.create <vector::BitCastOp>(loc, value.getType (), fromElem);
267+
268+ auto select =
269+ builder.create <arith::SelectOp>(loc, mask, value, vectorBitCast);
270+ auto bitcast2 = builder.create <vector::BitCastOp>(loc, oneVectorType, select);
271+ auto extract = builder.create <vector::ExtractOp>(loc, bitcast2, 0 );
272+ builder.create <memref::AtomicYieldOp>(loc, extract.getResult ());
273+ return atomicOp;
274+ }
275+
276+ // Extract a slice of a vector, and insert it into a byte vector.
277+ static Value extractSliceIntoByte (ConversionPatternRewriter &rewriter,
278+ Location loc, TypedValue<VectorType> vector,
279+ int64_t sliceOffset, int64_t sliceNumElements,
280+ int64_t byteOffset) {
281+ auto vectorElementType = vector.getType ().getElementType ();
282+ assert (8 % vectorElementType.getIntOrFloatBitWidth () == 0 &&
283+ " vector element must be a valid sub-byte type" );
284+ auto scale = 8 / vectorElementType.getIntOrFloatBitWidth ();
285+ auto emptyByteVector = rewriter.create <arith::ConstantOp>(
286+ loc, VectorType::get ({scale}, vectorElementType),
287+ rewriter.getZeroAttr (VectorType::get ({scale}, vectorElementType)));
288+ auto extracted = staticallyExtractSubvector (rewriter, loc, vector,
289+ sliceOffset, sliceNumElements);
290+ auto inserted = staticallyInsertSubvector (rewriter, loc, extracted,
291+ emptyByteVector, byteOffset);
292+ return inserted;
293+ }
294+
239295namespace {
240296
241297// ===----------------------------------------------------------------------===//
@@ -256,7 +312,8 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
256312
257313 auto loc = op.getLoc ();
258314 auto convertedType = cast<MemRefType>(adaptor.getBase ().getType ());
259- Type oldElementType = op.getValueToStore ().getType ().getElementType ();
315+ auto valueToStore = op.getValueToStore ();
316+ Type oldElementType = valueToStore.getType ().getElementType ();
260317 Type newElementType = convertedType.getElementType ();
261318 int srcBits = oldElementType.getIntOrFloatBitWidth ();
262319 int dstBits = newElementType.getIntOrFloatBitWidth ();
@@ -280,30 +337,121 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
280337 // vector.store %bitcast, %alloc[%linear_index] : memref<16xi8>,
281338 // vector<4xi8>
282339
283- auto origElements = op.getValueToStore ().getType ().getNumElements ();
284- if (origElements % scale != 0 )
285- return failure ();
340+ auto origElements = valueToStore.getType ().getNumElements ();
341+ bool isUnalignedEmulation = origElements % scale != 0 ;
286342
287343 auto stridedMetadata =
288344 rewriter.create <memref::ExtractStridedMetadataOp>(loc, op.getBase ());
289345
290346 OpFoldResult linearizedIndices;
291- std::tie (std::ignore, linearizedIndices) =
347+ memref::LinearizedMemRefInfo linearizedInfo;
348+ std::tie (linearizedInfo, linearizedIndices) =
292349 memref::getLinearizedMemRefOffsetAndSize (
293350 rewriter, loc, srcBits, dstBits,
294351 stridedMetadata.getConstifiedMixedOffset (),
295352 stridedMetadata.getConstifiedMixedSizes (),
296353 stridedMetadata.getConstifiedMixedStrides (),
297354 getAsOpFoldResult (adaptor.getIndices ()));
298355
299- auto numElements = origElements / scale;
300- auto bitCast = rewriter.create <vector::BitCastOp>(
301- loc, VectorType::get (numElements, newElementType),
302- op.getValueToStore ());
356+ auto foldedIntraVectorOffset =
357+ isUnalignedEmulation
358+ ? getConstantIntValue (linearizedInfo.intraDataOffset )
359+ : 0 ;
360+
361+ if (!foldedIntraVectorOffset) {
362+ // unimplemented case for dynamic front padding size
363+ return failure ();
364+ }
365+
366+ if (!isUnalignedEmulation) {
367+ auto numElements = origElements / scale;
368+ auto bitCast = rewriter.create <vector::BitCastOp>(
369+ loc, VectorType::get (numElements, newElementType),
370+ op.getValueToStore ());
371+ rewriter.replaceOpWithNewOp <vector::StoreOp>(
372+ op, bitCast.getResult (), adaptor.getBase (),
373+ getValueOrCreateConstantIndexOp (rewriter, loc, linearizedIndices));
374+ return llvm::success ();
375+ }
376+
377+ Value emulatedMemref = adaptor.getBase ();
378+ // the index into the target memref we are storing to
379+ Value currentDestIndex =
380+ getValueOrCreateConstantIndexOp (rewriter, loc, linearizedIndices);
381+ auto constantOne = rewriter.create <arith::ConstantIndexOp>(loc, 1 );
382+ auto atomicMaskType = VectorType::get ({scale}, rewriter.getI1Type ());
383+ // the index into the source vector we are currently processing
384+ auto currentSourceIndex = 0 ;
385+
386+ // 1. atomic store for the first byte
387+ auto frontAtomicStoreElem = (scale - *foldedIntraVectorOffset) % scale;
388+ if (frontAtomicStoreElem != 0 ) {
389+ auto frontMaskValues = llvm::SmallVector<bool >(scale, false );
390+ if (*foldedIntraVectorOffset + origElements < scale) {
391+ std::fill_n (frontMaskValues.begin () + *foldedIntraVectorOffset,
392+ origElements, true );
393+ frontAtomicStoreElem = origElements;
394+ } else {
395+ std::fill_n (frontMaskValues.end () - frontAtomicStoreElem,
396+ *foldedIntraVectorOffset, true );
397+ }
398+ auto frontMask = rewriter.create <arith::ConstantOp>(
399+ loc, DenseElementsAttr::get (atomicMaskType, frontMaskValues));
400+
401+ currentSourceIndex = scale - (*foldedIntraVectorOffset);
402+ auto value = extractSliceIntoByte (
403+ rewriter, loc, cast<TypedValue<VectorType>>(valueToStore), 0 ,
404+ frontAtomicStoreElem, *foldedIntraVectorOffset);
405+
406+ atomicStore (rewriter, loc, emulatedMemref, currentDestIndex,
407+ cast<TypedValue<VectorType>>(value), frontMask.getResult (),
408+ scale);
409+
410+ currentDestIndex = rewriter.create <arith::AddIOp>(
411+ loc, rewriter.getIndexType (), currentDestIndex, constantOne);
412+ }
413+
414+ if (currentSourceIndex >= origElements) {
415+ rewriter.eraseOp (op);
416+ return success ();
417+ }
418+
419+ // 2. non-atomic store
420+ int64_t nonAtomicStoreSize = (origElements - currentSourceIndex) / scale;
421+ int64_t numNonAtomicElements = nonAtomicStoreSize * scale;
422+ if (nonAtomicStoreSize != 0 ) {
423+ auto nonAtomicStorePart = staticallyExtractSubvector (
424+ rewriter, loc, cast<TypedValue<VectorType>>(valueToStore),
425+ currentSourceIndex, numNonAtomicElements);
426+
427+ nonAtomicStore (rewriter, loc, emulatedMemref, currentDestIndex,
428+ nonAtomicStorePart);
429+
430+ currentSourceIndex += numNonAtomicElements;
431+ currentDestIndex = rewriter.create <arith::AddIOp>(
432+ loc, rewriter.getIndexType (), currentDestIndex,
433+ rewriter.create <arith::ConstantIndexOp>(loc, nonAtomicStoreSize));
434+ }
435+
436+ // 3. atomic store for the last byte
437+ auto remainingElements = origElements - currentSourceIndex;
438+ if (remainingElements != 0 ) {
439+ auto atomicStorePart = extractSliceIntoByte (
440+ rewriter, loc, cast<TypedValue<VectorType>>(valueToStore),
441+ currentSourceIndex, remainingElements, 0 );
442+
443+ // back mask
444+ auto maskValues = llvm::SmallVector<bool >(scale, 0 );
445+ std::fill_n (maskValues.begin (), remainingElements, 1 );
446+ auto backMask = rewriter.create <arith::ConstantOp>(
447+ loc, DenseElementsAttr::get (atomicMaskType, maskValues));
448+
449+ atomicStore (rewriter, loc, emulatedMemref, currentDestIndex,
450+ cast<TypedValue<VectorType>>(atomicStorePart),
451+ backMask.getResult (), scale);
452+ }
303453
304- rewriter.replaceOpWithNewOp <vector::StoreOp>(
305- op, bitCast.getResult (), adaptor.getBase (),
306- getValueOrCreateConstantIndexOp (rewriter, loc, linearizedIndices));
454+ rewriter.eraseOp (op);
307455 return success ();
308456 }
309457};
@@ -511,9 +659,8 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
511659 rewriter, loc, dyn_cast<TypedValue<VectorType>>(result), resultVector,
512660 linearizedInfo.intraDataOffset , origElements);
513661 } else if (isUnalignedEmulation) {
514- result =
515- staticallyExtractSubvector (rewriter, loc, op.getType (), result,
516- *foldedIntraVectorOffset, origElements);
662+ result = staticallyExtractSubvector (
663+ rewriter, loc, result, *foldedIntraVectorOffset, origElements);
517664 }
518665 rewriter.replaceOp (op, result);
519666 return success ();
@@ -672,9 +819,8 @@ struct ConvertVectorMaskedLoad final
672819 rewriter, loc, dyn_cast<TypedValue<VectorType>>(result),
673820 op.getPassThru (), linearizedInfo.intraDataOffset , origElements);
674821 } else if (isUnalignedEmulation) {
675- result =
676- staticallyExtractSubvector (rewriter, loc, op.getType (), result,
677- *foldedIntraVectorOffset, origElements);
822+ result = staticallyExtractSubvector (
823+ rewriter, loc, result, *foldedIntraVectorOffset, origElements);
678824 }
679825 rewriter.replaceOp (op, result);
680826
@@ -757,9 +903,8 @@ struct ConvertVectorTransferRead final
757903 linearizedInfo.intraDataOffset ,
758904 origElements);
759905 } else if (isUnalignedEmulation) {
760- result =
761- staticallyExtractSubvector (rewriter, loc, op.getType (), result,
762- *foldedIntraVectorOffset, origElements);
906+ result = staticallyExtractSubvector (
907+ rewriter, loc, result, *foldedIntraVectorOffset, origElements);
763908 }
764909 rewriter.replaceOp (op, result);
765910
0 commit comments