Skip to content
This repository was archived by the owner on Oct 11, 2025. It is now read-only.

Commit 5893335

Browse files
authored
[mlir] Better Python diagnostics (#128581)
Updated the Python diagnostics handler to emit notes (in addition to errors) into the output stream so that users have more context as to where in the IR the error is occurring.
1 parent bb8fcc9 commit 5893335

File tree

1 file changed

+26
-13
lines changed

1 file changed

+26
-13
lines changed

mlir/include/mlir/Bindings/Python/Diagnostics.h

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,13 @@
99
#ifndef MLIR_BINDINGS_PYTHON_DIAGNOSTICS_H
1010
#define MLIR_BINDINGS_PYTHON_DIAGNOSTICS_H
1111

12-
#include <cassert>
13-
#include <string>
14-
1512
#include "mlir-c/Diagnostics.h"
1613
#include "mlir-c/IR.h"
17-
#include "llvm/ADT/StringRef.h"
14+
#include "llvm/Support/raw_ostream.h"
15+
16+
#include <cassert>
17+
#include <cstdint>
18+
#include <string>
1819

1920
namespace mlir {
2021
namespace python {
@@ -24,33 +25,45 @@ namespace python {
2425
class CollectDiagnosticsToStringScope {
2526
public:
2627
explicit CollectDiagnosticsToStringScope(MlirContext ctx) : context(ctx) {
27-
handlerID = mlirContextAttachDiagnosticHandler(ctx, &handler, &errorMessage,
28-
/*deleteUserData=*/nullptr);
28+
handlerID =
29+
mlirContextAttachDiagnosticHandler(ctx, &handler, &messageStream,
30+
/*deleteUserData=*/nullptr);
2931
}
3032
~CollectDiagnosticsToStringScope() {
31-
assert(errorMessage.empty() && "unchecked error message");
33+
assert(message.empty() && "unchecked error message");
3234
mlirContextDetachDiagnosticHandler(context, handlerID);
3335
}
3436

35-
[[nodiscard]] std::string takeMessage() { return std::move(errorMessage); }
37+
[[nodiscard]] std::string takeMessage() {
38+
std::string newMessage;
39+
std::swap(message, newMessage);
40+
return newMessage;
41+
}
3642

3743
private:
3844
static MlirLogicalResult handler(MlirDiagnostic diag, void *data) {
3945
auto printer = +[](MlirStringRef message, void *data) {
40-
*static_cast<std::string *>(data) +=
41-
llvm::StringRef(message.data, message.length);
46+
*static_cast<llvm::raw_string_ostream *>(data)
47+
<< std::string_view(message.data, message.length);
4248
};
4349
MlirLocation loc = mlirDiagnosticGetLocation(diag);
44-
*static_cast<std::string *>(data) += "at ";
50+
*static_cast<llvm::raw_string_ostream *>(data) << "at ";
4551
mlirLocationPrint(loc, printer, data);
46-
*static_cast<std::string *>(data) += ": ";
52+
*static_cast<llvm::raw_string_ostream *>(data) << ": ";
4753
mlirDiagnosticPrint(diag, printer, data);
54+
for (intptr_t i = 0; i < mlirDiagnosticGetNumNotes(diag); i++) {
55+
*static_cast<llvm::raw_string_ostream *>(data) << "\n";
56+
MlirDiagnostic note = mlirDiagnosticGetNote(diag, i);
57+
handler(note, data);
58+
}
4859
return mlirLogicalResultSuccess();
4960
}
5061

5162
MlirContext context;
5263
MlirDiagnosticHandlerID handlerID;
53-
std::string errorMessage = "";
64+
65+
std::string message;
66+
llvm::raw_string_ostream messageStream{message};
5467
};
5568

5669
} // namespace python

0 commit comments

Comments
 (0)