1515#include " mlir/Dialect/MemRef/Utils/MemRefUtils.h"
1616#include " mlir/Dialect/SCF/IR/SCF.h"
1717#include " mlir/Dialect/Vector/IR/VectorOps.h"
18+ #include " mlir/Dialect/Vector/Transforms/VectorTransforms.h"
1819#include " mlir/IR/BuiltinTypes.h"
1920#include " mlir/IR/OpDefinition.h"
2021#include " mlir/IR/PatternMatch.h"
@@ -52,13 +53,25 @@ static LogicalResult baseInBufferAddrSpace(PatternRewriter &rewriter,
5253}
5354
5455static Value createVectorLoadForMaskedLoad (OpBuilder &builder, Location loc,
55- vector::MaskedLoadOp maskedOp) {
56+ vector::MaskedLoadOp maskedOp,
57+ bool passthru) {
5658 VectorType vectorType = maskedOp.getVectorType ();
5759 Value load = builder.create <vector::LoadOp>(
5860 loc, vectorType, maskedOp.getBase (), maskedOp.getIndices ());
59- Value res = builder.create <arith::SelectOp>(
60- loc, vectorType, maskedOp.getMask (), load, maskedOp.getPassThru ());
61- return res;
61+ if (passthru)
62+ load = builder.create <arith::SelectOp>(loc, vectorType, maskedOp.getMask (),
63+ load, maskedOp.getPassThru ());
64+ return load;
65+ }
66+
67+ // / Check if the given value comes from a broadcasted i1 condition.
68+ static FailureOr<Value> matchFullMask (OpBuilder &b, Value val) {
69+ auto broadcastOp = val.getDefiningOp <vector::BroadcastOp>();
70+ if (!broadcastOp)
71+ return failure ();
72+ if (isa<VectorType>(broadcastOp.getSourceType ()))
73+ return failure ();
74+ return broadcastOp.getSource ();
6275}
6376
6477static constexpr char kMaskedloadNeedsMask [] =
@@ -78,6 +91,16 @@ struct MaskedLoadLowering final : OpRewritePattern<vector::MaskedLoadOp> {
7891 return failure ();
7992 }
8093
94+ // Check if this is either a full inbounds load or an empty, oob load. If
95+ // so, take the fast path and don't generate an if condition, because we
96+ // know doing the oob load is always safe.
97+ if (succeeded (matchFullMask (rewriter, maskedOp.getMask ()))) {
98+ Value load = createVectorLoadForMaskedLoad (rewriter, maskedOp.getLoc (),
99+ maskedOp, /* passthru=*/ true );
100+ rewriter.replaceOp (maskedOp, load);
101+ return success ();
102+ }
103+
81104 Location loc = maskedOp.getLoc ();
82105 Value src = maskedOp.getBase ();
83106
@@ -135,7 +158,8 @@ struct MaskedLoadLowering final : OpRewritePattern<vector::MaskedLoadOp> {
135158 };
136159
137160 auto elseBuilder = [&](OpBuilder &builder, Location loc) {
138- Value res = createVectorLoadForMaskedLoad (builder, loc, maskedOp);
161+ Value res = createVectorLoadForMaskedLoad (builder, loc, maskedOp,
162+ /* passthru=*/ true );
139163 rewriter.create <scf::YieldOp>(loc, res);
140164 };
141165
@@ -148,11 +172,63 @@ struct MaskedLoadLowering final : OpRewritePattern<vector::MaskedLoadOp> {
148172 }
149173};
150174
175+ struct FullMaskedLoadToConditionalLoad
176+ : OpRewritePattern<vector::MaskedLoadOp> {
177+ using OpRewritePattern::OpRewritePattern;
178+
179+ LogicalResult matchAndRewrite (vector::MaskedLoadOp loadOp,
180+ PatternRewriter &rewriter) const override {
181+ FailureOr<Value> maybeCond = matchFullMask (rewriter, loadOp.getMask ());
182+ if (failed (maybeCond)) {
183+ return failure ();
184+ }
185+
186+ Value cond = maybeCond.value ();
187+ auto trueBuilder = [&](OpBuilder &builder, Location loc) {
188+ Value res = createVectorLoadForMaskedLoad (builder, loc, loadOp,
189+ /* passthru=*/ false );
190+ rewriter.create <scf::YieldOp>(loc, res);
191+ };
192+ auto falseBuilder = [&](OpBuilder &builder, Location loc) {
193+ rewriter.create <scf::YieldOp>(loc, loadOp.getPassThru ());
194+ };
195+ auto ifOp = rewriter.create <scf::IfOp>(loadOp.getLoc (), cond, trueBuilder,
196+ falseBuilder);
197+ rewriter.replaceOp (loadOp, ifOp);
198+ return success ();
199+ }
200+ };
201+
202+ struct FullMaskedStoreToConditionalStore
203+ : OpRewritePattern<vector::MaskedStoreOp> {
204+ using OpRewritePattern::OpRewritePattern;
205+
206+ LogicalResult matchAndRewrite (vector::MaskedStoreOp storeOp,
207+ PatternRewriter &rewriter) const override {
208+ FailureOr<Value> maybeCond = matchFullMask (rewriter, storeOp.getMask ());
209+ if (failed (maybeCond)) {
210+ return failure ();
211+ }
212+ Value cond = maybeCond.value ();
213+
214+ auto trueBuilder = [&](OpBuilder &builder, Location loc) {
215+ rewriter.create <vector::StoreOp>(loc, storeOp.getValueToStore (),
216+ storeOp.getBase (), storeOp.getIndices ());
217+ rewriter.create <scf::YieldOp>(loc);
218+ };
219+ auto ifOp = rewriter.create <scf::IfOp>(storeOp.getLoc (), cond, trueBuilder);
220+ rewriter.replaceOp (storeOp, ifOp);
221+ return success ();
222+ }
223+ };
224+
151225} // namespace
152226
153227void mlir::amdgpu::populateAmdgpuMaskedloadToLoadPatterns (
154228 RewritePatternSet &patterns, PatternBenefit benefit) {
155- patterns.add <MaskedLoadLowering>(patterns.getContext (), benefit);
229+ patterns.add <MaskedLoadLowering, FullMaskedLoadToConditionalLoad,
230+ FullMaskedStoreToConditionalStore>(patterns.getContext (),
231+ benefit);
156232}
157233
158234struct AmdgpuMaskedloadToLoadPass final
0 commit comments