diff --git a/mlir/test/mlir-tblgen/op-result.td b/mlir/test/mlir-tblgen/op-result.td index a4f7af6dbcf1c..7f882ce0dfce4 100644 --- a/mlir/test/mlir-tblgen/op-result.td +++ b/mlir/test/mlir-tblgen/op-result.td @@ -136,9 +136,8 @@ def OpL1 : NS_Op<"op_with_all_types_constraint", // CHECK-LABEL: LogicalResult OpL1::inferReturnTypes // CHECK-NOT: } -// CHECK: if (operands.size() <= 0) -// CHECK-NEXT: return ::mlir::failure(); -// CHECK: ::mlir::Type odsInferredType0 = operands[0].getType(); +// CHECK: OpL1::Adaptor adaptor +// CHECK: ::mlir::Type odsInferredType0 = adaptor.getA().getType(); // CHECK: inferredReturnTypes[0] = odsInferredType0; def OpL2 : NS_Op<"op_with_all_types_constraint", @@ -149,11 +148,9 @@ def OpL2 : NS_Op<"op_with_all_types_constraint", // CHECK-LABEL: LogicalResult OpL2::inferReturnTypes // CHECK-NOT: } -// CHECK: if (operands.size() <= 2) -// CHECK-NEXT: return ::mlir::failure(); -// CHECK-NOT: if (operands.size() <= 0) -// CHECK: ::mlir::Type odsInferredType0 = operands[2].getType(); -// CHECK: ::mlir::Type odsInferredType1 = operands[0].getType(); +// CHECK: OpL2::Adaptor adaptor +// CHECK: ::mlir::Type odsInferredType0 = adaptor.getC().getType(); +// CHECK: ::mlir::Type odsInferredType1 = adaptor.getA().getType(); // CHECK: inferredReturnTypes[0] = odsInferredType0; // CHECK: inferredReturnTypes[1] = odsInferredType1; @@ -177,9 +174,8 @@ def OpL4 : NS_Op<"two_inference_edges", [ } // CHECK-LABEL: LogicalResult OpL4::inferReturnTypes -// CHECK: if (operands.size() <= 0) -// CHECK-NEXT: return ::mlir::failure(); -// CHECK: odsInferredType0 = fromInput(operands[0].getType()) +// CHECK: OpL4::Adaptor adaptor +// CHECK: odsInferredType0 = fromInput(adaptor.getInput().getType()) // CHECK: odsInferredType1 = infer0(odsInferredType0) // CHECK: odsInferredType2 = infer1(odsInferredType1) // CHECK: inferredReturnTypes[0] = odsInferredType0 @@ -207,6 +203,18 @@ def OpL6 : NS_Op<"op_with_same_and_constraint_results", // CHECK: inferredReturnTypes[1] = odsInferredType1; // CHECK: inferredReturnTypes[2] = odsInferredType2; +def OpL7 : NS_Op<"one_variadic_and_one_normal_operand_with_infer_result_op", + [TypesMatchWith<"", "input2", "output1", "infer0($_self)">]> { + let arguments = (ins Variadic:$input1, AnyTensor:$input2); + let results = (outs AnyTensor:$output1); +} + +// CHECK-LABEL: LogicalResult OpL7::inferReturnTypes +// CHECK-NOT: } +// CHECK: OpL7::Adaptor adaptor +// CHECK: odsInferredType0 = infer0(adaptor.getInput2().getType()) +// CHECK: inferredReturnTypes[0] = odsInferredType0 + def OpM : NS_Op<"mix_diff_size_variadic_and_normal_results_op", [AttrSizedResultSegments]> { let results = (outs Variadic:$output1, AnyTensor:$output2, Optional:$output3); } diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp index b957c8ee9f8ab..f61129c234ddf 100644 --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -2641,8 +2641,7 @@ void OpEmitter::genSeparateArgParamBuilder() { // Avoid emitting "resultTypes.size() >= 0u" which is always true. if (!hasVariadicResult || numNonVariadicResults != 0) - body << " " - << "assert(resultTypes.size() " + body << " " << "assert(resultTypes.size() " << (hasVariadicResult ? ">=" : "==") << " " << numNonVariadicResults << "u && \"mismatched number of results\");\n"; @@ -3751,29 +3750,15 @@ void OpEmitter::genTypeInterfaceMethods() { fctx.addSubst("_ctxt", "context"); body << " ::mlir::Builder odsBuilder(context);\n"; - // Preprocessing stage to verify all accesses to operands are valid. - int maxAccessedIndex = -1; - for (int i = 0, e = op.getNumResults(); i != e; ++i) { - const InferredResultType &infer = op.getInferredResultType(i); - if (!infer.isArg()) - continue; - Operator::OperandOrAttribute arg = - op.getArgToOperandOrAttribute(infer.getIndex()); - if (arg.kind() == Operator::OperandOrAttribute::Kind::Operand) { - maxAccessedIndex = - std::max(maxAccessedIndex, arg.operandOrAttributeIndex()); - } - } - if (maxAccessedIndex != -1) { - body << " if (operands.size() <= " << Twine(maxAccessedIndex) << ")\n"; - body << " return ::mlir::failure();\n"; - } + // Emit an adaptor to access right ranges for ods operands. + body << " " << op.getCppClassName() + << "::Adaptor adaptor(operands, attributes, properties, regions);\n"; - // Process the type inference graph in topological order, starting from types - // that are always fully-inferred: operands and results with constructible - // types. The type inference graph here will always be a DAG, so this gives - // us the correct order for generating the types. -1 is a placeholder to - // indicate the type for a result has not been generated. + // Process the type inference graph in topological order, starting from + // types that are always fully-inferred: operands and results with + // constructible types. The type inference graph here will always be a + // DAG, so this gives us the correct order for generating the types. -1 is + // a placeholder to indicate the type for a result has not been generated. SmallVector constructedIndices(op.getNumResults(), -1); int inferredTypeIdx = 0; for (int numResults = op.getNumResults(); inferredTypeIdx != numResults;) { @@ -3788,10 +3773,11 @@ void OpEmitter::genTypeInterfaceMethods() { Operator::OperandOrAttribute arg = op.getArgToOperandOrAttribute(infer.getIndex()); if (arg.kind() == Operator::OperandOrAttribute::Kind::Operand) { - typeStr = ("operands[" + Twine(arg.operandOrAttributeIndex()) + - "].getType()") - .str(); - + std::string getter = + "adaptor." + + op.getGetterName( + op.getOperand(arg.operandOrAttributeIndex()).name); + typeStr = (getter + "().getType()"); // If this is an attribute, index into the attribute dictionary. } else { auto *attr =