@@ -130,6 +130,7 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
130130 return newMask;
131131}
132132
133+ // / A wrapper function for emitting `vector.extract_strided_slice`.
133134static Value extractSubvectorFrom (RewriterBase &rewriter, Location loc,
134135 VectorType extractType, Value vector,
135136 int64_t frontOffset, int64_t subvecSize) {
@@ -142,6 +143,7 @@ static Value extractSubvectorFrom(RewriterBase &rewriter, Location loc,
142143 ->getResult (0 );
143144}
144145
146+ // / A wrapper function for emitting `vector.insert_strided_slice`.
145147static Value insertSubvectorInto (RewriterBase &rewriter, Location loc,
146148 Value src, Value dest, int64_t offset) {
147149 auto offsets = rewriter.getI64ArrayAttr ({offset});
@@ -150,36 +152,14 @@ static Value insertSubvectorInto(RewriterBase &rewriter, Location loc,
150152 dest, offsets, strides);
151153}
152154
155+ // / Extracts `lengthSubvec` elements from `srcVec` into `destVec` starting at
156+ // / the offset specified by `srcOffsetVar`. Use this function when
157+ // / `srcOffsetVar` is not a constant, making it impossible to use
158+ // / vector.extract_strided_slice, as it requires constant offsets.
153159static void dynamicallyExtractElementsToVector (
154160 RewriterBase &rewriter, Location loc, TypedValue<VectorType> srcVec,
155- Value destVec, OpFoldResult srcOffsetVar, int64_t loopSize) {
156- /*
157- // Create affine maps for the lower and upper bounds
158- AffineMap lowerBoundMap = AffineMap::getConstantMap(0, rewriter.getContext());
159- AffineMap upperBoundMap =
160- AffineMap::getConstantMap(loopSize, rewriter.getContext());
161-
162- auto forLoop = rewriter.create<affine::AffineForOp>(
163- loc, ValueRange{}, lowerBoundMap, ValueRange{}, upperBoundMap, 1,
164- ArrayRef<Value>(destVec));
165-
166- OpBuilder builder =
167- OpBuilder::atBlockEnd(forLoop.getBody(), rewriter.getListener());
168-
169- auto iv = forLoop.getInductionVar();
170-
171- auto loopDestVec = forLoop.getRegionIterArgs()[0];
172- auto extractLoc = builder.create<arith::AddIOp>(
173- loc, rewriter.getIndexType(), srcOffsetVar.dyn_cast<Value>(), iv);
174- auto extractElemOp = builder.create<vector::ExtractElementOp>(
175- loc, elemType, srcVec, extractLoc);
176- auto insertElemOp = builder.create<vector::InsertElementOp>(
177- loc, extractElemOp, loopDestVec, iv);
178- builder.create<affine::AffineYieldOp>(loc,
179- ValueRange{insertElemOp->getResult(0)});
180- return forLoop->getResult(0);
181- */
182- for (int i = 0 ; i < loopSize; ++i) {
161+ Value destVec, OpFoldResult srcOffsetVar, int64_t lengthSubvec) {
162+ for (int i = 0 ; i < lengthSubvec; ++i) {
183163 Value extractLoc;
184164 if (i == 0 ) {
185165 extractLoc = srcOffsetVar.dyn_cast <Value>();
@@ -194,15 +174,21 @@ static void dynamicallyExtractElementsToVector(
194174 }
195175}
196176
177+ // / Load `numLoadedElements` of `newElementType` from `base` at
178+ // / `linearizedIndices`, then bitcast the result into a vector of
179+ // / `oldElementType`.
197180static TypedValue<VectorType>
198181emulatedVectorLoad (ConversionPatternRewriter &rewriter, Location loc,
199- Value base, OpFoldResult linearizedIndices, int64_t numBytes,
200- int64_t scale, Type oldElememtType, Type newElementType) {
182+ Value base, OpFoldResult linearizedIndices,
183+ int64_t numLoadedElements, Type oldElememtType,
184+ Type newElementType) {
185+ auto scale = newElementType.getIntOrFloatBitWidth () /
186+ oldElememtType.getIntOrFloatBitWidth ();
201187 auto newLoad = rewriter.create <vector::LoadOp>(
202- loc, VectorType::get (numBytes , newElementType), base,
188+ loc, VectorType::get (numLoadedElements , newElementType), base,
203189 getValueOrCreateConstantIndexOp (rewriter, loc, linearizedIndices));
204190 return rewriter.create <vector::BitCastOp>(
205- loc, VectorType::get (numBytes * scale, oldElememtType), newLoad);
191+ loc, VectorType::get (numLoadedElements * scale, oldElememtType), newLoad);
206192};
207193
208194namespace {
@@ -443,7 +429,7 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
443429 llvm::divideCeil (maxintraDataOffset + origElements, scale);
444430 Value result =
445431 emulatedVectorLoad (rewriter, loc, adaptor.getBase (), linearizedIndices,
446- numElements, scale, oldElementType, newElementType);
432+ numElements, oldElementType, newElementType);
447433
448434 if (foldedIntraVectorOffset) {
449435 if (isUnalignedEmulation) {
0 commit comments