Skip to content

Commit 774fe01

Browse files
Mogballmeta-codesync[bot]
authored andcommitted
[Cherry-pick] [Frontend] Run the MLIR verifier after parsing (#7999) (#523)
Summary: Cherry-picked from upstream OAI repository. Original Commit: 165dd4b Original Author: Jeff Niu Original Date: 2025-08-28 12:44:28 -0700 Original commit message: ``` [Frontend] Run the MLIR verifier after parsing (#7999) The error messages generated aren't perfect, but this at least prevents the compiler from dumping a reproducer just for verifier errors. ``` This PR was automatically cherry-picked from the upstream triton-lang/triton repository. Pull Request resolved: #523 Reviewed By: minjang Differential Revision: D85686892 Pulled By: agron911 fbshipit-source-id: cba8640eae2fb69687d0410ebabddd5289a43102
1 parent bf70a34 commit 774fe01

File tree

2 files changed

+69
-57
lines changed

2 files changed

+69
-57
lines changed

python/src/ir.cc

Lines changed: 62 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,23 @@ llvm::raw_ostream &mlir_dumps_or_dbgs() {
8888
}
8989
}
9090

91+
// Function to parse a comma-separated string into a vector of C-style strings
92+
llvm::SmallVector<const char *, 3>
93+
parseCommaSeparatedValues(const std::string &input,
94+
llvm::SmallVector<std::string, 3> &storage) {
95+
llvm::SmallVector<StringRef, 3> split;
96+
llvm::SmallVector<const char *, 3> result;
97+
StringRef(input.c_str()).split(split, ',');
98+
llvm::transform(split, std::back_inserter(result), [&storage](StringRef str) {
99+
// StringRefs are not always null-terminated.
100+
// The purpose for this storage pattern is to
101+
// produce a collection of C-strings that are.
102+
storage.push_back(str.str());
103+
return storage.back().c_str();
104+
});
105+
return result;
106+
}
107+
91108
// Run the pass manager under a source manager diagnostic handler, which
92109
// enables emitted MLIR diagnostics to directly reference Python source
93110
// code. This diagnostic handler supports filtering diagnostic info by
@@ -124,6 +141,43 @@ struct TritonSourceMgrDiagnosticHandler : public SourceMgrDiagnosticHandler {
124141
llvm::SourceMgr sourceMgr;
125142
};
126143

