Skip to content

Commit 924468e

Browse files
authored
[Frontend][Diagnostics] Improve emitting diagnostic information (#5581)
### Summary This PR enhances the current implementation for emitting diagnostic remarks by introducing a unified handler in `ir.cc`. This handler manages diagnostic information more effectively and disables the emission of IRs unless explicitly requested by the user. The `MLIR_ENABLE_DIAGNOSTICS` environment variable now controls all diagnostic emission settings, accepting one or more values from `{warnings, remarks, stacktraces, operations}`, separated by commas. Detailed usage instructions are available in the README. ### Background Previously, a new default LLVM `SourceManager` was configured in `nvidia/backend/compiler.py` to support remarks, applied in both `make_ttgir` and `make_llir`. However, a custom handler already existed in `ir.cc`, and a more robust design should extend this handler rather than create a new one. ### Changes - **Unified Handler**: Inspired by LLVM upstream [[PR 117669](https://github.com/llvm/llvm-project/pull/117669)](https://github.com/llvm/llvm-project/pull/117669), this PR implements a similar custom handler that supports various severity levels. The `MLIR_ENABLE_DIAGNOSTICS` environment variable now specifies the severity level: `warnings` for warnings and errors, and `remarks` for remarks, warnings, and errors. - **IR Emission Control**: By default, the MLIR diagnostic API emits IRs, which can clutter error messages or performance remarks. This PR suppresses IR emission unless explicitly enabled by the user, improving the readability of error messages and performance remarks. Users can specify `MLIR_ENABLE_DIAGNOSTICS=remarks,operations` to include IR operations in remarks. - **Stacktraces**: Previously, setting `MLIR_ENABLE_DIAGNOSTICS=1` enabled all diagnostic information with stacktraces. Now, the `stacktraces` parameter specifically enables stacktraces. For example, `MLIR_ENABLE_DIAGNOSTICS=remarks,operations,stacktraces` enables IR operations and stacktraces, displaying all remarks, warnings, and errors. - **Testing**: Updated existing Python tests to verify that combinations of operations and stacktraces are emitted successfully. ### Future Work - With the new handler in place, there is an opportunity to further enhance the readability of existing warnings and remarks. This will be a focus in future updates. # New contributor declaration - [x] I am not making a trivial change, such as fixing a typo in a comment. - [x] I have written a PR description following these [rules](https://cbea.ms/git-commit/#why-not-how). - [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`. - Select one of the following. - [x] I have added tests. - `/python/test` for end-to-end tests - [ ] This PR does not need a test. - Select one of the following. - [x] I have not added any `lit` tests. - [ ] The `lit` tests I have added follow these [best practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices), including the "tests should be minimal" section. (Usually running Python code and using the instructions it generates is not minimal.)
1 parent 1186806 commit 924468e

File tree

4 files changed

+161
-97
lines changed

4 files changed

+161
-97
lines changed

README.md

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -232,8 +232,14 @@ For detailed instructions on how to debug Triton's frontend, please refer to thi
232232
- `MLIR_ENABLE_TIMING` dumps the timing information for each MLIR pass.
233233
- `LLVM_ENABLE_TIMING` dumps the timing information for each LLVM pass.
234234
- `TRITON_DEFAULT_FP_FUSION` overrides the default behavior of allowing fp fusion (mul+add->fma).
235-
- `MLIR_ENABLE_DIAGNOSTICS` enables dumping the stack trace and the related IR operation of diagnostics (e.g., errors and warnings).
236-
- `MLIR_ENABLE_REMARK` enables the performance warnings that are emitted as remarks.
235+
- `MLIR_ENABLE_DIAGNOSTICS=<comma-separated>` controls diagnostic emission in MLIR.
236+
Options are: `warnings`, `remarks`, `stacktraces`, `operations`.
237+
Use comma-separated values to customize output. For example,
238+
`MLIR_ENABLE_DIAGNOSTICS=remarks,operations` enables remarks and IR operations,
239+
while `MLIR_ENABLE_DIAGNOSTICS=warnings,stacktraces` enables warnings with
240+
stacktraces. By default, only errors are shown. Setting `warnings` includes
241+
errors and warnings; `remarks` includes errors, warnings, and remarks.
242+
- `MLIR_ENABLE_REMARK` is deprecated. Please use `MLIR_ENABLE_DIAGNOSTICS=remarks`.
237243
- `TRITON_KERNEL_DUMP` enables the dumping of the IR from each compilation stage and the final ptx/amdgcn.
238244
- `TRITON_DUMP_DIR` specifies the directory to save the dumped IR and ptx/amdgcn when `TRITON_KERNEL_DUMP` is set to 1.
239245
- `TRITON_KERNEL_OVERRIDE` enables the override of the compiled kernel with a user-specified IR/ptx/amdgcn at the beginning of each compilation stage.

python/src/ir.cc

Lines changed: 91 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,42 @@ class TritonOpBuilder {
140140
bool lineInfoEnabled = !triton::tools::getBoolEnv("TRITON_DISABLE_LINE_INFO");
141141
};
142142

143+
// Run the pass manager under a source manager diagnostic handler, which
144+
// enables emitted MLIR diagnostics to directly reference Python source
145+
// code. This diagnostic handler supports filtering diagnostic info by
146+
// severity levels.
147+
struct TritonSourceMgrDiagnosticHandler : public SourceMgrDiagnosticHandler {
148+
TritonSourceMgrDiagnosticHandler(MLIRContext *ctx,
149+
DiagnosticSeverity minSeverity)
150+
: SourceMgrDiagnosticHandler(sourceMgr, ctx, llvm::errs()) {
151+
setHandler([this, minSeverity](Diagnostic &diag) {
152+
auto severity = diag.getSeverity();
153+
switch (severity) {
154+
case DiagnosticSeverity::Error:
155+
break;
156+
case DiagnosticSeverity::Warning:
157+
if (minSeverity == DiagnosticSeverity::Error)
158+
return success();
159+
break;
160+
case DiagnosticSeverity::Remark:
161+
if (minSeverity == DiagnosticSeverity::Error ||
162+
minSeverity == DiagnosticSeverity::Warning)
163+
return success();
164+
break;
165+
case DiagnosticSeverity::Note:
166+
// notes are handled somewhere else.
167+
return failure();
168+
default:
169+
llvm_unreachable("Unknown diagnostic severity");
170+
}
171+
emitDiagnostic(diag);
172+
return success();
173+
});
174+
}
175+
176+
llvm::SourceMgr sourceMgr;
177+
};
178+
143179
std::string locationToString(Location loc) {
144180
std::string str;
145181
llvm::raw_string_ostream os(str);
@@ -148,6 +184,23 @@ std::string locationToString(Location loc) {
148184
return str;
149185
}
150186

187+
// Function to parse a comma-separated string into a vector of C-style strings
188+
llvm::SmallVector<const char *, 3>
189+
parseCommaSeparatedValues(const std::string &input,
190+
llvm::SmallVector<std::string, 3> &storage) {
191+
llvm::SmallVector<StringRef, 3> split;
192+
llvm::SmallVector<const char *, 3> result;
193+
StringRef(input.c_str()).split(split, ',');
194+
llvm::transform(split, std::back_inserter(result), [&storage](StringRef str) {
195+
// StringRefs are not always null-terminated.
196+
// The purpose for this storage pattern is to
197+
// produce a collection of C-strings that are.
198+
storage.push_back(str.str());
199+
return storage.back().c_str();
200+
});
201+
return result;
202+
}
203+
151204
void outputWarning(Location loc, const std::string &msg) {
152205
std::string locStr = locationToString(loc);
153206

@@ -1691,27 +1744,15 @@ void init_triton_ir(py::module &&m) {
16911744
.def("enable_debug",
16921745
[](PassManager &self) {
16931746
auto *context = self.getContext();
1694-
bool haveDiagnostics =
1695-
::triton::tools::getBoolEnv("MLIR_ENABLE_DIAGNOSTICS");
16961747
bool haveDump = ::triton::tools::getBoolEnv("MLIR_ENABLE_DUMP");
16971748
std::string funcToDump;
16981749
if (!haveDump) {
16991750
funcToDump = triton::tools::getStrEnv("MLIR_ENABLE_DUMP");
17001751
if (!funcToDump.empty())
17011752
haveDump = true;
17021753
}
1703-
if (haveDiagnostics || haveDump) {
1704-
context->disableMultithreading();
1705-
}
1706-
if (haveDiagnostics) {
1707-
context->printOpOnDiagnostic(true);
1708-
context->printStackTraceOnDiagnostic(true);
1709-
context->getDiagEngine().registerHandler([](Diagnostic &diag) {
1710-
llvm::outs() << diag << "\n";
1711-
return success();
1712-
});
1713-
}
17141754
if (haveDump) {
1755+
context->disableMultithreading();
17151756
auto printingFlags = OpPrintingFlags();
17161757
printingFlags.elideLargeElementsAttrs(16);
17171758
printingFlags.enableDebugInfo();
@@ -1741,6 +1782,8 @@ void init_triton_ir(py::module &&m) {
17411782
// TODO: maybe dump module to file and print error for better
17421783
// diagnostics
17431784

1785+
auto *context = mod.getContext();
1786+
17441787
auto reproducerPath =
17451788
triton::tools::getStrEnv("TRITON_REPRODUCER_PATH");
17461789
if (!reproducerPath.empty()) {
@@ -1752,7 +1795,7 @@ void init_triton_ir(py::module &&m) {
17521795
makeReproducer(anchorName, passes, op, reproducerPath);
17531796
// But if the pass manager crashes, attempt to generate a local
17541797
// reproducer instead.
1755-
mod.getContext()->disableMultithreading();
1798+
context->disableMultithreading();
17561799
self.enableCrashReproducerGeneration(reproducerPath,
17571800
/*genLocalReproducer=*/true);
17581801
}
@@ -1763,20 +1806,9 @@ void init_triton_ir(py::module &&m) {
17631806

17641807
if (auto debugOnly = triton::tools::getStrEnv("TRITON_LLVM_DEBUG_ONLY");
17651808
!debugOnly.empty()) {
1766-
llvm::SmallVector<StringRef, 3> split;
17671809
llvm::SmallVector<std::string, 3> storage;
1768-
llvm::SmallVector<const char *, 3> debugTypes;
1769-
1770-
StringRef(debugOnly.c_str()).split(split, ',');
1771-
llvm::transform(split, std::back_inserter(debugTypes),
1772-
[&storage](StringRef str) {
1773-
// StringRefs are not always null-terminated.
1774-
// The purpose for this storage pattern is to
1775-
// produce a collection of C-strings that are.
1776-
storage.push_back(str.str());
1777-
return storage.back().c_str();
1778-
});
1779-
1810+
llvm::SmallVector<const char *, 3> debugTypes =
1811+
parseCommaSeparatedValues(debugOnly, storage);
17801812
::llvm::DebugFlag = true;
17811813
using namespace llvm;
17821814
setCurrentDebugTypes(debugTypes.data(), debugTypes.size());
@@ -1787,25 +1819,41 @@ void init_triton_ir(py::module &&m) {
17871819
self.enableTiming();
17881820
}
17891821

1790-
// Run the pass manager under a source manager diagnostic handler, which
1791-
// enables emitted MLIR diagnostics to directly reference Python source
1792-
// code. This diagnostic handler will only filter for errors.
1793-
struct SourceMgrErrorDiagnosticHandler
1794-
: public SourceMgrDiagnosticHandler {
1795-
SourceMgrErrorDiagnosticHandler(MLIRContext *ctx)
1796-
: SourceMgrDiagnosticHandler(sourceMgr, ctx, llvm::errs()) {
1797-
setHandler([this](Diagnostic &diag) {
1798-
if (diag.getSeverity() != DiagnosticSeverity::Error)
1799-
return failure();
1800-
emitDiagnostic(diag);
1801-
return success();
1802-
});
1822+
// setting up diagnostics
1823+
bool showOperations = false, showStacktraces = false,
1824+
showRemarks = false, showWarnings = false;
1825+
1826+
if (auto enableDiagnostics =
1827+
triton::tools::getStrEnv("MLIR_ENABLE_DIAGNOSTICS");
1828+
!enableDiagnostics.empty()) {
1829+
llvm::SmallVector<std::string, 3> storage;
1830+
parseCommaSeparatedValues(enableDiagnostics, storage);
1831+
for (auto &str : storage) {
1832+
if (str == "warnings") {
1833+
showWarnings = true;
1834+
} else if (str == "remarks") {
1835+
showRemarks = true;
1836+
} else if (str == "stacktraces") {
1837+
showStacktraces = true;
1838+
} else if (str == "operations") {
1839+
showOperations = true;
1840+
}
1841+
// we show errors by default, so no need to set it
18031842
}
1843+
}
18041844

1805-
llvm::SourceMgr sourceMgr;
1806-
};
1807-
SourceMgrErrorDiagnosticHandler diagHandler(mod.getContext());
1845+
DiagnosticSeverity minSeverity = showWarnings
1846+
? DiagnosticSeverity::Warning
1847+
: DiagnosticSeverity::Error;
1848+
minSeverity = showRemarks ? DiagnosticSeverity::Remark : minSeverity;
18081849

1850+
TritonSourceMgrDiagnosticHandler diagHandler(context, minSeverity);
1851+
1852+
context->printOpOnDiagnostic(showOperations);
1853+
context->printStackTraceOnDiagnostic(showStacktraces);
1854+
if (showStacktraces) {
1855+
context->disableMultithreading();
1856+
}
18091857
if (failed(self.run(mod.getOperation())))
18101858
throw std::runtime_error("PassManager::run failed");
18111859
});

python/test/unit/test_perf_warning.py

Lines changed: 61 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,12 @@
88

99

1010
@contextmanager
11-
def enable_remark_context():
11+
def enable_diagnostics_context(value):
1212
try:
13-
os.environ["MLIR_ENABLE_REMARK"] = "1"
13+
os.environ["MLIR_ENABLE_DIAGNOSTICS"] = value
1414
yield
1515
finally:
16-
os.environ["MLIR_ENABLE_REMARK"] = "0"
17-
18-
19-
def is_perf_warning_enabled():
20-
return os.environ.get("MLIR_ENABLE_REMARK", "0") == "1"
16+
os.environ["MLIR_ENABLE_DIAGNOSTICS"] = ""
2117

2218

2319
def is_cuda():
@@ -74,29 +70,39 @@ def matmul_kernel(
7470
c = tl.dot(a, b)
7571
tl.store(c_block_ptr, c)
7672

77-
with enable_remark_context():
78-
triton.compile(
79-
triton.compiler.ASTSource(
80-
fn=matmul_kernel,
81-
signature={
82-
"a_ptr": "*fp32",
83-
"b_ptr": "*fp32",
84-
"c_ptr": "*fp32",
85-
"M": "i32",
86-
"N": "i32",
87-
"K": "i32",
88-
"stride_am": "i32",
89-
"stride_ak": "i32",
90-
"stride_bk": "i32",
91-
"stride_bn": "i32",
92-
"stride_cm": "i32",
93-
"stride_cn": "i32",
94-
},
95-
constexprs={},
96-
))
73+
signature = {
74+
"a_ptr": "*fp32",
75+
"b_ptr": "*fp32",
76+
"c_ptr": "*fp32",
77+
"M": "i32",
78+
"N": "i32",
79+
"K": "i32",
80+
"stride_am": "i32",
81+
"stride_ak": "i32",
82+
"stride_bk": "i32",
83+
"stride_bn": "i32",
84+
"stride_cm": "i32",
85+
"stride_cn": "i32",
86+
}
87+
with enable_diagnostics_context('remarks'):
88+
triton.compile(triton.compiler.ASTSource(
89+
fn=matmul_kernel,
90+
signature=signature,
91+
constexprs={},
92+
))
9793
captured = capfd.readouterr()
9894

99-
assert ("remark: Warning: can't use MMA V3 for the dot op" in captured.err), "expect MMA V3 remark"
95+
assert ("can't use MMA V3 for the dot op" in captured.err), "expect MMA V3 remark"
96+
assert "note: see current operation:" not in captured.err
97+
98+
with enable_diagnostics_context('remarks,operations,stacktraces'):
99+
triton.compile(triton.compiler.ASTSource(
100+
fn=matmul_kernel,
101+
signature=signature,
102+
constexprs={},
103+
))
104+
captured = capfd.readouterr()
105+
assert "note: diagnostic emitted with trace:" in captured.err
100106
assert "note: see current operation:" in captured.err
101107

102108

@@ -126,25 +132,39 @@ def ldst_vec(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, XBLOCK: tl.constexpr)
126132
tl.store(out_ptr0 + (x4), tmp22, None)
127133

128134
XBLOCK = 1024
129-
with enable_remark_context():
135+
136+
astsource_args = {
137+
"fn": ldst_vec,
138+
"signature": {
139+
"in_ptr0": "*i64",
140+
"in_ptr1": "*i64",
141+
"in_ptr2": "*fp16",
142+
"in_ptr3": "*fp32",
143+
"out_ptr0": "*fp16",
144+
"XBLOCK": "constexpr",
145+
},
146+
"constexprs": {"XBLOCK": XBLOCK},
147+
}
148+
149+
with enable_diagnostics_context('remarks'):
130150
triton.compile(
131-
triton.compiler.ASTSource(
132-
fn=ldst_vec,
133-
signature={
134-
"in_ptr0": "*i64",
135-
"in_ptr1": "*i64",
136-
"in_ptr2": "*fp16",
137-
"in_ptr3": "*fp32",
138-
"out_ptr0": "*fp16",
139-
"XBLOCK": "constexpr",
140-
},
141-
constexprs={"XBLOCK": XBLOCK},
142-
),
151+
triton.compiler.ASTSource(**astsource_args),
143152
options={"num_warps": 1},
144153
)
145154

146155
_, err = capfd.readouterr()
147156
assert ("remark: Warning: vectorization fails" in err), "expect vectorization failure remark"
157+
assert "note: see current operation:" not in err
158+
159+
with enable_diagnostics_context('remarks,operations,stacktraces'):
160+
triton.compile(
161+
triton.compiler.ASTSource(**astsource_args),
162+
options={"num_warps": 1},
163+
)
164+
165+
_, err = capfd.readouterr()
166+
assert "note: see current operation:" in err
167+
assert "note: diagnostic emitted with trace:" in err
148168

149169

150170
def test_remark_swp_op_before_operands(capfd, fresh_triton_cache):

third_party/nvidia/backend/compiler.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -238,12 +238,6 @@ def make_ttgir(mod, metadata, opt, capability):
238238
cluster_info.clusterDimX = opt.cluster_dims[0]
239239
cluster_info.clusterDimY = opt.cluster_dims[1]
240240
cluster_info.clusterDimZ = opt.cluster_dims[2]
241-
# Set up Diagnostic
242-
if os.environ.get("MLIR_ENABLE_REMARK", "0") == "1":
243-
srcMgr = llvm.source_mgr()
244-
_ = ir.source_mgr_diag(srcMgr, mod.context)
245-
mod.context.printOpOnDiagnostic(True)
246-
# TTIR -> TTGIR
247241
pm = ir.pass_manager(mod.context)
248242
pm.enable_debug()
249243
passes.ttir.add_convert_to_ttgpuir(pm, f"cuda:{capability}", opt.num_warps, 32, opt.num_ctas)
@@ -299,11 +293,7 @@ def make_llir(self, src, metadata, options, capability):
299293
# TritonGPU -> LLVM-IR (MLIR)
300294
pm = ir.pass_manager(mod.context)
301295
pm.enable_debug()
302-
# Set up Diagnostic
303-
if os.environ.get("MLIR_ENABLE_REMARK", "0") == "1":
304-
srcMgr = llvm.source_mgr()
305-
_ = ir.source_mgr_diag(srcMgr, mod.context)
306-
mod.context.printOpOnDiagnostic(True)
296+
307297
nvidia.passes.ttnvgpuir.add_lower_mma(pm)
308298
passes.ttgpuir.add_combine_tensor_select_and_if(pm)
309299
passes.convert.add_scf_to_cf(pm)

0 commit comments

Comments
 (0)