2424#include " llvm/Support/Debug.h"
2525#include " llvm/Support/raw_ostream.h"
2626#include < cstdint>
27+ #include < optional>
2728
2829using namespace mlir ;
2930
@@ -102,6 +103,23 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
102103 return newMask;
103104}
104105
106+ // /
107+ static std::optional<int64_t >
108+ getFrontPaddingSize (ConversionPatternRewriter &rewriter, Location loc,
109+ const memref::LinearizedMemRefInfo linearizedInfo,
110+ bool isUnalignedEmulation) {
111+ if (!isUnalignedEmulation)
112+ return 0 ;
113+ auto foldedFrontPaddingSize = getValueOrCreateConstantIndexOp (
114+ rewriter, loc, linearizedInfo.frontPaddingSize );
115+ // try to fold the front padding size into a constant
116+ if (auto frontPadding = dyn_cast_or_null<arith::ConstantIndexOp>(
117+ foldedFrontPaddingSize.getDefiningOp ())) {
118+ return frontPadding.value ();
119+ }
120+ return std::nullopt ;
121+ }
122+
105123namespace {
106124
107125// ===----------------------------------------------------------------------===//
@@ -142,29 +160,66 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
142160 // vector<4xi8>
143161
144162 auto origElements = op.getValueToStore ().getType ().getNumElements ();
145- if (origElements % scale != 0 )
146- return failure ();
163+
164+ // if the size of vector we are loading is not byte-aligned, extra handling
165+ // is needed
166+ bool isUnalignedEmulation = origElements % scale != 0 ;
147167
148168 auto stridedMetadata =
149169 rewriter.create <memref::ExtractStridedMetadataOp>(loc, op.getBase ());
150170
151171 OpFoldResult linearizedIndices;
152- std::tie (std::ignore, linearizedIndices) =
172+ memref::LinearizedMemRefInfo linearizedInfo;
173+ std::tie (linearizedInfo, linearizedIndices) =
153174 memref::getLinearizedMemRefOffsetAndSize (
154175 rewriter, loc, srcBits, dstBits,
155176 stridedMetadata.getConstifiedMixedOffset (),
156177 stridedMetadata.getConstifiedMixedSizes (),
157178 stridedMetadata.getConstifiedMixedStrides (),
158179 getAsOpFoldResult (adaptor.getIndices ()));
159180
160- auto numElements = origElements / scale;
161- auto bitCast = rewriter.create <vector::BitCastOp>(
162- loc, VectorType::get (numElements, newElementType),
163- op.getValueToStore ());
181+ auto foldedFrontPaddingSize = getFrontPaddingSize (
182+ rewriter, loc, linearizedInfo, isUnalignedEmulation);
164183
165- rewriter.replaceOpWithNewOp <vector::StoreOp>(
166- op, bitCast.getResult (), adaptor.getBase (),
167- getValueOrCreateConstantIndexOp (rewriter, loc, linearizedIndices));
184+ if (!foldedFrontPaddingSize) {
185+ // unimplemented case for dynamic front padding size
186+ return failure ();
187+ }
188+
189+ auto numElements =
190+ (*foldedFrontPaddingSize + origElements + scale - 1 ) / scale;
191+ auto newVectorType = VectorType::get (numElements, newElementType);
192+
193+ if (isUnalignedEmulation) {
194+ auto insertedVectorType =
195+ VectorType::get (numElements * scale, oldElementType);
196+
197+ auto linearizedIndicesValue =
198+ getValueOrCreateConstantIndexOp (rewriter, loc, linearizedIndices);
199+ auto passThru =
200+ rewriter.create <vector::LoadOp>(loc, newVectorType, adaptor.getBase (),
201+ ValueRange{linearizedIndicesValue});
202+ auto bitcastedPassThru =
203+ rewriter.create <vector::BitCastOp>(loc, insertedVectorType, passThru);
204+
205+ // just extract it and use it for the strided slice offset
206+ auto insertStridedSlice = rewriter.create <vector::InsertStridedSliceOp>(
207+ loc, insertedVectorType, op.getValueToStore (), bitcastedPassThru,
208+ rewriter.getI64ArrayAttr ({*foldedFrontPaddingSize}),
209+ rewriter.getI64ArrayAttr ({1 }));
210+ // bit cast the vector to the original type
211+ auto bitCast = rewriter.create <vector::BitCastOp>(loc, newVectorType,
212+ insertStridedSlice);
213+
214+ rewriter.replaceOpWithNewOp <vector::StoreOp>(
215+ op, bitCast.getResult (), adaptor.getBase (), linearizedIndicesValue);
216+ } else {
217+ auto bitCast = rewriter.create <vector::BitCastOp>(loc, newVectorType,
218+ op.getValueToStore ());
219+ rewriter.replaceOpWithNewOp <vector::StoreOp>(
220+ op, bitCast.getResult (), adaptor.getBase (),
221+ getValueOrCreateConstantIndexOp (rewriter, loc, linearizedIndices));
222+ }
168223 return success ();
169224 }
170225};
@@ -294,35 +349,67 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
294349 // %1 = vector.load %0[%linear_index] : memref<6xi8>, vector<2xi8>
295350 // %2 = vector.bitcast %1 : vector<2xi8> to vector<4xi4>
296351 //
297- // TODO: Currently, only the even number of elements loading is supported.
298- // To deal with the odd number of elements, one has to extract the
299- // subvector at the proper offset after bit-casting.
352+ // There are cases where the number of elements to load is not byte-aligned,
353+ // for example:
354+ //
355+ // %1 = vector.load %0[%c1, %c0] : memref<3x3xi2>, vector<3xi2>
356+ //
357+ // we will have to load extra bytes and extract the exact slice in between.
358+ //
359+ // %1 = vector.load %0[%c2] : memref<3xi8>, vector<2xi8>
360+ // %2 = vector.bitcast %1 : vector<2xi8> to vector<8xi2>
361+ // %3 = vector.extract_strided_slice %1 {offsets = [2], sizes = [3], strides
362+ // = [1]}
363+ // : vector<8xi2> to vector<3xi2>
364+ //
365+ // TODO: Currently the extract_strided_slice's attributes must be known at
366+ // compile time as they must be constants.
300367
301368 auto origElements = op.getVectorType ().getNumElements ();
302- if (origElements % scale != 0 )
303- return failure ();
369+ bool isUnalignedEmulation = origElements % scale != 0 ;
304370
305371 auto stridedMetadata =
306372 rewriter.create <memref::ExtractStridedMetadataOp>(loc, op.getBase ());
307373
308374 OpFoldResult linearizedIndices;
309- std::tie (std::ignore, linearizedIndices) =
375+ memref::LinearizedMemRefInfo linearizedInfo;
376+ std::tie (linearizedInfo, linearizedIndices) =
310377 memref::getLinearizedMemRefOffsetAndSize (
311378 rewriter, loc, srcBits, dstBits,
312379 stridedMetadata.getConstifiedMixedOffset (),
313380 stridedMetadata.getConstifiedMixedSizes (),
314381 stridedMetadata.getConstifiedMixedStrides (),
315382 getAsOpFoldResult (adaptor.getIndices ()));
316383
317- auto numElements = (origElements + scale - 1 ) / scale;
384+ auto foldedFrontPaddingSize = getFrontPaddingSize (
385+ rewriter, loc, linearizedInfo, isUnalignedEmulation);
386+
387+ if (!foldedFrontPaddingSize) {
388+ // unimplemented case for dynamic front padding size
389+ return failure ();
390+ }
391+
392+ auto numElements =
393+ (*foldedFrontPaddingSize + origElements + scale - 1 ) / scale;
394+ auto loadVectorType = VectorType::get (numElements, newElementType);
318395 auto newLoad = rewriter.create <vector::LoadOp>(
319- loc, VectorType::get (numElements, newElementType) , adaptor.getBase (),
396+ loc, loadVectorType , adaptor.getBase (),
320397 getValueOrCreateConstantIndexOp (rewriter, loc, linearizedIndices));
321398
399+ auto newBitCastType = VectorType::get (numElements * scale, oldElementType);
322400 auto bitCast =
323- rewriter.create <vector::BitCastOp>(loc, op.getType (), newLoad);
324-
325- rewriter.replaceOp (op, bitCast->getResult (0 ));
401+ rewriter.create <vector::BitCastOp>(loc, newBitCastType, newLoad);
402+
403+ if (newBitCastType.getNumElements () != origElements) {
404+ auto extractStridedSlice = rewriter.create <vector::ExtractStridedSliceOp>(
405+ loc, op.getType (), bitCast,
406+ rewriter.getI64ArrayAttr ({*foldedFrontPaddingSize}),
407+ rewriter.getI64ArrayAttr ({origElements}),
408+ rewriter.getI64ArrayAttr ({1 }));
409+ rewriter.replaceOp (op, extractStridedSlice.getResult ());
410+ } else {
411+ rewriter.replaceOp (op, bitCast->getResult (0 ));
412+ }
326413 return success ();
327414 }
328415};
@@ -464,8 +551,8 @@ struct ConvertVectorTransferRead final
464551 int scale = dstBits / srcBits;
465552
466553 auto origElements = op.getVectorType ().getNumElements ();
467- if (origElements % scale != 0 )
468- return failure () ;
554+
555+ bool isUnalignedEmulation = origElements % scale != 0 ;
469556
470557 auto newPadding = rewriter.create <arith::ExtUIOp>(loc, newElementType,
471558 adaptor.getPadding ());
@@ -474,26 +561,47 @@ struct ConvertVectorTransferRead final
474561 rewriter.create <memref::ExtractStridedMetadataOp>(loc, op.getSource ());
475562
476563 OpFoldResult linearizedIndices;
477- std::tie (std::ignore, linearizedIndices) =
564+ memref::LinearizedMemRefInfo linearizedInfo;
565+ std::tie (linearizedInfo, linearizedIndices) =
478566 memref::getLinearizedMemRefOffsetAndSize (
479567 rewriter, loc, srcBits, dstBits,
480568 stridedMetadata.getConstifiedMixedOffset (),
481569 stridedMetadata.getConstifiedMixedSizes (),
482570 stridedMetadata.getConstifiedMixedStrides (),
483571 getAsOpFoldResult (adaptor.getIndices ()));
484572
485- auto numElements = (origElements + scale - 1 ) / scale;
573+ auto foldedFrontPaddingSize = getFrontPaddingSize (
574+ rewriter, loc, linearizedInfo, isUnalignedEmulation);
575+
576+ if (!foldedFrontPaddingSize) {
577+ // unimplemented case for dynamic front padding size
578+ return failure ();
579+ }
580+
581+ auto numElements =
582+ (*foldedFrontPaddingSize + origElements + scale - 1 ) / scale;
486583 auto newReadType = VectorType::get (numElements, newElementType);
487584
488585 auto newRead = rewriter.create <vector::TransferReadOp>(
489586 loc, newReadType, adaptor.getSource (),
490587 getValueOrCreateConstantIndexOp (rewriter, loc, linearizedIndices),
491588 newPadding);
492589
590+ auto bitCastType = VectorType::get (numElements * scale, oldElementType);
493591 auto bitCast =
494- rewriter.create <vector::BitCastOp>(loc, op.getType (), newRead);
592+ rewriter.create <vector::BitCastOp>(loc, bitCastType, newRead);
593+
594+ if (isUnalignedEmulation) {
595+ // we only extract a portion of the vector.
596+ rewriter.replaceOpWithNewOp <vector::ExtractStridedSliceOp>(
597+ op, op.getType (), bitCast,
598+ rewriter.getI64ArrayAttr ({*foldedFrontPaddingSize}),
599+ rewriter.getI64ArrayAttr ({origElements}),
600+ rewriter.getI64ArrayAttr ({1 }));
601+ } else {
602+ rewriter.replaceOp (op, bitCast->getResult (0 ));
603+ }
495604
496- rewriter.replaceOp (op, bitCast->getResult (0 ));
497605 return success ();
498606 }
499607};
0 commit comments