1616#include " mlir/Dialect/MemRef/Utils/MemRefUtils.h"
1717#include " mlir/Dialect/SCF/IR/SCF.h"
1818#include " mlir/Dialect/Vector/IR/VectorOps.h"
19+ #include " mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
1920#include " mlir/IR/BuiltinTypes.h"
2021#include " mlir/IR/OpDefinition.h"
2122#include " mlir/IR/PatternMatch.h"
2223#include " mlir/IR/TypeUtilities.h"
2324#include " mlir/Pass/Pass.h"
2425#include " mlir/Support/LogicalResult.h"
26+ #include " mlir/Transforms/GreedyPatternRewriteDriver.h"
2527#include " mlir/Transforms/WalkPatternRewriteDriver.h"
2628
2729namespace mlir ::amdgpu {
@@ -132,7 +134,7 @@ struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
132134
133135 LogicalResult matchAndRewrite (vector::TransferReadOp readOp,
134136 PatternRewriter &rewriter) const override {
135- if (readOp->hasAttr (" amdgpu.transformed " ))
137+ if (readOp->hasAttr (" amdgpu.buffer_transfer_read_needs_mask " ))
136138 return failure ();
137139
138140 bool requiresBroadcasting = false ;
@@ -148,7 +150,6 @@ struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
148150 VectorType vectorType = readOp.getVectorType ();
149151 int64_t vectorSize = vectorType.getNumElements ();
150152 int64_t elementBitWidth = vectorType.getElementTypeBitWidth ();
151- // Value linearIndex = rewriter.create<arith::ConstantIndexOp>(loc, 0);
152153 SmallVector<OpFoldResult> indices = readOp.getIndices ();
153154
154155 auto stridedMetadata =
@@ -161,16 +162,15 @@ struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
161162 stridedMetadata.getConstifiedMixedOffset (),
162163 stridedMetadata.getConstifiedMixedSizes (),
163164 stridedMetadata.getConstifiedMixedStrides (), indices);
164- // OpFoldResult linearIndexSize = linearizedInfo.linearizedSize;
165165 Value linearIndex =
166166 getValueOrCreateConstantIndexOp (rewriter, loc, linearizedIndices);
167167
168- // Note below doesn't give the correct result for the linearized size.
169- // It compute the mutiplied sizes of all dimensions instead of taking
170- // the maximum of each dimension size * stride.
171168 // TODO(jerryyin): Fix the getLinearizedMemRefOffsetAndSize() function
169+ // Note below doesn't give the correct result for the linearized size.
172170 // Value totalSize = getValueOrCreateConstantIndexOp(
173171 // rewriter, loc, linearizedInfo.linearizedSize);
172+ // It compute the mutiplied sizes of all dimensions instead of taking
173+ // the maximum of each dimension size * stride.
174174 SmallVector<AffineExpr> productExpressions;
175175 SmallVector<Value> productResults;
176176 unsigned sourceRank =
@@ -201,7 +201,7 @@ struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
201201 Value isOutofBounds = rewriter.create <arith::CmpIOp>(
202202 loc, arith::CmpIPredicate::ule, delta, vectorSizeOffset);
203203
204- // 2) check if (detla(bytes) % (32 / elementBitwidth) != 0)
204+ // 2) check if (detla_bytes % (32 / elementBitwidth) != 0)
205205 Value deltaBytes = rewriter.create <arith::MulIOp>(
206206 loc, delta,
207207 rewriter.create <arith::ConstantIndexOp>(loc, elementBitWidth / 8 ));
@@ -219,7 +219,8 @@ struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
219219
220220 auto thenBuilder = [&](OpBuilder &builder, Location loc) {
221221 Operation *read = builder.clone (*readOp.getOperation ());
222- read->setAttr (" amdgpu.transformed" , builder.getUnitAttr ());
222+ read->setAttr (" amdgpu.buffer_transfer_read_needs_mask" ,
223+ builder.getUnitAttr ());
223224 Value readResult = read->getResult (0 );
224225 builder.create <scf::YieldOp>(loc, readResult);
225226 };
@@ -244,6 +245,7 @@ struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
244245void mlir::amdgpu::populateAmdgpuTransferReadToLoadPatterns (
245246 RewritePatternSet &patterns) {
246247 patterns.add <TransferReadLowering>(patterns.getContext ());
248+ vector::populateVectorTransferLoweringPatterns (patterns);
247249}
248250
249251struct AmdgpuTransferReadToLoadPass final
@@ -252,6 +254,8 @@ struct AmdgpuTransferReadToLoadPass final
252254 void runOnOperation () override {
253255 RewritePatternSet patterns (&getContext ());
254256 populateAmdgpuTransferReadToLoadPatterns (patterns);
255- walkAndApplyPatterns (getOperation (), std::move (patterns));
257+ if (failed (applyPatternsGreedily (getOperation (), std::move (patterns)))) {
258+ return signalPassFailure ();
259+ }
256260 }
257261};
0 commit comments