1515#include " mlir/IR/TypeUtilities.h"
1616#include " mlir/Pass/Pass.h"
1717#include " mlir/Support/LogicalResult.h"
18- #include " mlir/Transforms/GreedyPatternRewriteDriver .h"
18+ #include " mlir/Transforms/WalkPatternRewriteDriver .h"
1919
2020namespace mlir {
2121#define GEN_PASS_DEF_CONVERTVECTORTOAMDGPUPASS
@@ -36,17 +36,16 @@ using namespace mlir;
3636// / - The permutation map doesn't perform permutation (broadcasting is allowed).
3737// / Note: those conditions mostly come from TransferReadToVectorLoadLowering
3838// / pass.
39- static LogicalResult
40- transferPreconditions (PatternRewriter &rewriter,
41- VectorTransferOpInterface xferOp,
42- SmallVector<unsigned > &broadcastedDims,
43- VectorType &unbroadcastedVectorType) {
39+ static LogicalResult transferPreconditions (
40+ PatternRewriter &rewriter, VectorTransferOpInterface xferOp,
41+ bool &requiresBroadcasting, VectorType &unbroadcastedVectorType) {
4442 if (!xferOp.getMask ())
4543 return rewriter.notifyMatchFailure (xferOp, " Only support masked transfer" );
4644
4745 // Permutations are handled by VectorToSCF or
4846 // populateVectorTransferPermutationMapLoweringPatterns.
4947 // We let the 0-d corner case pass-through as it is supported.
48+ SmallVector<unsigned > broadcastedDims;
5049 if (!xferOp.getPermutationMap ().isMinorIdentityWithBroadcasting (
5150 &broadcastedDims))
5251 return rewriter.notifyMatchFailure (xferOp, " not minor identity + bcast" );
@@ -56,9 +55,8 @@ transferPreconditions(PatternRewriter &rewriter,
5655 return rewriter.notifyMatchFailure (xferOp, " not a memref source" );
5756
5857 Attribute addrSpace = memRefType.getMemorySpace ();
59- if (!addrSpace ||
60- llvm::dyn_cast<amdgpu::AddressSpaceAttr>(addrSpace).getValue () !=
61- amdgpu::AddressSpace::FatRawBuffer)
58+ if (!addrSpace || dyn_cast<amdgpu::AddressSpaceAttr>(addrSpace).getValue () !=
59+ amdgpu::AddressSpace::FatRawBuffer)
6260 return rewriter.notifyMatchFailure (xferOp, " not in buffer address space" );
6361
6462 // Non-unit strides are handled by VectorToSCF.
@@ -73,6 +71,7 @@ transferPreconditions(PatternRewriter &rewriter,
7371 unbroadcastedVectorShape[i] = 1 ;
7472 unbroadcastedVectorType = xferOp.getVectorType ().cloneWith (
7573 unbroadcastedVectorShape, xferOp.getVectorType ().getElementType ());
74+ requiresBroadcasting = !broadcastedDims.empty ();
7675
7776 // `vector.load` supports vector types as memref's elements only when the
7877 // resulting vector type is the same as the element type.
@@ -98,31 +97,31 @@ transferPreconditions(PatternRewriter &rewriter,
9897 return success ();
9998}
10099
101- struct TransferReadLowering : public OpRewritePattern <vector::TransferReadOp> {
102- using OpRewritePattern<vector::TransferReadOp> ::OpRewritePattern;
100+ struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
101+ using OpRewritePattern::OpRewritePattern;
103102
104103 LogicalResult matchAndRewrite (vector::TransferReadOp readOp,
105104 PatternRewriter &rewriter) const override {
106105
107- SmallVector< unsigned > broadcastedDims ;
106+ bool requiresBroadcasting = false ;
108107 VectorType unbroadcastedVectorType;
109- if (failed (transferPreconditions (rewriter, readOp, broadcastedDims ,
108+ if (failed (transferPreconditions (rewriter, readOp, requiresBroadcasting ,
110109 unbroadcastedVectorType))) {
111110 return failure ();
112111 }
113112
114- Value fill = rewriter.create <vector::SplatOp>(
115- readOp.getLoc (), unbroadcastedVectorType, readOp.getPadding ());
113+ Location loc = readOp.getLoc ();
114+ Value fill = rewriter.create <vector::SplatOp>(loc, unbroadcastedVectorType,
115+ readOp.getPadding ());
116116 Value load = rewriter.create <vector::LoadOp>(
117- readOp.getLoc (), unbroadcastedVectorType, readOp.getSource (),
118- readOp.getIndices ());
119- Value res = rewriter.create <arith::SelectOp>(
120- readOp.getLoc (), unbroadcastedVectorType, readOp.getMask (), load, fill);
117+ loc, unbroadcastedVectorType, readOp.getSource (), readOp.getIndices ());
118+ Value res = rewriter.create <arith::SelectOp>(loc, unbroadcastedVectorType,
119+ readOp.getMask (), load, fill);
121120
122121 // Insert a broadcasting op if required.
123- if (!broadcastedDims. empty () ) {
124- res = rewriter.create <vector::BroadcastOp>(readOp.getLoc (),
125- readOp. getVectorType (), res);
122+ if (requiresBroadcasting ) {
123+ res = rewriter.create <vector::BroadcastOp>(loc, readOp.getVectorType (),
124+ res);
126125 }
127126
128127 rewriter.replaceOp (readOp, res);
@@ -136,12 +135,11 @@ void mlir::populateVectorToAMDGPUConversionPatterns(
136135 patterns.add <TransferReadLowering>(patterns.getContext ());
137136}
138137
139- struct ConvertVectorToAMDGPUPass
138+ struct ConvertVectorToAMDGPUPass final
140139 : public impl::ConvertVectorToAMDGPUPassBase<ConvertVectorToAMDGPUPass> {
141140 void runOnOperation () override {
142141 RewritePatternSet patterns (&getContext ());
143142 populateVectorToAMDGPUConversionPatterns (patterns);
144- if (failed (applyPatternsGreedily (getOperation (), std::move (patterns))))
145- return signalPassFailure ();
143+ walkAndApplyPatterns (getOperation (), std::move (patterns));
146144 }
147145};
0 commit comments