144+
TritonSourceMgrDiagnosticHandler
145+
setupTritonDiagnosticHandler(MLIRContext *context) {
146+
bool showOperations = false, showStacktraces = false, showRemarks = false,
147+
showWarnings = false;
148+
149+
if (auto enableDiagnostics =
150+
triton::tools::getStrEnv("MLIR_ENABLE_DIAGNOSTICS");
151+
!enableDiagnostics.empty()) {
152+
llvm::SmallVector<std::string, 3> storage;
153+
parseCommaSeparatedValues(enableDiagnostics, storage);
154+
for (auto &str : storage) {
155+
if (str == "warnings") {
156+
showWarnings = true;
157+
} else if (str == "remarks") {
158+
showRemarks = true;
159+
} else if (str == "stacktraces") {
160+
showStacktraces = true;
161+
} else if (str == "operations") {
162+
showOperations = true;
163+
}
164+
// we show errors by default, so no need to set it
165+
}
166+
}
167+
168+
DiagnosticSeverity minSeverity =
169+
showWarnings ? DiagnosticSeverity::Warning : DiagnosticSeverity::Error;
170+
minSeverity = showRemarks ? DiagnosticSeverity::Remark : minSeverity;
171+
172+
context->printOpOnDiagnostic(showOperations);
173+
context->printStackTraceOnDiagnostic(showStacktraces);
174+
if (showStacktraces) {
175+
context->disableMultithreading();
176+
}
177+
178+
return TritonSourceMgrDiagnosticHandler(context, minSeverity);
179+
}
180+
127181
std::string locationToString(Location loc) {
128182
std::string str;
129183
llvm::raw_string_ostream os(str);
@@ -132,23 +186,6 @@ std::string locationToString(Location loc) {
132186
return str;
133187
}
134188

135-
// Function to parse a comma-separated string into a vector of C-style strings
136-
llvm::SmallVector<const char *, 3>
137-
parseCommaSeparatedValues(const std::string &input,
138-
llvm::SmallVector<std::string, 3> &storage) {
139-
llvm::SmallVector<StringRef, 3> split;
140-
llvm::SmallVector<const char *, 3> result;
141-
StringRef(input.c_str()).split(split, ',');
142-
llvm::transform(split, std::back_inserter(result), [&storage](StringRef str) {
143-
// StringRefs are not always null-terminated.
144-
// The purpose for this storage pattern is to
145-
// produce a collection of C-strings that are.
146-
storage.push_back(str.str());
147-
return storage.back().c_str();
148-
});
149-
return result;
150-
}
151-
152189
void outputWarning(Location loc, const std::string &msg) {
153190
std::string locStr = locationToString(loc);
154191

@@ -694,7 +731,12 @@ void init_triton_ir(py::module &&m) {
694731
.def("walk",
695732
[](ModuleOp &self, const std::function<void(Operation *)> &fn) {
696733
self.walk(fn);
697-
});
734+
})
735+
.def("verify_with_diagnostics", [](ModuleOp &self) {
736+
TritonSourceMgrDiagnosticHandler handler =
737+
setupTritonDiagnosticHandler(self.getContext());
738+
return succeeded(verify(self.getOperation()));
739+
});
698740

699741
m.def("make_attr", [](const std::vector<int> &values, MLIRContext &context) {
700742
return mlir::cast<Attribute>(DenseIntElementsAttr::get(
@@ -1923,42 +1965,8 @@ void init_triton_ir(py::module &&m) {
19231965
self.enableTiming();
19241966
}
19251967

1926-
// setting up diagnostics
1927-
bool showOperations = false, showStacktraces = false,
1928-
showRemarks = false, showWarnings = false;
1929-
1930-
if (auto enableDiagnostics =
1931-
triton::tools::getStrEnv("MLIR_ENABLE_DIAGNOSTICS");
1932-
!enableDiagnostics.empty()) {
1933-
llvm::SmallVector<std::string, 3> storage;
1934-
parseCommaSeparatedValues(enableDiagnostics, storage);
1935-
for (auto &str : storage) {
1936-
if (str == "warnings") {
1937-
showWarnings = true;
1938-
} else if (str == "remarks") {
1939-
showRemarks = true;
1940-
} else if (str == "stacktraces") {
1941-
showStacktraces = true;
1942-
} else if (str == "operations") {
1943-
showOperations = true;
1944-
}
1945-
// we show errors by default, so no need to set it
1946-
}
1947-
}
1948-
1949-
DiagnosticSeverity minSeverity = showWarnings
1950-
? DiagnosticSeverity::Warning
1951-
: DiagnosticSeverity::Error;
1952-
minSeverity =
1953-
showRemarks ? DiagnosticSeverity::Remark : minSeverity;
1954-
1955-
TritonSourceMgrDiagnosticHandler diagHandler(context, minSeverity);
1956-
1957-
context->printOpOnDiagnostic(showOperations);
1958-
context->printStackTraceOnDiagnostic(showStacktraces);
1959-
if (showStacktraces) {
1960-
context->disableMultithreading();
1961-
}
1968+
TritonSourceMgrDiagnosticHandler diagHandler =
1969+
setupTritonDiagnosticHandler(context);
19621970
if (failed(self.run(mod.getOperation())))
19631971
throw std::runtime_error("PassManager::run failed");
19641972
},

python/triton/compiler/code_generator.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1602,7 +1602,11 @@ def ast_to_ttir(fn, src, context, options, codegen_fns, module_map, module=None)
16021602
jit_fn=fn, is_kernel=True, file_name=file_name, begin_line=begin_line, options=options,
16031603
codegen_fns=codegen_fns, module_map=module_map, module=module, is_gluon=fn.is_gluon())
16041604
generator.visit(fn.parse())
1605-
ret = generator.module
1605+
module = generator.module
16061606
# module takes ownership of the context
1607-
ret.context = context
1608-
return ret
1607+
module.context = context
1608+
if not module.verify_with_diagnostics():
1609+
if not fn.is_gluon():
1610+
print(module)
1611+
raise RuntimeError("error encountered during parsing")
1612+
return module

0 commit comments

Comments
 (0)