diff --git a/mlir/lib/TableGen/Pattern.cpp b/mlir/lib/TableGen/Pattern.cpp index d83df3e415c36..ab605391faf6a 100644 --- a/mlir/lib/TableGen/Pattern.cpp +++ b/mlir/lib/TableGen/Pattern.cpp @@ -303,6 +303,12 @@ std::string SymbolInfoMap::SymbolInfo::getValueAndRangeUse( case Kind::Operand: { assert(index < 0); auto *operand = cast(op->getArg(getArgIndex())); + if (operand->isOptional()) { + auto repl = + formatv(fmt, formatv("({0}.empty() ? Value() : *{0}.begin())", name)); + LLVM_DEBUG(dbgs() << repl << " (OptionalOperand)\n"); + return std::string(repl); + } // If this operand is variadic and this SymbolInfo doesn't have a range // index, then return the full variadic operand_range. Otherwise, return // the value itself. diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index 85a49e05d4c73..3e461999e2730 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -1850,6 +1850,20 @@ def : Pat< (MixedVOperandOp5 $input2a, $input2b, $input1b, $attr1, ConstantStrAttr)>; +def MixedVOperandOp7 : TEST_Op<"mixed_variadic_optional_in7", + [AttrSizedOperandSegments]> { + let arguments = (ins + Variadic:$input1, + Optional:$input2, + I32Attr:$attr1 + ); +} + +def : Pat< + (MixedVOperandOp7 $input1, $input2, ConstantAttr:$attr1), + (MixedVOperandOp6 $input1, (variadic $input2), $attr1), + [(Constraint> $input2)]>; + //===----------------------------------------------------------------------===// // Test Patterns (either) //===----------------------------------------------------------------------===// diff --git a/mlir/test/mlir-tblgen/pattern.mlir b/mlir/test/mlir-tblgen/pattern.mlir index 60d46e676d2a3..90905280c0796 100644 --- a/mlir/test/mlir-tblgen/pattern.mlir +++ b/mlir/test/mlir-tblgen/pattern.mlir @@ -584,6 +584,16 @@ func.func @testMatchMultiVariadicSubSymbol(%arg0: i32, %arg1: i32, %arg2: i32, % return } +// CHECK-LABEL: @testMatchMixedVaradicOptional +func.func @testMatchMixedVaradicOptional(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32) -> () { + // CHECK: "test.mixed_variadic_in6"(%arg0, %arg1, %arg2) <{attr1 = 2 : i32}> : (i32, i32, i32) -> () + "test.mixed_variadic_optional_in7"(%arg0, %arg1, %arg2) {attr1 = 2 : i32, operandSegmentSizes = array} : (i32, i32, i32) -> () + // CHECK: test.mixed_variadic_optional_in7 + "test.mixed_variadic_optional_in7"(%arg0, %arg1) {attr1 = 2 : i32, operandSegmentSizes = array} : (i32, i32) -> () + + return +} + //===----------------------------------------------------------------------===// // Test that natives calls are only called once during rewrites. //===----------------------------------------------------------------------===//