@@ -112,6 +112,64 @@ SmallVector<OpFoldResult> getMixedSizesXfer(bool hasTensorSemantics,
112112 Operation *xfer,
113113 RewriterBase &rewriter);
114114
115+ // / A pattern for ops that implement `MaskableOpInterface` and that _might_ be
116+ // / masked (i.e. inside `vector.mask` Op region). In particular:
117+ // / 1. Matches `SourceOp` operation, Op.
118+ // / 2.1. If Op is masked, retrieves the masking Op, maskOp, and updates the
119+ // / insertion point to avoid inserting new ops into the `vector.mask` Op
120+ // / region (which only allows one Op).
121+ // / 2.2 If Op is not masked, this step is skipped.
122+ // / 3. Invokes `matchAndRewriteMaskableOp` on Op and optionally maskOp if
123+ // / found in step 2.1.
124+ // /
125+ // / This wrapper frees patterns from re-implementing the logic to update the
126+ // / insertion point when a maskable Op is masked. Such patterns are still
127+ // / responsible for providing an updated ("rewritten") version of:
128+ // / a. the source Op when mask _is not_ present,
129+ // / b. the source Op and the masking Op when mask _is_ present.
130+ // / Note that the return value from `matchAndRewriteMaskableOp` depends on the
131+ // / case above.
132+ template <class SourceOp >
133+ struct MaskableOpRewritePattern : OpRewritePattern<SourceOp> {
134+ using OpRewritePattern<SourceOp>::OpRewritePattern;
135+
136+ private:
137+ LogicalResult matchAndRewrite (SourceOp sourceOp,
138+ PatternRewriter &rewriter) const final {
139+ auto maskableOp = dyn_cast<MaskableOpInterface>(sourceOp.getOperation ());
140+ if (!maskableOp)
141+ return failure ();
142+
143+ Operation *rootOp = sourceOp;
144+
145+ // If this Op is masked, update the insertion point to avoid inserting into
146+ // the vector.mask Op region.
147+ OpBuilder::InsertionGuard guard (rewriter);
148+ MaskingOpInterface maskOp;
149+ if (maskableOp.isMasked ()) {
150+ maskOp = maskableOp.getMaskingOp ();
151+ rewriter.setInsertionPoint (maskOp);
152+ rootOp = maskOp;
153+ }
154+
155+ FailureOr<Value> newOp =
156+ matchAndRewriteMaskableOp (sourceOp, maskOp, rewriter);
157+ if (failed (newOp))
158+ return failure ();
159+
160+ rewriter.replaceOp (rootOp, *newOp);
161+ return success ();
162+ }
163+
164+ public:
165+ // Matches SourceOp that can potentially be masked with `maskingOp`. If the
166+ // latter is present, returns an updated masking op (with a replacement for
167+ // `sourceOp` nested inside). Otherwise, returns an updated `sourceOp`.
168+ virtual FailureOr<Value>
169+ matchAndRewriteMaskableOp (SourceOp sourceOp, MaskingOpInterface maskingOp,
170+ PatternRewriter &rewriter) const = 0 ;
171+ };
172+
115173} // namespace vector
116174
117175// / Constructs a permutation map of invariant memref indices to vector
0 commit comments