From e686f5aa3e1119d53ec77abef401be1a0364e739 Mon Sep 17 00:00:00 2001 From: Kolya Panchenko Date: Wed, 16 Oct 2024 11:51:27 -0400 Subject: [PATCH 1/3] [mlir][ods] Verfify access to operands in inferReturnTypes The patch adds graceful handling of incorrectly constructed MLIR operation with less operands than expected. --- mlir/test/mlir-tblgen/op-result.td | 7 +++++++ mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp | 16 ++++++++++++++++ 2 files changed, 23 insertions(+) diff --git a/mlir/test/mlir-tblgen/op-result.td b/mlir/test/mlir-tblgen/op-result.td index 0ca570cf8cafb..51f8b0671a328 100644 --- a/mlir/test/mlir-tblgen/op-result.td +++ b/mlir/test/mlir-tblgen/op-result.td @@ -130,6 +130,8 @@ def OpL1 : NS_Op<"op_with_all_types_constraint", // CHECK-LABEL: LogicalResult OpL1::inferReturnTypes // CHECK-NOT: } +// CHECK: if (operands.size() <= 0) +// CHECK-NEXT: return ::mlir::failure(); // CHECK: ::mlir::Type odsInferredType0 = operands[0].getType(); // CHECK: inferredReturnTypes[0] = odsInferredType0; @@ -141,6 +143,9 @@ def OpL2 : NS_Op<"op_with_all_types_constraint", // CHECK-LABEL: LogicalResult OpL2::inferReturnTypes // CHECK-NOT: } +// CHECK: if (operands.size() <= 2) +// CHECK-NEXT: return ::mlir::failure(); +// CHECK-NOT: if (operands.size() <= 0) // CHECK: ::mlir::Type odsInferredType0 = operands[2].getType(); // CHECK: ::mlir::Type odsInferredType1 = operands[0].getType(); // CHECK: inferredReturnTypes[0] = odsInferredType0; @@ -166,6 +171,8 @@ def OpL4 : NS_Op<"two_inference_edges", [ } // CHECK-LABEL: LogicalResult OpL4::inferReturnTypes +// CHECK: if (operands.size() <= 0) +// CHECK-NEXT: return ::mlir::failure(); // CHECK: odsInferredType0 = fromInput(operands[0].getType()) // CHECK: odsInferredType1 = infer0(odsInferredType0) // CHECK: odsInferredType2 = infer1(odsInferredType1) diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp index ce2b6ed94c394..c55a00cf08a7c 100644 --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -3583,6 +3583,22 @@ void OpEmitter::genTypeInterfaceMethods() { fctx.addSubst("_ctxt", "context"); body << " ::mlir::Builder odsBuilder(context);\n"; + // Preprocessing stage to verify all accesses to operands are valid. + int maxAccessedIndex = -1; + for (int i = 0, e = op.getNumResults(); i != e; ++i) { + const InferredResultType &infer = op.getInferredResultType(i); + if (!infer.isArg()) + continue; + auto arg = op.getArgToOperandOrAttribute(infer.getIndex()); + if (arg.kind() == Operator::OperandOrAttribute::Kind::Operand) + maxAccessedIndex = + std::max(maxAccessedIndex, arg.operandOrAttributeIndex()); + } + if (maxAccessedIndex != -1) { + body << " if (operands.size() <= " << Twine(maxAccessedIndex) << ")\n"; + body << " return ::mlir::failure();\n"; + } + // Process the type inference graph in topological order, starting from types // that are always fully-inferred: operands and results with constructible // types. The type inference graph here will always be a DAG, so this gives From f6232ef4771daf5fb26a1efe170e58502af709bf Mon Sep 17 00:00:00 2001 From: Kolya Panchenko Date: Wed, 16 Oct 2024 18:28:35 -0400 Subject: [PATCH 2/3] [NFC] added {} for 2-line if-statement --- mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp index c55a00cf08a7c..d1767f0908671 100644 --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -3590,9 +3590,10 @@ void OpEmitter::genTypeInterfaceMethods() { if (!infer.isArg()) continue; auto arg = op.getArgToOperandOrAttribute(infer.getIndex()); - if (arg.kind() == Operator::OperandOrAttribute::Kind::Operand) + if (arg.kind() == Operator::OperandOrAttribute::Kind::Operand) { maxAccessedIndex = std::max(maxAccessedIndex, arg.operandOrAttributeIndex()); + } } if (maxAccessedIndex != -1) { body << " if (operands.size() <= " << Twine(maxAccessedIndex) << ")\n"; From fdfb7647ee2bb52db18b8ac31d5bb42d4c1e3b60 Mon Sep 17 00:00:00 2001 From: Kolya Panchenko Date: Wed, 16 Oct 2024 19:47:28 -0400 Subject: [PATCH 3/3] replaced auto with Operator::OperandOrAttribute --- mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp index d1767f0908671..d466c4d47ee6f 100644 --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -3589,7 +3589,8 @@ void OpEmitter::genTypeInterfaceMethods() { const InferredResultType &infer = op.getInferredResultType(i); if (!infer.isArg()) continue; - auto arg = op.getArgToOperandOrAttribute(infer.getIndex()); + Operator::OperandOrAttribute arg = + op.getArgToOperandOrAttribute(infer.getIndex()); if (arg.kind() == Operator::OperandOrAttribute::Kind::Operand) { maxAccessedIndex = std::max(maxAccessedIndex, arg.operandOrAttributeIndex()); @@ -3616,7 +3617,8 @@ void OpEmitter::genTypeInterfaceMethods() { if (infer.isArg()) { // If this is an operand, just index into operand list to access the // type. - auto arg = op.getArgToOperandOrAttribute(infer.getIndex()); + Operator::OperandOrAttribute arg = + op.getArgToOperandOrAttribute(infer.getIndex()); if (arg.kind() == Operator::OperandOrAttribute::Kind::Operand) { typeStr = ("operands[" + Twine(arg.operandOrAttributeIndex()) + "].getType()")