99#include " mlir/Dialect/AMDGPU/Transforms/Passes.h"
1010
1111#include " mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
12+ #include " mlir/Dialect/Affine/IR/AffineOps.h"
13+ #include " mlir/Dialect/Arith/IR/Arith.h"
14+ #include " mlir/Dialect/Arith/Utils/Utils.h"
15+ #include " mlir/Dialect/MemRef/IR/MemRef.h"
16+ #include " mlir/Dialect/MemRef/Utils/MemRefUtils.h"
17+ #include " mlir/Dialect/SCF/IR/SCF.h"
1218#include " mlir/Dialect/Vector/IR/VectorOps.h"
19+ #include " mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
1320#include " mlir/IR/BuiltinTypes.h"
21+ #include " mlir/IR/OpDefinition.h"
1422#include " mlir/IR/PatternMatch.h"
1523#include " mlir/IR/TypeUtilities.h"
1624#include " mlir/Pass/Pass.h"
1725#include " mlir/Support/LogicalResult.h"
18- #include " mlir/Transforms/WalkPatternRewriteDriver.h"
26+ #include " mlir/Transforms/GreedyPatternRewriteDriver.h"
27+ #include " llvm/Support/MathExtras.h"
1928
2029namespace mlir ::amdgpu {
2130#define GEN_PASS_DEF_AMDGPUTRANSFERREADTOLOADPASS
@@ -67,6 +76,9 @@ static LogicalResult transferPreconditions(
6776 if (!memRefType.isLastDimUnitStride ())
6877 return rewriter.notifyMatchFailure (xferOp, " != 1 stride needs VectorToSCF" );
6978
79+ if (memRefType.getElementTypeBitWidth () < 8 )
80+ return rewriter.notifyMatchFailure (xferOp, " unsupported sub-byte type" );
81+
7082 // If there is broadcasting involved then we first load the unbroadcasted
7183 // vector, and then broadcast it with `vector.broadcast`.
7284 ArrayRef<int64_t > vectorShape = xferOp.getVectorType ().getShape ();
@@ -101,13 +113,35 @@ static LogicalResult transferPreconditions(
101113 return success ();
102114}
103115
116+ static Value createVectorLoadForMaskedLoad (OpBuilder &builder, Location loc,
117+ vector::TransferReadOp readOp,
118+ bool requiresBroadcasting,
119+ VectorType unbroadcastedVectorType) {
120+ Value fill = builder.create <vector::SplatOp>(loc, unbroadcastedVectorType,
121+ readOp.getPadding ());
122+ Value load = builder.create <vector::LoadOp>(
123+ loc, unbroadcastedVectorType, readOp.getSource (), readOp.getIndices ());
124+ Value res = builder.create <arith::SelectOp>(loc, unbroadcastedVectorType,
125+ readOp.getMask (), load, fill);
126+ // Insert a broadcasting op if required.
127+ if (requiresBroadcasting) {
128+ res = builder.create <vector::BroadcastOp>(loc, readOp.getVectorType (), res);
129+ }
130+ return res;
131+ }
132+
133+ static constexpr char kTransferReadNeedsMask [] =
134+ " amdgpu.buffer_transfer_read_needs_mask" ;
135+
104136namespace {
105137
106138struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
107139 using OpRewritePattern::OpRewritePattern;
108140
109141 LogicalResult matchAndRewrite (vector::TransferReadOp readOp,
110142 PatternRewriter &rewriter) const override {
143+ if (readOp->hasAttr (kTransferReadNeedsMask ))
144+ return failure ();
111145
112146 bool requiresBroadcasting = false ;
113147 VectorType unbroadcastedVectorType;
@@ -117,20 +151,115 @@ struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
117151 }
118152
119153 Location loc = readOp.getLoc ();
120- Value fill = rewriter.create <vector::SplatOp>(loc, unbroadcastedVectorType,
121- readOp.getPadding ());
122- Value load = rewriter.create <vector::LoadOp>(
123- loc, unbroadcastedVectorType, readOp.getSource (), readOp.getIndices ());
124- Value res = rewriter.create <arith::SelectOp>(loc, unbroadcastedVectorType,
125- readOp.getMask (), load, fill);
126-
127- // Insert a broadcasting op if required.
128- if (requiresBroadcasting) {
129- res = rewriter.create <vector::BroadcastOp>(loc, readOp.getVectorType (),
130- res);
154+ Value src = readOp.getSource ();
155+
156+ VectorType vectorType = readOp.getVectorType ();
157+ int64_t vectorSize = vectorType.getNumElements ();
158+ int64_t elementBitWidth = vectorType.getElementTypeBitWidth ();
159+ SmallVector<OpFoldResult> indices = readOp.getIndices ();
160+
161+ auto stridedMetadata =
162+ rewriter.create <memref::ExtractStridedMetadataOp>(loc, src);
163+ SmallVector<OpFoldResult> strides =
164+ stridedMetadata.getConstifiedMixedStrides ();
165+ SmallVector<OpFoldResult> sizes = stridedMetadata.getConstifiedMixedSizes ();
166+ OpFoldResult offset = stridedMetadata.getConstifiedMixedOffset ();
167+ OpFoldResult linearizedIndices;
168+ std::tie (std::ignore, linearizedIndices) =
169+ memref::getLinearizedMemRefOffsetAndSize (rewriter, loc, elementBitWidth,
170+ elementBitWidth, offset, sizes,
171+ strides, indices);
172+
173+ // TODO(jerryyin): Fix the getLinearizedMemRefOffsetAndSize() function
174+ // Note below doesn't give the correct result for the linearized size.
175+ // Value totalSize = getValueOrCreateConstantIndexOp(
176+ // rewriter, loc, linearizedInfo.linearizedSize);
177+ // It computes the multiplied sizes of all dimensions instead of taking
178+ // the maximum of each dimension size * stride.
179+ SmallVector<AffineExpr> productExpressions;
180+ SmallVector<Value> productResults;
181+ unsigned sourceRank = cast<ShapedType>(src.getType ()).getRank ();
182+
183+ SmallVector<AffineExpr> symbols (2 * sourceRank);
184+ SmallVector<Value> offsetValues;
185+ bindSymbolsList (rewriter.getContext (), MutableArrayRef{symbols});
186+
187+ size_t symbolIndex = 0 ;
188+ for (size_t i = 0 ; i < sourceRank; ++i) {
189+ AffineExpr strideExpr, sizeExpr;
190+ OpFoldResult stride = strides[i];
191+ OpFoldResult size = sizes[i];
192+ if (auto constantStride = getConstantIntValue (stride)) {
193+ strideExpr = rewriter.getAffineConstantExpr (*constantStride);
194+ } else {
195+ strideExpr = symbols[symbolIndex++];
196+ offsetValues.push_back (
197+ getValueOrCreateConstantIndexOp (rewriter, loc, stride));
198+ }
199+
200+ if (auto constantSize = getConstantIntValue (size)) {
201+ sizeExpr = rewriter.getAffineConstantExpr (*constantSize);
202+ } else {
203+ sizeExpr = symbols[symbolIndex++];
204+ offsetValues.push_back (
205+ getValueOrCreateConstantIndexOp (rewriter, loc, size));
206+ }
207+
208+ productExpressions.push_back (strideExpr * sizeExpr);
131209 }
132210
133- rewriter.replaceOp (readOp, res);
211+ AffineMap maxMap = AffineMap::get (
212+ /* dimCount=*/ 0 , /* symbolCount=*/ symbolIndex, productExpressions,
213+ rewriter.getContext ());
214+ Value totalSize =
215+ rewriter.create <affine::AffineMaxOp>(loc, maxMap, offsetValues);
216+
217+ // delta = bufferSize - linearizedOffset
218+ Value vectorSizeOffset =
219+ rewriter.create <arith::ConstantIndexOp>(loc, vectorSize);
220+ Value linearIndex =
221+ getValueOrCreateConstantIndexOp (rewriter, loc, linearizedIndices);
222+ Value delta = rewriter.create <arith::SubIOp>(loc, totalSize, linearIndex);
223+
224+ // 1) check if delta < vectorSize
225+ Value isOutofBounds = rewriter.create <arith::CmpIOp>(
226+ loc, arith::CmpIPredicate::ult, delta, vectorSizeOffset);
227+
228+ // 2) check if (detla_bytes % (32 / elementBitwidth) != 0)
229+ Value deltaBytes = rewriter.create <arith::MulIOp>(
230+ loc, delta,
231+ rewriter.create <arith::ConstantIndexOp>(loc, elementBitWidth / 8 ));
232+ Value elementsPerWord = rewriter.create <arith::ConstantIndexOp>(
233+ loc, llvm::divideCeil (32 , elementBitWidth));
234+ Value isNotWordAligned = rewriter.create <arith::CmpIOp>(
235+ loc, arith::CmpIPredicate::ne,
236+ rewriter.create <arith::RemUIOp>(loc, deltaBytes, elementsPerWord),
237+ rewriter.create <arith::ConstantIndexOp>(loc, 0 ));
238+
239+ // We take the fallback of transfer_read default lowering only it is both
240+ // out-of-bounds and not word aligned. The fallback ensures correct results
241+ // when loading at the boundary of the buffer since buffer load returns
242+ // inconsistent zeros for the whole word when boundary is crossed.
243+ Value ifCondition =
244+ rewriter.create <arith::AndIOp>(loc, isOutofBounds, isNotWordAligned);
245+
246+ auto thenBuilder = [&](OpBuilder &builder, Location loc) {
247+ Operation *read = builder.clone (*readOp.getOperation ());
248+ read->setAttr (kTransferReadNeedsMask , builder.getUnitAttr ());
249+ Value readResult = read->getResult (0 );
250+ builder.create <scf::YieldOp>(loc, readResult);
251+ };
252+
253+ auto elseBuilder = [&](OpBuilder &builder, Location loc) {
254+ Value res = createVectorLoadForMaskedLoad (
255+ builder, loc, readOp, requiresBroadcasting, unbroadcastedVectorType);
256+ rewriter.create <scf::YieldOp>(loc, res);
257+ };
258+
259+ auto ifOp =
260+ rewriter.create <scf::IfOp>(loc, ifCondition, thenBuilder, elseBuilder);
261+
262+ rewriter.replaceOp (readOp, ifOp);
134263
135264 return success ();
136265 }
@@ -149,6 +278,8 @@ struct AmdgpuTransferReadToLoadPass final
149278 void runOnOperation () override {
150279 RewritePatternSet patterns (&getContext ());
151280 populateAmdgpuTransferReadToLoadPatterns (patterns);
152- walkAndApplyPatterns (getOperation (), std::move (patterns));
281+ if (failed (applyPatternsGreedily (getOperation (), std::move (patterns)))) {
282+ return signalPassFailure ();
283+ }
153284 }
154285};
0 commit comments