Skip to content

Commit c88f3c5

Browse files
authored
[mlir] Add base class type aliases for rewrites/conversions. NFC. (#158433)
This is to simplify writing rewrite/conversion patterns that usually start with: ```c++ struct MyPattern : public OpRewritePattern<MyOp> { using OpRewritePattern::OpRewritePattern; ``` and allow for: ```c++ struct MyPattern : public OpRewritePattern<MyOp> { using Base::Base; ``` similar to how we enable it for pass classes.
1 parent e299d9a commit c88f3c5

File tree

3 files changed

+34
-3
lines changed

3 files changed

+34
-3
lines changed

mlir/include/mlir/IR/PatternMatch.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,9 @@ struct OpOrInterfaceRewritePatternBase : public RewritePattern {
312312
template <typename SourceOp>
313313
struct OpRewritePattern
314314
: public mlir::detail::OpOrInterfaceRewritePatternBase<SourceOp> {
315+
/// Type alias to allow derived classes to inherit constructors with
316+
/// `using Base::Base;`.
317+
using Base = OpRewritePattern;
315318

316319
/// Patterns must specify the root operation name they match against, and can
317320
/// also specify the benefit of the pattern matching and a list of generated
@@ -328,6 +331,9 @@ struct OpRewritePattern
328331
template <typename SourceOp>
329332
struct OpInterfaceRewritePattern
330333
: public mlir::detail::OpOrInterfaceRewritePatternBase<SourceOp> {
334+
/// Type alias to allow derived classes to inherit constructors with
335+
/// `using Base::Base;`.
336+
using Base = OpInterfaceRewritePattern;
331337

332338
OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit = 1)
333339
: mlir::detail::OpOrInterfaceRewritePatternBase<SourceOp>(
@@ -341,6 +347,10 @@ struct OpInterfaceRewritePattern
341347
template <template <typename> class TraitType>
342348
class OpTraitRewritePattern : public RewritePattern {
343349
public:
350+
/// Type alias to allow derived classes to inherit constructors with
351+
/// `using Base::Base;`.
352+
using Base = OpTraitRewritePattern;
353+
344354
OpTraitRewritePattern(MLIRContext *context, PatternBenefit benefit = 1)
345355
: RewritePattern(Pattern::MatchTraitOpTypeTag(), TypeID::get<TraitType>(),
346356
benefit, context) {}

mlir/include/mlir/Transforms/DialectConversion.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,10 @@ class Value;
4040
/// registered using addConversion and addMaterialization, respectively.
4141
class TypeConverter {
4242
public:
43+
/// Type alias to allow derived classes to inherit constructors with
44+
/// `using Base::Base;`.
45+
using Base = TypeConverter;
46+
4347
virtual ~TypeConverter() = default;
4448
TypeConverter() = default;
4549
// Copy the registered conversions, but not the caches
@@ -679,6 +683,10 @@ class ConversionPattern : public RewritePattern {
679683
template <typename SourceOp>
680684
class OpConversionPattern : public ConversionPattern {
681685
public:
686+
/// Type alias to allow derived classes to inherit constructors with
687+
/// `using Base::Base;`.
688+
using Base = OpConversionPattern;
689+
682690
using OpAdaptor = typename SourceOp::Adaptor;
683691
using OneToNOpAdaptor =
684692
typename SourceOp::template GenericAdaptor<ArrayRef<ValueRange>>;
@@ -729,6 +737,10 @@ class OpConversionPattern : public ConversionPattern {
729737
template <typename SourceOp>
730738
class OpInterfaceConversionPattern : public ConversionPattern {
731739
public:
740+
/// Type alias to allow derived classes to inherit constructors with
741+
/// `using Base::Base;`.
742+
using Base = OpInterfaceConversionPattern;
743+
732744
OpInterfaceConversionPattern(MLIRContext *context, PatternBenefit benefit = 1)
733745
: ConversionPattern(Pattern::MatchInterfaceOpTypeTag(),
734746
SourceOp::getInterfaceID(), benefit, context) {}
@@ -773,6 +785,10 @@ class OpInterfaceConversionPattern : public ConversionPattern {
773785
template <template <typename> class TraitType>
774786
class OpTraitConversionPattern : public ConversionPattern {
775787
public:
788+
/// Type alias to allow derived classes to inherit constructors with
789+
/// `using Base::Base;`.
790+
using Base = OpTraitConversionPattern;
791+
776792
OpTraitConversionPattern(MLIRContext *context, PatternBenefit benefit = 1)
777793
: ConversionPattern(Pattern::MatchTraitOpTypeTag(),
778794
TypeID::get<TraitType>(), benefit, context) {}

mlir/test/lib/Dialect/Test/TestPatterns.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "mlir/Dialect/Tensor/IR/Tensor.h"
1818
#include "mlir/IR/BuiltinAttributes.h"
1919
#include "mlir/IR/Matchers.h"
20+
#include "mlir/IR/PatternMatch.h"
2021
#include "mlir/IR/Visitors.h"
2122
#include "mlir/Pass/Pass.h"
2223
#include "mlir/Transforms/DialectConversion.h"
@@ -114,7 +115,8 @@ struct FoldingPattern : public RewritePattern {
114115
struct FolderInsertBeforePreviouslyFoldedConstantPattern
115116
: public OpRewritePattern<TestCastOp> {
116117
public:
117-
using OpRewritePattern<TestCastOp>::OpRewritePattern;
118+
static_assert(std::is_same_v<Base, OpRewritePattern<TestCastOp>>);
119+
using Base::Base;
118120

119121
LogicalResult matchAndRewrite(TestCastOp op,
120122
PatternRewriter &rewriter) const override {
@@ -1306,7 +1308,8 @@ class TestReplaceWithValidConsumer : public ConversionPattern {
13061308
/// b) or: drops all block arguments and replaces each with 2x the first
13071309
/// operand.
13081310
class TestConvertBlockArgs : public OpConversionPattern<ConvertBlockArgsOp> {
1309-
using OpConversionPattern<ConvertBlockArgsOp>::OpConversionPattern;
1311+
static_assert(std::is_same_v<Base, OpConversionPattern<ConvertBlockArgsOp>>);
1312+
using Base::Base;
13101313

13111314
LogicalResult
13121315
matchAndRewrite(ConvertBlockArgsOp op, OpAdaptor adaptor,
@@ -1431,7 +1434,9 @@ class TestTypeConsumerOpPattern
14311434

14321435
namespace {
14331436
struct TestTypeConverter : public TypeConverter {
1434-
using TypeConverter::TypeConverter;
1437+
static_assert(std::is_same_v<Base, TypeConverter>);
1438+
using Base::Base;
1439+
14351440
TestTypeConverter() {
14361441
addConversion(convertType);
14371442
addSourceMaterialization(materializeCast);

0 commit comments

Comments
 (0)