@@ -88,6 +88,23 @@ llvm::raw_ostream &mlir_dumps_or_dbgs() {
8888 }
8989}
9090
91+ // Function to parse a comma-separated string into a vector of C-style strings
92+ llvm::SmallVector<const char *, 3 >
93+ parseCommaSeparatedValues (const std::string &input,
94+ llvm::SmallVector<std::string, 3 > &storage) {
95+ llvm::SmallVector<StringRef, 3 > split;
96+ llvm::SmallVector<const char *, 3 > result;
97+ StringRef (input.c_str ()).split (split, ' ,' );
98+ llvm::transform (split, std::back_inserter (result), [&storage](StringRef str) {
99+ // StringRefs are not always null-terminated.
100+ // The purpose for this storage pattern is to
101+ // produce a collection of C-strings that are.
102+ storage.push_back (str.str ());
103+ return storage.back ().c_str ();
104+ });
105+ return result;
106+ }
107+
91108// Run the pass manager under a source manager diagnostic handler, which
92109// enables emitted MLIR diagnostics to directly reference Python source
93110// code. This diagnostic handler supports filtering diagnostic info by
@@ -124,6 +141,43 @@ struct TritonSourceMgrDiagnosticHandler : public SourceMgrDiagnosticHandler {
124141 llvm::SourceMgr sourceMgr;
125142};
126143
144+ TritonSourceMgrDiagnosticHandler
145+ setupTritonDiagnosticHandler (MLIRContext *context) {
146+ bool showOperations = false , showStacktraces = false , showRemarks = false ,
147+ showWarnings = false ;
148+
149+ if (auto enableDiagnostics =
150+ triton::tools::getStrEnv (" MLIR_ENABLE_DIAGNOSTICS" );
151+ !enableDiagnostics.empty ()) {
152+ llvm::SmallVector<std::string, 3 > storage;
153+ parseCommaSeparatedValues (enableDiagnostics, storage);
154+ for (auto &str : storage) {
155+ if (str == " warnings" ) {
156+ showWarnings = true ;
157+ } else if (str == " remarks" ) {
158+ showRemarks = true ;
159+ } else if (str == " stacktraces" ) {
160+ showStacktraces = true ;
161+ } else if (str == " operations" ) {
162+ showOperations = true ;
163+ }
164+ // we show errors by default, so no need to set it
165+ }
166+ }
167+
168+ DiagnosticSeverity minSeverity =
169+ showWarnings ? DiagnosticSeverity::Warning : DiagnosticSeverity::Error;
170+ minSeverity = showRemarks ? DiagnosticSeverity::Remark : minSeverity;
171+
172+ context->printOpOnDiagnostic (showOperations);
173+ context->printStackTraceOnDiagnostic (showStacktraces);
174+ if (showStacktraces) {
175+ context->disableMultithreading ();
176+ }
177+
178+ return TritonSourceMgrDiagnosticHandler (context, minSeverity);
179+ }
180+
127181std::string locationToString (Location loc) {
128182 std::string str;
129183 llvm::raw_string_ostream os (str);
@@ -132,23 +186,6 @@ std::string locationToString(Location loc) {
132186 return str;
133187}
134188
135- // Function to parse a comma-separated string into a vector of C-style strings
136- llvm::SmallVector<const char *, 3 >
137- parseCommaSeparatedValues (const std::string &input,
138- llvm::SmallVector<std::string, 3 > &storage) {
139- llvm::SmallVector<StringRef, 3 > split;
140- llvm::SmallVector<const char *, 3 > result;
141- StringRef (input.c_str ()).split (split, ' ,' );
142- llvm::transform (split, std::back_inserter (result), [&storage](StringRef str) {
143- // StringRefs are not always null-terminated.
144- // The purpose for this storage pattern is to
145- // produce a collection of C-strings that are.
146- storage.push_back (str.str ());
147- return storage.back ().c_str ();
148- });
149- return result;
150- }
151-
152189void outputWarning (Location loc, const std::string &msg) {
153190 std::string locStr = locationToString (loc);
154191
@@ -694,7 +731,12 @@ void init_triton_ir(py::module &&m) {
694731 .def (" walk" ,
695732 [](ModuleOp &self, const std::function<void (Operation *)> &fn) {
696733 self.walk (fn);
697- });
734+ })
735+ .def (" verify_with_diagnostics" , [](ModuleOp &self) {
736+ TritonSourceMgrDiagnosticHandler handler =
737+ setupTritonDiagnosticHandler (self.getContext ());
738+ return succeeded (verify (self.getOperation ()));
739+ });
698740
699741 m.def (" make_attr" , [](const std::vector<int > &values, MLIRContext &context) {
700742 return mlir::cast<Attribute>(DenseIntElementsAttr::get (
@@ -1923,42 +1965,8 @@ void init_triton_ir(py::module &&m) {
19231965 self.enableTiming ();
19241966 }
19251967
1926- // setting up diagnostics
1927- bool showOperations = false , showStacktraces = false ,
1928- showRemarks = false , showWarnings = false ;
1929-
1930- if (auto enableDiagnostics =
1931- triton::tools::getStrEnv (" MLIR_ENABLE_DIAGNOSTICS" );
1932- !enableDiagnostics.empty ()) {
1933- llvm::SmallVector<std::string, 3 > storage;
1934- parseCommaSeparatedValues (enableDiagnostics, storage);
1935- for (auto &str : storage) {
1936- if (str == " warnings" ) {
1937- showWarnings = true ;
1938- } else if (str == " remarks" ) {
1939- showRemarks = true ;
1940- } else if (str == " stacktraces" ) {
1941- showStacktraces = true ;
1942- } else if (str == " operations" ) {
1943- showOperations = true ;
1944- }
1945- // we show errors by default, so no need to set it
1946- }
1947- }
1948-
1949- DiagnosticSeverity minSeverity = showWarnings
1950- ? DiagnosticSeverity::Warning
1951- : DiagnosticSeverity::Error;
1952- minSeverity =
1953- showRemarks ? DiagnosticSeverity::Remark : minSeverity;
1954-
1955- TritonSourceMgrDiagnosticHandler diagHandler (context, minSeverity);
1956-
1957- context->printOpOnDiagnostic (showOperations);
1958- context->printStackTraceOnDiagnostic (showStacktraces);
1959- if (showStacktraces) {
1960- context->disableMultithreading ();
1961- }
1968+ TritonSourceMgrDiagnosticHandler diagHandler =
1969+ setupTritonDiagnosticHandler (context);
19621970 if (failed (self.run (mod.getOperation ())))
19631971 throw std::runtime_error (" PassManager::run failed" );
19641972 },
0 commit comments