Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions mlir/include/mlir/Interfaces/InferTypeOpInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -245,9 +245,8 @@ inferReturnTensorTypes(ArrayRef<ShapedTypeComponents> retComponents,
/// the op. Precondition: op implements InferTypeOpInterface.
LogicalResult verifyInferredResultTypes(Operation *op);

/// Report a fatal error indicating that the result types could not be
/// inferred.
void reportFatalInferReturnTypesError(OperationState &state);
/// Report an error indicating that the result types could not be inferred.
void emitInferReturnTypesError(OperationState &state);
} // namespace detail

namespace OpTrait {
Expand Down
19 changes: 8 additions & 11 deletions mlir/lib/Interfaces/InferTypeOpInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -240,15 +240,12 @@ LogicalResult mlir::detail::verifyInferredResultTypes(Operation *op) {
return result;
}

void mlir::detail::reportFatalInferReturnTypesError(OperationState &state) {
std::string buffer;
llvm::raw_string_ostream os(buffer);
os << "Failed to infer result type(s):\n"
<< "\"" << state.name << "\"(...) "
<< state.attributes.getDictionary(state.location.getContext()) << " : ("
<< llvm::interleaved(llvm::map_range(
state.operands, [](Value val) { return val.getType(); }))
<< ") -> ( ??? )";
emitRemark(state.location, "location of op");
llvm::report_fatal_error(llvm::StringRef(buffer));
void mlir::detail::emitInferReturnTypesError(OperationState &state) {
mlir::emitError(state.location)
<< "failed to infer result type(s):\n"
<< "\"" << state.name << "\"(...) "
<< state.attributes.getDictionary(state.location.getContext()) << " : ("
<< llvm::interleaved(llvm::map_range(
state.operands, [](Value val) { return val.getType(); }))
<< ") -> ( ??? )";
}
2 changes: 1 addition & 1 deletion mlir/test/mlir-tblgen/op-decl-and-defs.td
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def NS_FOp : NS_Op<"op_with_all_types_constraint",
// DEFS: void FOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Value a) {
// DEFS: if (::mlir::succeeded(FOp::inferReturnTypes(odsBuilder.getContext(),
// DEFS: else
// DEFS: ::mlir::detail::reportFatalInferReturnTypesError(odsState);
// DEFS: ::mlir::detail::emitInferReturnTypesError(odsState);

// DEFS: FOp FOp::create(::mlir::OpBuilder &builder, ::mlir::Location location, ::mlir::Value a) {
// DEFS: ::mlir::OperationState __state__(location, getOperationName());
Expand Down
9 changes: 5 additions & 4 deletions mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2681,7 +2681,7 @@ void OpEmitter::genSeparateArgParamBuilder() {
{1}.regions, inferredReturnTypes)))
{1}.addTypes(inferredReturnTypes);
else
::mlir::detail::reportFatalInferReturnTypesError({1});
::mlir::detail::emitInferReturnTypesError({1});
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens post this in error case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's up to the usage, similar to calling op.verify

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How is the caller in this context supposed to know that there was an error?
You emit a diagnostic but don't propagate any error state do you?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct, I'm following the same pattern used in <Type>::getChecked and <Type>::parse for instance, which afaict expect you to check whether a diagnostic was emitted to signal an error state. I don't think this is a great pattern necessarily, we can do something else, but it's the pattern that seems to (1) have prior use in the codebase and (2) appears to locally best fit the needs.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually no: getChecked returns a null Attribute (or Type) which you can check. This is not the case here, the MLIR builder API has no provision to fail building an operation, and your patch does not do this right now.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gotcha. What is your recommendation here? Generally I would gravitate towards patterns like FailureOr but as you say the builder API doesn't have a clear place to put this.

)",
opClass.getClassName(), builderOpState);
return;
Expand Down Expand Up @@ -2967,10 +2967,11 @@ void OpEmitter::genInferredTypeCollectiveParamBuilder(
<< "u && \"mismatched number of return types\");";
body << "\n " << builderOpState << ".addTypes(inferredReturnTypes);";

body << R"(
body << formatv(R"(
} else {
::llvm::report_fatal_error("Failed to infer result type(s).");
})";
::mlir::detail::emitInferReturnTypesError({0});
})",
builderOpState);
}

void OpEmitter::genUseOperandAsResultTypeSeparateParamBuilder() {
Expand Down