@@ -234,41 +234,50 @@ class Pattern {
234234// RewritePattern
235235// ===----------------------------------------------------------------------===//
236236
237- // / RewritePattern is the common base class for all DAG to DAG replacements.
238- // / There are two possible usages of this class:
239- // / * Multi-step RewritePattern with "match" and "rewrite"
240- // / - By overloading the "match" and "rewrite" functions, the user can
241- // / separate the concerns of matching and rewriting.
242- // / * Single-step RewritePattern with "matchAndRewrite"
243- // / - By overloading the "matchAndRewrite" function, the user can perform
244- // / the rewrite in the same call as the match.
245- // /
246- class RewritePattern : public Pattern {
247- public:
248- virtual ~RewritePattern () = default ;
237+ namespace detail {
238+ // / Helper class that derives from a RewritePattern class and provides separate
239+ // / `match` and `rewrite` entry points instead of a combined `matchAndRewrite`.
240+ template <typename PatternT>
241+ class SplitMatchAndRewriteImpl : public PatternT {
242+ using PatternT::PatternT;
249243
250244 // / Rewrite the IR rooted at the specified operation with the result of
251245 // / this pattern, generating any new operations with the specified
252- // / builder. If an unexpected error is encountered (an internal
253- // / compiler error), it is emitted through the normal MLIR diagnostic
254- // / hooks and the IR is left in a valid state.
255- virtual void rewrite (Operation *op, PatternRewriter &rewriter) const ;
246+ // / rewriter.
247+ virtual void rewrite (typename PatternT::OperationT op,
248+ PatternRewriter &rewriter) const = 0;
256249
257250 // / Attempt to match against code rooted at the specified operation,
258251 // / which is the same operation code as getRootKind().
259- virtual LogicalResult match (Operation * op) const ;
252+ virtual LogicalResult match (typename PatternT::OperationT op) const = 0 ;
260253
261- // / Attempt to match against code rooted at the specified operation,
262- // / which is the same operation code as getRootKind(). If successful, this
263- // / function will automatically perform the rewrite.
264- virtual LogicalResult matchAndRewrite (Operation *op,
265- PatternRewriter &rewriter) const {
254+ LogicalResult matchAndRewrite (typename PatternT::OperationT op,
255+ PatternRewriter &rewriter) const final {
266256 if (succeeded (match (op))) {
267257 rewrite (op, rewriter);
268258 return success ();
269259 }
270260 return failure ();
271261 }
262+ };
263+ } // namespace detail
264+
265+ // / RewritePattern is the common base class for all DAG to DAG replacements.
266+ // / By overloading the "matchAndRewrite" function, the user can perform the
267+ // / rewrite in the same call as the match.
268+ // /
269+ class RewritePattern : public Pattern {
270+ public:
271+ using OperationT = Operation *;
272+ using SplitMatchAndRewrite = detail::SplitMatchAndRewriteImpl<RewritePattern>;
273+
274+ virtual ~RewritePattern () = default ;
275+
276+ // / Attempt to match against code rooted at the specified operation,
277+ // / which is the same operation code as getRootKind(). If successful, this
278+ // / function will automatically perform the rewrite.
279+ virtual LogicalResult matchAndRewrite (Operation *op,
280+ PatternRewriter &rewriter) const = 0;
272281
273282 // / This method provides a convenient interface for creating and initializing
274283 // / derived rewrite patterns of the given type `T`.
@@ -317,36 +326,19 @@ namespace detail {
317326// / class or Interface.
318327template <typename SourceOp>
319328struct OpOrInterfaceRewritePatternBase : public RewritePattern {
329+ using OperationT = SourceOp;
320330 using RewritePattern::RewritePattern;
321331
322- // / Wrappers around the RewritePattern methods that pass the derived op type.
323- void rewrite (Operation *op, PatternRewriter &rewriter) const final {
324- rewrite (cast<SourceOp>(op), rewriter);
325- }
326- LogicalResult match (Operation *op) const final {
327- return match (cast<SourceOp>(op));
328- }
332+ // / Wrapper around the RewritePattern method that passes the derived op type.
329333 LogicalResult matchAndRewrite (Operation *op,
330334 PatternRewriter &rewriter) const final {
331335 return matchAndRewrite (cast<SourceOp>(op), rewriter);
332336 }
333337
334- // / Rewrite and Match methods that operate on the SourceOp type. These must be
335- // / overridden by the derived pattern class.
336- virtual void rewrite (SourceOp op, PatternRewriter &rewriter) const {
337- llvm_unreachable (" must override rewrite or matchAndRewrite" );
338- }
339- virtual LogicalResult match (SourceOp op) const {
340- llvm_unreachable (" must override match or matchAndRewrite" );
341- }
338+ // / Method that operates on the SourceOp type. Must be overridden by the
339+ // / derived pattern class.
342340 virtual LogicalResult matchAndRewrite (SourceOp op,
343- PatternRewriter &rewriter) const {
344- if (succeeded (match (op))) {
345- rewrite (op, rewriter);
346- return success ();
347- }
348- return failure ();
349- }
341+ PatternRewriter &rewriter) const = 0;
350342};
351343} // namespace detail
352344
@@ -356,6 +348,9 @@ struct OpOrInterfaceRewritePatternBase : public RewritePattern {
356348template <typename SourceOp>
357349struct OpRewritePattern
358350 : public detail::OpOrInterfaceRewritePatternBase<SourceOp> {
351+ using SplitMatchAndRewrite =
352+ detail::SplitMatchAndRewriteImpl<OpRewritePattern<SourceOp>>;
353+
359354 // / Patterns must specify the root operation name they match against, and can
360355 // / also specify the benefit of the pattern matching and a list of generated
361356 // / ops.
@@ -371,6 +366,9 @@ struct OpRewritePattern
371366template <typename SourceOp>
372367struct OpInterfaceRewritePattern
373368 : public detail::OpOrInterfaceRewritePatternBase<SourceOp> {
369+ using SplitMatchAndRewrite =
370+ detail::SplitMatchAndRewriteImpl<OpInterfaceRewritePattern<SourceOp>>;
371+
374372 OpInterfaceRewritePattern (MLIRContext *context, PatternBenefit benefit = 1 )
375373 : detail::OpOrInterfaceRewritePatternBase<SourceOp>(
376374 Pattern::MatchInterfaceOpTypeTag (), SourceOp::getInterfaceID(),
0 commit comments