99#ifndef MLIR_BINDINGS_PYTHON_DIAGNOSTICS_H
1010#define MLIR_BINDINGS_PYTHON_DIAGNOSTICS_H
1111
12- #include < cassert>
13- #include < string>
14-
1512#include " mlir-c/Diagnostics.h"
1613#include " mlir-c/IR.h"
17- #include " llvm/ADT/StringRef.h"
14+ #include " llvm/Support/raw_ostream.h"
15+
16+ #include < cassert>
17+ #include < cstdint>
18+ #include < string>
1819
1920namespace mlir {
2021namespace python {
@@ -24,33 +25,45 @@ namespace python {
2425class CollectDiagnosticsToStringScope {
2526public:
2627 explicit CollectDiagnosticsToStringScope (MlirContext ctx) : context(ctx) {
27- handlerID = mlirContextAttachDiagnosticHandler (ctx, &handler, &errorMessage,
28- /* deleteUserData=*/ nullptr );
28+ handlerID =
29+ mlirContextAttachDiagnosticHandler (ctx, &handler, &messageStream,
30+ /* deleteUserData=*/ nullptr );
2931 }
3032 ~CollectDiagnosticsToStringScope () {
31- assert (errorMessage .empty () && " unchecked error message" );
33+ assert (message .empty () && " unchecked error message" );
3234 mlirContextDetachDiagnosticHandler (context, handlerID);
3335 }
3436
35- [[nodiscard]] std::string takeMessage () { return std::move (errorMessage); }
37+ [[nodiscard]] std::string takeMessage () {
38+ std::string newMessage;
39+ std::swap (message, newMessage);
40+ return newMessage;
41+ }
3642
3743private:
3844 static MlirLogicalResult handler (MlirDiagnostic diag, void *data) {
3945 auto printer = +[](MlirStringRef message, void *data) {
40- *static_cast <std::string *>(data) +=
41- llvm::StringRef (message.data , message.length );
46+ *static_cast <llvm::raw_string_ostream *>(data)
47+ << std::string_view (message.data , message.length );
4248 };
4349 MlirLocation loc = mlirDiagnosticGetLocation (diag);
44- *static_cast <std::string *>(data) += " at " ;
50+ *static_cast <llvm::raw_string_ostream *>(data) << " at " ;
4551 mlirLocationPrint (loc, printer, data);
46- *static_cast <std::string *>(data) += " : " ;
52+ *static_cast <llvm::raw_string_ostream *>(data) << " : " ;
4753 mlirDiagnosticPrint (diag, printer, data);
54+ for (intptr_t i = 0 ; i < mlirDiagnosticGetNumNotes (diag); i++) {
55+ *static_cast <llvm::raw_string_ostream *>(data) << " \n " ;
56+ MlirDiagnostic note = mlirDiagnosticGetNote (diag, i);
57+ handler (note, data);
58+ }
4859 return mlirLogicalResultSuccess ();
4960 }
5061
5162 MlirContext context;
5263 MlirDiagnosticHandlerID handlerID;
53- std::string errorMessage = " " ;
64+
65+ std::string message;
66+ llvm::raw_string_ostream messageStream{message};
5467};
5568
5669} // namespace python
0 commit comments