@@ -61,6 +61,36 @@ static Value createVectorLoadForMaskedLoad(OpBuilder &builder, Location loc,
6161 return res;
6262}
6363
64+ // / Check if the given value comes from a:
65+ // /
66+ // / arith.select %cond, TRUE/FALSE, TRUE/FALSE
67+ // /
68+ // / i.e the condition is either always true or it's always false.
69+ // /
70+ // / Returns the condition to use for scf.if (condition) { true } else { false }.
71+ static FailureOr<Value> matchFullSelect (OpBuilder &b, Value val) {
72+ auto selectOp = val.getDefiningOp <arith::SelectOp>();
73+ if (!selectOp)
74+ return failure ();
75+ std::optional<int64_t > trueInt = getConstantIntValue (selectOp.getTrueValue ());
76+ std::optional<int64_t > falseInt =
77+ getConstantIntValue (selectOp.getFalseValue ());
78+ if (!trueInt || !falseInt)
79+ return failure ();
80+ // getConstantIntValue returns -1 for "true" for bools.
81+ if (trueInt.value () == -1 && falseInt.value () == 0 )
82+ return selectOp.getCondition ();
83+
84+ if (trueInt.value () == 0 && falseInt.value () == -1 ) {
85+ Value cond = selectOp.getCondition ();
86+ Value one = b.create <arith::ConstantIntOp>(cond.getLoc (), /* value=*/ true ,
87+ /* width=*/ 1 );
88+ Value inverse = b.create <arith::XOrIOp>(cond.getLoc (), cond, one);
89+ return inverse;
90+ }
91+ return failure ();
92+ }
93+
6494static constexpr char kMaskedloadNeedsMask [] =
6595 " amdgpu.buffer_maskedload_needs_mask" ;
6696
@@ -78,6 +108,16 @@ struct MaskedLoadLowering final : OpRewritePattern<vector::MaskedLoadOp> {
78108 return failure ();
79109 }
80110
111+ // Check if this is either a full inbounds load or an empty, oob load. If
112+ // so, take the fast path and don't generate a if condition, because we know
113+ // doing the oob load is always safe.
114+ if (succeeded (matchFullSelect (rewriter, maskedOp.getMask ()))) {
115+ Value load =
116+ createVectorLoadForMaskedLoad (rewriter, maskedOp.getLoc (), maskedOp);
117+ rewriter.replaceOp (maskedOp, load);
118+ return success ();
119+ }
120+
81121 Location loc = maskedOp.getLoc ();
82122 Value src = maskedOp.getBase ();
83123
@@ -148,11 +188,64 @@ struct MaskedLoadLowering final : OpRewritePattern<vector::MaskedLoadOp> {
148188 }
149189};
150190
191+ struct FullMaskedLoadToConditionalLoad
192+ : OpRewritePattern<vector::MaskedLoadOp> {
193+ using OpRewritePattern::OpRewritePattern;
194+
195+ public:
196+ LogicalResult matchAndRewrite (vector::MaskedLoadOp loadOp,
197+ PatternRewriter &rewriter) const override {
198+ FailureOr<Value> maybeCond = matchFullSelect (rewriter, loadOp.getMask ());
199+ if (failed (maybeCond)) {
200+ return failure ();
201+ }
202+
203+ Value cond = maybeCond.value ();
204+ auto trueBuilder = [&](OpBuilder &builder, Location loc) {
205+ Value res = createVectorLoadForMaskedLoad (builder, loc, loadOp);
206+ rewriter.create <scf::YieldOp>(loc, res);
207+ };
208+ auto falseBuilder = [&](OpBuilder &builder, Location loc) {
209+ rewriter.create <scf::YieldOp>(loc, loadOp.getPassThru ());
210+ };
211+ auto ifOp = rewriter.create <scf::IfOp>(loadOp.getLoc (), cond, trueBuilder,
212+ falseBuilder);
213+ rewriter.replaceOp (loadOp, ifOp);
214+ return success ();
215+ }
216+ };
217+
218+ struct FullMaskedStoreToConditionalStore
219+ : OpRewritePattern<vector::MaskedStoreOp> {
220+ using OpRewritePattern::OpRewritePattern;
221+
222+ public:
223+ LogicalResult matchAndRewrite (vector::MaskedStoreOp storeOp,
224+ PatternRewriter &rewriter) const override {
225+ FailureOr<Value> maybeCond = matchFullSelect (rewriter, storeOp.getMask ());
226+ if (failed (maybeCond)) {
227+ return failure ();
228+ }
229+ Value cond = maybeCond.value ();
230+
231+ auto trueBuilder = [&](OpBuilder &builder, Location loc) {
232+ rewriter.create <vector::StoreOp>(loc, storeOp.getValueToStore (),
233+ storeOp.getBase (), storeOp.getIndices ());
234+ rewriter.create <scf::YieldOp>(loc);
235+ };
236+ auto ifOp = rewriter.create <scf::IfOp>(storeOp.getLoc (), cond, trueBuilder);
237+ rewriter.replaceOp (storeOp, ifOp);
238+ return success ();
239+ }
240+ };
241+
151242} // namespace
152243
153244void mlir::amdgpu::populateAmdgpuMaskedloadToLoadPatterns (
154245 RewritePatternSet &patterns, PatternBenefit benefit) {
155- patterns.add <MaskedLoadLowering>(patterns.getContext (), benefit);
246+ patterns.add <MaskedLoadLowering, FullMaskedLoadToConditionalLoad,
247+ FullMaskedStoreToConditionalStore>(patterns.getContext (),
248+ benefit);
156249}
157250
158251struct AmdgpuMaskedloadToLoadPass final
0 commit comments