diff --git a/mlir/test/mlir-tblgen/op-result.td b/mlir/test/mlir-tblgen/op-result.td index 0ca570cf8cafb..51f8b0671a328 100644 --- a/mlir/test/mlir-tblgen/op-result.td +++ b/mlir/test/mlir-tblgen/op-result.td @@ -130,6 +130,8 @@ def OpL1 : NS_Op<"op_with_all_types_constraint", // CHECK-LABEL: LogicalResult OpL1::inferReturnTypes // CHECK-NOT: } +// CHECK: if (operands.size() <= 0) +// CHECK-NEXT: return ::mlir::failure(); // CHECK: ::mlir::Type odsInferredType0 = operands[0].getType(); // CHECK: inferredReturnTypes[0] = odsInferredType0; @@ -141,6 +143,9 @@ def OpL2 : NS_Op<"op_with_all_types_constraint", // CHECK-LABEL: LogicalResult OpL2::inferReturnTypes // CHECK-NOT: } +// CHECK: if (operands.size() <= 2) +// CHECK-NEXT: return ::mlir::failure(); +// CHECK-NOT: if (operands.size() <= 0) // CHECK: ::mlir::Type odsInferredType0 = operands[2].getType(); // CHECK: ::mlir::Type odsInferredType1 = operands[0].getType(); // CHECK: inferredReturnTypes[0] = odsInferredType0; @@ -166,6 +171,8 @@ def OpL4 : NS_Op<"two_inference_edges", [ } // CHECK-LABEL: LogicalResult OpL4::inferReturnTypes +// CHECK: if (operands.size() <= 0) +// CHECK-NEXT: return ::mlir::failure(); // CHECK: odsInferredType0 = fromInput(operands[0].getType()) // CHECK: odsInferredType1 = infer0(odsInferredType0) // CHECK: odsInferredType2 = infer1(odsInferredType1) diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp index dea6fb209863c..9badb7aa163a6 100644 --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -3584,6 +3584,24 @@ void OpEmitter::genTypeInterfaceMethods() { fctx.addSubst("_ctxt", "context"); body << " ::mlir::Builder odsBuilder(context);\n"; + // Preprocessing stage to verify all accesses to operands are valid. + int maxAccessedIndex = -1; + for (int i = 0, e = op.getNumResults(); i != e; ++i) { + const InferredResultType &infer = op.getInferredResultType(i); + if (!infer.isArg()) + continue; + Operator::OperandOrAttribute arg = + op.getArgToOperandOrAttribute(infer.getIndex()); + if (arg.kind() == Operator::OperandOrAttribute::Kind::Operand) { + maxAccessedIndex = + std::max(maxAccessedIndex, arg.operandOrAttributeIndex()); + } + } + if (maxAccessedIndex != -1) { + body << " if (operands.size() <= " << Twine(maxAccessedIndex) << ")\n"; + body << " return ::mlir::failure();\n"; + } + // Process the type inference graph in topological order, starting from types // that are always fully-inferred: operands and results with constructible // types. The type inference graph here will always be a DAG, so this gives @@ -3600,7 +3618,8 @@ void OpEmitter::genTypeInterfaceMethods() { if (infer.isArg()) { // If this is an operand, just index into operand list to access the // type. - auto arg = op.getArgToOperandOrAttribute(infer.getIndex()); + Operator::OperandOrAttribute arg = + op.getArgToOperandOrAttribute(infer.getIndex()); if (arg.kind() == Operator::OperandOrAttribute::Kind::Operand) { typeStr = ("operands[" + Twine(arg.operandOrAttributeIndex()) + "].getType()")