diff --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.h b/mlir/include/mlir/Interfaces/InferTypeOpInterface.h index 47bcfc9bbd4f9..4fcbeff9df560 100644 --- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.h +++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.h @@ -244,6 +244,10 @@ inferReturnTensorTypes(ArrayRef retComponents, /// Verifies that the inferred result types match the actual result types for /// the op. Precondition: op implements InferTypeOpInterface. LogicalResult verifyInferredResultTypes(Operation *op); + +/// Report a fatal error indicating that the result types could not be +/// inferred. +void reportFatalInferReturnTypesError(OperationState &state); } // namespace detail namespace OpTrait { diff --git a/mlir/lib/Interfaces/InferTypeOpInterface.cpp b/mlir/lib/Interfaces/InferTypeOpInterface.cpp index e52d0e17cda22..8cc4206dae6ed 100644 --- a/mlir/lib/Interfaces/InferTypeOpInterface.cpp +++ b/mlir/lib/Interfaces/InferTypeOpInterface.cpp @@ -247,3 +247,17 @@ LogicalResult mlir::detail::verifyInferredResultTypes(Operation *op) { return result; } + +void mlir::detail::reportFatalInferReturnTypesError(OperationState &state) { + std::string buffer; + llvm::raw_string_ostream os(buffer); + os << "Failed to infer result type(s):\n"; + os << "\"" << state.name << "\"(...) "; + os << state.attributes.getDictionary(state.location.getContext()); + os << " : ("; + llvm::interleaveComma(state.operands, os, + [&](Value val) { os << val.getType(); }); + os << ") -> ( ??? )"; + emitRemark(state.location, "location of op"); + llvm::report_fatal_error(llvm::StringRef(buffer)); +} diff --git a/mlir/test/mlir-tblgen/op-decl-and-defs.td b/mlir/test/mlir-tblgen/op-decl-and-defs.td index 31dd53725c59a..a03d0b40d4655 100644 --- a/mlir/test/mlir-tblgen/op-decl-and-defs.td +++ b/mlir/test/mlir-tblgen/op-decl-and-defs.td @@ -208,6 +208,11 @@ def NS_FOp : NS_Op<"op_with_all_types_constraint", // CHECK-LABEL: class FOp : // CHECK: static ::llvm::LogicalResult inferReturnTypes +// DEFS: void FOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Value a) { +// DEFS: if (::mlir::succeeded(FOp::inferReturnTypes(odsBuilder.getContext(), +// DEFS: else +// DEFS: ::mlir::detail::reportFatalInferReturnTypesError(odsState); + def NS_GOp : NS_Op<"op_with_fixed_return_type", []> { let arguments = (ins AnyType:$a); let results = (outs I32:$b); diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp index ce2b6ed94c394..71fa5011a476b 100644 --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -2503,7 +2503,8 @@ void OpEmitter::genSeparateArgParamBuilder() { {1}.regions, inferredReturnTypes))) {1}.addTypes(inferredReturnTypes); else - ::llvm::report_fatal_error("Failed to infer result type(s).");)", + ::mlir::detail::reportFatalInferReturnTypesError({1}); + )", opClass.getClassName(), builderOpState); return; }