1818#include " mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
1919#include " mlir/IR/BuiltinAttributes.h"
2020#include " mlir/IR/BuiltinTypes.h"
21+ #include " mlir/IR/OpDefinition.h"
2122#include " mlir/IR/TypeUtilities.h"
2223#include " mlir/IR/Value.h"
2324#include " mlir/Transforms/DialectConversion.h"
@@ -149,6 +150,61 @@ static Value insertSubvectorInto(RewriterBase &rewriter, Location loc,
149150 dest, offsets, strides);
150151}
151152
153+ static void dynamicallyExtractElementsToVector (
154+ RewriterBase &rewriter, Location loc, TypedValue<VectorType> srcVec,
155+ Value destVec, OpFoldResult srcOffsetVar, int64_t loopSize) {
156+ /*
157+ // Create affine maps for the lower and upper bounds
158+ AffineMap lowerBoundMap = AffineMap::getConstantMap(0, rewriter.getContext());
159+ AffineMap upperBoundMap =
160+ AffineMap::getConstantMap(loopSize, rewriter.getContext());
161+
162+ auto forLoop = rewriter.create<affine::AffineForOp>(
163+ loc, ValueRange{}, lowerBoundMap, ValueRange{}, upperBoundMap, 1,
164+ ArrayRef<Value>(destVec));
165+
166+ OpBuilder builder =
167+ OpBuilder::atBlockEnd(forLoop.getBody(), rewriter.getListener());
168+
169+ auto iv = forLoop.getInductionVar();
170+
171+ auto loopDestVec = forLoop.getRegionIterArgs()[0];
172+ auto extractLoc = builder.create<arith::AddIOp>(
173+ loc, rewriter.getIndexType(), srcOffsetVar.dyn_cast<Value>(), iv);
174+ auto extractElemOp = builder.create<vector::ExtractElementOp>(
175+ loc, elemType, srcVec, extractLoc);
176+ auto insertElemOp = builder.create<vector::InsertElementOp>(
177+ loc, extractElemOp, loopDestVec, iv);
178+ builder.create<affine::AffineYieldOp>(loc,
179+ ValueRange{insertElemOp->getResult(0)});
180+ return forLoop->getResult(0);
181+ */
182+ for (int i = 0 ; i < loopSize; ++i) {
183+ Value extractLoc;
184+ if (i == 0 ) {
185+ extractLoc = srcOffsetVar.dyn_cast <Value>();
186+ } else {
187+ extractLoc = rewriter.create <arith::AddIOp>(
188+ loc, rewriter.getIndexType (), srcOffsetVar.dyn_cast <Value>(),
189+ rewriter.create <arith::ConstantIndexOp>(loc, i));
190+ }
191+ auto extractOp =
192+ rewriter.create <vector::ExtractOp>(loc, srcVec, extractLoc);
193+ rewriter.create <vector::InsertOp>(loc, extractOp, destVec, i);
194+ }
195+ }
196+
197+ static TypedValue<VectorType>
198+ emulatedVectorLoad (ConversionPatternRewriter &rewriter, Location loc,
199+ Value base, OpFoldResult linearizedIndices, int64_t numBytes,
200+ int64_t scale, Type oldElememtType, Type newElementType) {
201+ auto newLoad = rewriter.create <vector::LoadOp>(
202+ loc, VectorType::get (numBytes, newElementType), base,
203+ getValueOrCreateConstantIndexOp (rewriter, loc, linearizedIndices));
204+ return rewriter.create <vector::BitCastOp>(
205+ loc, VectorType::get (numBytes * scale, oldElememtType), newLoad);
206+ };
207+
152208namespace {
153209
154210// ===----------------------------------------------------------------------===//
@@ -380,26 +436,29 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
380436 ? getConstantIntValue (linearizedInfo.intraDataOffset )
381437 : 0 ;
382438
383- if (!foldedIntraVectorOffset) {
384- // unimplemented case for dynamic intra vector offset
385- return failure ();
386- }
387-
439+ // always load enough elements which can cover the original elements
440+ auto maxintraVectorOffset =
441+ foldedIntraVectorOffset ? *foldedIntraVectorOffset : scale - 1 ;
388442 auto numElements =
389- llvm::divideCeil (*foldedIntraVectorOffset + origElements, scale);
390- auto newLoad = rewriter.create <vector::LoadOp>(
391- loc, VectorType::get (numElements, newElementType), adaptor.getBase (),
392- getValueOrCreateConstantIndexOp (rewriter, loc, linearizedIndices));
393-
394- Value result = rewriter.create <vector::BitCastOp>(
395- loc, VectorType::get (numElements * scale, oldElementType), newLoad);
443+ llvm::divideCeil (maxintraVectorOffset + origElements, scale);
444+ Value result =
445+ emulatedVectorLoad (rewriter, loc, adaptor.getBase (), linearizedIndices,
446+ numElements, scale, oldElementType, newElementType);
396447
397- if (isUnalignedEmulation) {
398- result = extractSubvectorFrom (rewriter, loc, op.getType (), result,
399- *foldedIntraVectorOffset, origElements);
448+ if (foldedIntraVectorOffset) {
449+ if (isUnalignedEmulation) {
450+ result = extractSubvectorFrom (rewriter, loc, op.getType (), result,
451+ *foldedIntraVectorOffset, origElements);
452+ }
453+ rewriter.replaceOp (op, result);
454+ } else {
455+ auto resultVector = rewriter.create <arith::ConstantOp>(
456+ loc, op.getType (), rewriter.getZeroAttr (op.getType ()));
457+ dynamicallyExtractElementsToVector (
458+ rewriter, loc, dyn_cast<TypedValue<VectorType>>(result), resultVector,
459+ linearizedInfo.intraVectorOffset , origElements);
460+ rewriter.replaceOp (op, resultVector);
400461 }
401-
402- rewriter.replaceOp (op, result);
403462 return success ();
404463 }
405464};
@@ -604,13 +663,10 @@ struct ConvertVectorTransferRead final
604663 ? getConstantIntValue (linearizedInfo.intraDataOffset )
605664 : 0 ;
606665
607- if (!foldedIntraVectorOffset) {
608- // unimplemented case for dynamic inra-vector offset
609- return failure ();
610- }
611-
666+ auto maxIntraVectorOffset =
667+ foldedIntraVectorOffset ? *foldedIntraVectorOffset : scale - 1 ;
612668 auto numElements =
613- llvm::divideCeil (*foldedIntraVectorOffset + origElements, scale);
669+ llvm::divideCeil (maxIntraVectorOffset + origElements, scale);
614670
615671 auto newRead = rewriter.create <vector::TransferReadOp>(
616672 loc, VectorType::get (numElements, newElementType), adaptor.getSource (),
@@ -621,9 +677,17 @@ struct ConvertVectorTransferRead final
621677 loc, VectorType::get (numElements * scale, oldElementType), newRead);
622678
623679 Value result = bitCast->getResult (0 );
624- if (isUnalignedEmulation) {
625- result = extractSubvectorFrom (rewriter, loc, op.getType (), result,
626- *foldedIntraVectorOffset, origElements);
680+ if (foldedIntraVectorOffset) {
681+ if (isUnalignedEmulation) {
682+ result = extractSubvectorFrom (rewriter, loc, op.getType (), result,
683+ *foldedIntraVectorOffset, origElements);
684+ }
685+ } else {
686+ result = rewriter.create <arith::ConstantOp>(
687+ loc, op.getType (), rewriter.getZeroAttr (op.getType ()));
688+ dynamicallyExtractElementsToVector (rewriter, loc, bitCast, result,
689+ linearizedInfo.intraVectorOffset ,
690+ origElements);
627691 }
628692 rewriter.replaceOp (op, result);
629693
0 commit comments