2424#include " mlir/Pass/Pass.h"
2525#include " mlir/Support/LogicalResult.h"
2626#include " mlir/Transforms/GreedyPatternRewriteDriver.h"
27- #include " mlir/Transforms/WalkPatternRewriteDriver .h"
27+ #include " llvm/Support/MathExtras .h"
2828
2929namespace mlir ::amdgpu {
3030#define GEN_PASS_DEF_AMDGPUTRANSFERREADTOLOADPASS
@@ -76,6 +76,9 @@ static LogicalResult transferPreconditions(
7676 if (!memRefType.isLastDimUnitStride ())
7777 return rewriter.notifyMatchFailure (xferOp, " != 1 stride needs VectorToSCF" );
7878
79+ if (memRefType.getElementTypeBitWidth () < 8 )
80+ return rewriter.notifyMatchFailure (xferOp, " unsupported sub-byte type" );
81+
7982 // If there is broadcasting involved then we first load the unbroadcasted
8083 // vector, and then broadcast it with `vector.broadcast`.
8184 ArrayRef<int64_t > vectorShape = xferOp.getVectorType ().getShape ();
@@ -127,14 +130,17 @@ static Value createVectorLoadForMaskedLoad(OpBuilder &builder, Location loc,
127130 return res;
128131}
129132
133+ static constexpr char kTransferReadNeedsMask [] =
134+ " amdgpu.buffer_transfer_read_needs_mask" ;
135+
130136namespace {
131137
132138struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
133139 using OpRewritePattern::OpRewritePattern;
134140
135141 LogicalResult matchAndRewrite (vector::TransferReadOp readOp,
136142 PatternRewriter &rewriter) const override {
137- if (readOp->hasAttr (" amdgpu.buffer_transfer_read_needs_mask " ))
143+ if (readOp->hasAttr (kTransferReadNeedsMask ))
138144 return failure ();
139145
140146 bool requiresBroadcasting = false ;
@@ -154,71 +160,96 @@ struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
154160
155161 auto stridedMetadata =
156162 rewriter.create <memref::ExtractStridedMetadataOp>(loc, src);
163+ SmallVector<OpFoldResult> strides =
164+ stridedMetadata.getConstifiedMixedStrides ();
165+ SmallVector<OpFoldResult> sizes =
166+ stridedMetadata.getConstifiedMixedSizes ();
167+ OpFoldResult offset =
168+ stridedMetadata.getConstifiedMixedOffset ();
157169 OpFoldResult linearizedIndices;
158170 std::tie (std::ignore, linearizedIndices) =
159- memref::getLinearizedMemRefOffsetAndSize (
160- rewriter, loc, elementBitWidth, elementBitWidth,
161- stridedMetadata.getConstifiedMixedOffset (),
162- stridedMetadata.getConstifiedMixedSizes (),
163- stridedMetadata.getConstifiedMixedStrides (), indices);
164- Value linearIndex =
165- getValueOrCreateConstantIndexOp (rewriter, loc, linearizedIndices);
171+ memref::getLinearizedMemRefOffsetAndSize (rewriter, loc, elementBitWidth,
172+ elementBitWidth, offset, sizes,
173+ strides, indices);
166174
167175 // TODO(jerryyin): Fix the getLinearizedMemRefOffsetAndSize() function
168176 // Note below doesn't give the correct result for the linearized size.
169177 // Value totalSize = getValueOrCreateConstantIndexOp(
170178 // rewriter, loc, linearizedInfo.linearizedSize);
171- // It compute the mutiplied sizes of all dimensions instead of taking
179+ // It computes the multiplied sizes of all dimensions instead of taking
172180 // the maximum of each dimension size * stride.
173181 SmallVector<AffineExpr> productExpressions;
174182 SmallVector<Value> productResults;
175183 unsigned sourceRank = cast<ShapedType>(src.getType ()).getRank ();
176184
177185 SmallVector<AffineExpr> symbols (2 * sourceRank);
178- SmallVector<Value> offsetValues ( 2 * sourceRank) ;
186+ SmallVector<Value> offsetValues;
179187 bindSymbolsList (rewriter.getContext (), MutableArrayRef{symbols});
188+
189+ size_t symbolIndex = 0 ;
180190 for (size_t i = 0 ; i < sourceRank; ++i) {
181- unsigned offsetIdx = 2 * i;
182- productExpressions.push_back (symbols[offsetIdx] * symbols[offsetIdx + 1 ]);
183- offsetValues[offsetIdx] = stridedMetadata.getStrides ()[i];
184- offsetValues[offsetIdx + 1 ] = stridedMetadata.getSizes ()[i];
191+ AffineExpr strideExpr, sizeExpr;
192+ OpFoldResult stride = strides[i];
193+ OpFoldResult size = sizes[i];
194+ if (auto constantStride =
195+ getConstantIntValue (stride)) {
196+ strideExpr = rewriter.getAffineConstantExpr (*constantStride);
197+ } else {
198+ strideExpr = symbols[symbolIndex++];
199+ offsetValues.push_back (getValueOrCreateConstantIndexOp (
200+ rewriter, loc, stride));
201+ }
202+
203+ if (auto constantSize =
204+ getConstantIntValue (size)) {
205+ sizeExpr = rewriter.getAffineConstantExpr (*constantSize);
206+ } else {
207+ sizeExpr = symbols[symbolIndex++];
208+ offsetValues.push_back (getValueOrCreateConstantIndexOp (
209+ rewriter, loc, size));
210+ }
211+
212+ productExpressions.push_back (strideExpr * sizeExpr);
185213 }
186214
187215 AffineMap maxMap = AffineMap::get (
188- /* dimCount=*/ 0 , /* symbolCount=*/ symbols. size () , productExpressions,
216+ /* dimCount=*/ 0 , /* symbolCount=*/ symbolIndex , productExpressions,
189217 rewriter.getContext ());
190218 Value totalSize =
191219 rewriter.create <affine::AffineMaxOp>(loc, maxMap, offsetValues);
192220
193221 // delta = bufferSize - linearizedOffset
194222 Value vectorSizeOffset =
195223 rewriter.create <arith::ConstantIndexOp>(loc, vectorSize);
224+ Value linearIndex =
225+ getValueOrCreateConstantIndexOp (rewriter, loc, linearizedIndices);
196226 Value delta = rewriter.create <arith::SubIOp>(loc, totalSize, linearIndex);
197227
198228 // 1) check if delta < vectorSize
199229 Value isOutofBounds = rewriter.create <arith::CmpIOp>(
200- loc, arith::CmpIPredicate::ule , delta, vectorSizeOffset);
230+ loc, arith::CmpIPredicate::ult , delta, vectorSizeOffset);
201231
202232 // 2) check if (detla_bytes % (32 / elementBitwidth) != 0)
203233 Value deltaBytes = rewriter.create <arith::MulIOp>(
204234 loc, delta,
205235 rewriter.create <arith::ConstantIndexOp>(loc, elementBitWidth / 8 ));
206236 Value elementsPerWord = rewriter.create <arith::ConstantIndexOp>(
207- loc, elementBitWidth < 32 ? 32 / elementBitWidth : 1 );
237+ loc, llvm::divideCeil ( 32 , elementBitWidth) );
208238 Value isNotWordAligned = rewriter.create <arith::CmpIOp>(
209239 loc, arith::CmpIPredicate::ne,
210240 rewriter.create <arith::RemUIOp>(loc, deltaBytes, elementsPerWord),
211241 rewriter.create <arith::ConstantIndexOp>(loc, 0 ));
212242
213243 // We take the fallback of transfer_read default lowering only it is both
214- // out-of-bounds and not word aligned.
244+ // out-of-bounds and not word aligned. The fallback ensures correct results
245+ // when loading at the boundary of the buffer since buffer load returns
246+ // inconsistent zeros for the whole word when boundary is crossed.
215247 Value ifCondition =
216248 rewriter.create <arith::AndIOp>(loc, isOutofBounds, isNotWordAligned);
217249
218250 auto thenBuilder = [&](OpBuilder &builder, Location loc) {
219251 Operation *read = builder.clone (*readOp.getOperation ());
220- read->setAttr (" amdgpu.buffer_transfer_read_needs_mask" ,
221- builder.getUnitAttr ());
252+ read->setAttr (kTransferReadNeedsMask , builder.getUnitAttr ());
222253 Value readResult = read->getResult (0 );
223254 builder.create <scf::YieldOp>(loc, readResult);
224255 };
@@ -243,7 +274,6 @@ struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
243274void mlir::amdgpu::populateAmdgpuTransferReadToLoadPatterns (
244275 RewritePatternSet &patterns) {
245276 patterns.add <TransferReadLowering>(patterns.getContext ());
246- vector::populateVectorTransferLoweringPatterns (patterns);
247277}
248278
249279struct AmdgpuTransferReadToLoadPass final
0 commit comments