1515#include " mlir/Dialect/MemRef/Utils/MemRefUtils.h"
1616#include " mlir/Dialect/SCF/IR/SCF.h"
1717#include " mlir/Dialect/Vector/IR/VectorOps.h"
18+ #include " mlir/Dialect/Vector/Transforms/VectorTransforms.h"
1819#include " mlir/IR/BuiltinTypes.h"
1920#include " mlir/IR/OpDefinition.h"
2021#include " mlir/IR/PatternMatch.h"
@@ -52,42 +53,24 @@ static LogicalResult baseInBufferAddrSpace(PatternRewriter &rewriter,
5253}
5354
5455static Value createVectorLoadForMaskedLoad (OpBuilder &builder, Location loc,
55- vector::MaskedLoadOp maskedOp) {
56+ vector::MaskedLoadOp maskedOp,
57+ bool passthru) {
5658 VectorType vectorType = maskedOp.getVectorType ();
5759 Value load = builder.create <vector::LoadOp>(
5860 loc, vectorType, maskedOp.getBase (), maskedOp.getIndices ());
59- Value res = builder.create <arith::SelectOp>(
60- loc, vectorType, maskedOp.getMask (), load, maskedOp.getPassThru ());
61- return res;
61+ if (passthru)
62+ load = builder.create <arith::SelectOp>(loc, vectorType, maskedOp.getMask (),
63+ load, maskedOp.getPassThru ());
64+ return load;
6265}
6366
64- // / Check if the given value comes from a:
65- // /
66- // / arith.select %cond, TRUE/FALSE, TRUE/FALSE
67- // /
68- // / i.e the condition is either always true or it's always false.
69- // /
70- // / Returns the condition to use for scf.if (condition) { true } else { false }.
71- static FailureOr<Value> matchFullSelect (OpBuilder &b, Value val) {
72- auto selectOp = val.getDefiningOp <arith::SelectOp>();
73- if (!selectOp)
74- return failure ();
75- std::optional<int64_t > trueInt = getConstantIntValue (selectOp.getTrueValue ());
76- std::optional<int64_t > falseInt =
77- getConstantIntValue (selectOp.getFalseValue ());
78- if (!trueInt || !falseInt)
67+ // / Check if the given value comes from a broadcasted i1 condition.
68+ static FailureOr<Value> matchFullMask (OpBuilder &b, Value val) {
69+ auto broadcastOp = val.getDefiningOp <vector::BroadcastOp>();
70+ if (!broadcastOp)
7971 return failure ();
80- // getConstantIntValue returns -1 for "true" for bools.
81- if (trueInt.value () == -1 && falseInt.value () == 0 )
82- return selectOp.getCondition ();
83-
84- if (trueInt.value () == 0 && falseInt.value () == -1 ) {
85- Value cond = selectOp.getCondition ();
86- Value one = b.create <arith::ConstantIntOp>(cond.getLoc (), /* value=*/ true ,
87- /* width=*/ 1 );
88- Value inverse = b.create <arith::XOrIOp>(cond.getLoc (), cond, one);
89- return inverse;
90- }
72+ if (!isa<VectorType>(broadcastOp.getSourceType ()))
73+ return broadcastOp.getSource ();
9174 return failure ();
9275}
9376
@@ -109,11 +92,11 @@ struct MaskedLoadLowering final : OpRewritePattern<vector::MaskedLoadOp> {
10992 }
11093
11194 // Check if this is either a full inbounds load or an empty, oob load. If
112- // so, take the fast path and don't generate a if condition, because we know
113- // doing the oob load is always safe.
114- if (succeeded (matchFullSelect (rewriter, maskedOp.getMask ()))) {
115- Value load =
116- createVectorLoadForMaskedLoad (rewriter, maskedOp. getLoc (), maskedOp );
95+ // so, take the fast path and don't generate an if condition, because we
96+ // know doing the oob load is always safe.
97+ if (succeeded (matchFullMask (rewriter, maskedOp.getMask ()))) {
98+ Value load = createVectorLoadForMaskedLoad (rewriter, maskedOp. getLoc (),
99+ maskedOp, /* passthru= */ true );
117100 rewriter.replaceOp (maskedOp, load);
118101 return success ();
119102 }
@@ -175,7 +158,8 @@ struct MaskedLoadLowering final : OpRewritePattern<vector::MaskedLoadOp> {
175158 };
176159
177160 auto elseBuilder = [&](OpBuilder &builder, Location loc) {
178- Value res = createVectorLoadForMaskedLoad (builder, loc, maskedOp);
161+ Value res = createVectorLoadForMaskedLoad (builder, loc, maskedOp,
162+ /* passthru=*/ true );
179163 rewriter.create <scf::YieldOp>(loc, res);
180164 };
181165
@@ -192,17 +176,17 @@ struct FullMaskedLoadToConditionalLoad
192176 : OpRewritePattern<vector::MaskedLoadOp> {
193177 using OpRewritePattern::OpRewritePattern;
194178
195- public:
196179 LogicalResult matchAndRewrite (vector::MaskedLoadOp loadOp,
197180 PatternRewriter &rewriter) const override {
198- FailureOr<Value> maybeCond = matchFullSelect (rewriter, loadOp.getMask ());
181+ FailureOr<Value> maybeCond = matchFullMask (rewriter, loadOp.getMask ());
199182 if (failed (maybeCond)) {
200183 return failure ();
201184 }
202185
203186 Value cond = maybeCond.value ();
204187 auto trueBuilder = [&](OpBuilder &builder, Location loc) {
205- Value res = createVectorLoadForMaskedLoad (builder, loc, loadOp);
188+ Value res = createVectorLoadForMaskedLoad (builder, loc, loadOp,
189+ /* passthru=*/ false );
206190 rewriter.create <scf::YieldOp>(loc, res);
207191 };
208192 auto falseBuilder = [&](OpBuilder &builder, Location loc) {
@@ -219,10 +203,9 @@ struct FullMaskedStoreToConditionalStore
219203 : OpRewritePattern<vector::MaskedStoreOp> {
220204 using OpRewritePattern::OpRewritePattern;
221205
222- public:
223206 LogicalResult matchAndRewrite (vector::MaskedStoreOp storeOp,
224207 PatternRewriter &rewriter) const override {
225- FailureOr<Value> maybeCond = matchFullSelect (rewriter, storeOp.getMask ());
208+ FailureOr<Value> maybeCond = matchFullMask (rewriter, storeOp.getMask ());
226209 if (failed (maybeCond)) {
227210 return failure ();
228211 }
0 commit comments