Skip to content

Commit 799d846

Browse files
Youngzt998plotfi
andauthored
Triton debug name/type information lowering through llvm dbg metadata into cubin/hsaco (triton-lang#7633)
@makslevental @plotfi @htyu @minjang * This is a draft PR for referencing the debug info work. * This work will need to be integrated with triton-lang#7521 once it does land. * The overlap with triton-lang#7521 is that these both collect source level naming information and lower them to loc/ssa value names at the MLIR level * However, this PR also adds support for lowering name and type information to LLVM Metadata for debugging purposes with cuda-gdb or rocm-gdb Below is a showcase of what we are able to do (we use flag LLVM_EXTRACT_DI_LOCAL_VARIABLES to turn on this option): ![output (1)](https://github.com/user-attachments/assets/6154b712-9cb7-4aac-ba3f-babf926a28cf) # New contributor declaration - [x] I am not making a trivial change, such as fixing a typo in a comment. - [x] I have written a PR description following these [rules](https://cbea.ms/git-commit/#why-not-how). - [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`. - Select one of the following. - [x] I have added tests. - `/test` for `lit` tests - `/unittest` for C++ tests - `/python/test` for end-to-end tests - [ ] This PR does not need a test because `FILL THIS IN`. - Select one of the following. - [ ] I have not added any `lit` tests. - [x] The `lit` tests I have added follow these [best practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices), including the "tests should be minimal" section. (Usually running Python code and using the instructions it generates is not minimal.) --------- Co-authored-by: Puyan Lotfi <[email protected]>
1 parent 0a4a711 commit 799d846

File tree

16 files changed

+559
-11
lines changed

16 files changed

+559
-11
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,7 @@ See [`python/triton/knobs.py`](python/triton/knobs.py) for the full list of conf
244244
- `TRITON_FRONT_END_DEBUGGING=1` disables exception wrapping when an error occurs in the compiler frontend, allowing the full stack trace to be seen.
245245
- `TRITON_DISABLE_LINE_INFO=1` removes all line information from the module.
246246
- `PTXAS_OPTIONS` passes additional command-line options to the PTX assembler `ptxas` (only on NVIDIA).
247+
- `LLVM_EXTRACT_DI_LOCAL_VARIABLES` emit full debug info, allowing for eval of values in gpu debuggers (ie cuda-gdb, rocm-gdb etc)
247248

248249
> [!NOTE]
249250
> Some of these environment variables don't have a knob in `knobs.py`-- those are only relevant to the C++ layer(s), hence they don't exist in the python layer.

bin/RegisterTritonDialects.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,12 @@
3939
#include "mlir/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.h"
4040
#include "mlir/InitAllPasses.h"
4141

42+
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
43+
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
44+
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
45+
#include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h"
46+
#include "mlir/Conversion/UBToLLVM/UBToLLVM.h"
47+
4248
namespace mlir {
4349
namespace test {
4450
void registerTestAliasPass();
@@ -82,13 +88,20 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
8288
mlir::registerLLVMDIScope();
8389
mlir::LLVM::registerInlinerInterface(registry);
8490
mlir::NVVM::registerInlinerInterface(registry);
91+
mlir::registerLLVMDILocalVariable();
8592

8693
// TritonAMDGPUToLLVM passes
8794
mlir::triton::registerAllocateAMDGPUSharedMemory();
8895
mlir::triton::registerConvertTritonAMDGPUToLLVM();
8996
mlir::triton::registerConvertBuiltinFuncToLLVM();
9097
mlir::triton::registerOptimizeAMDLDSUsage();
9198

99+
mlir::ub::registerConvertUBToLLVMInterface(registry);
100+
mlir::registerConvertNVVMToLLVMInterface(registry);
101+
mlir::registerConvertMathToLLVMInterface(registry);
102+
mlir::cf::registerConvertControlFlowToLLVMInterface(registry);
103+
mlir::arith::registerConvertArithToLLVMInterface(registry);
104+
92105
// TritonAMDGPUTransforms passes
93106
mlir::registerTritonAMDGPUAccelerateMatmul();
94107
mlir::registerTritonAMDGPUOptimizeEpilogue();

include/triton/Target/LLVMIR/Passes.td

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,12 @@ def LLVMDIScope: Pass<"enable-line-info", "mlir::ModuleOp"> {
1010
}];
1111
}
1212

13+
def LLVMDILocalVariable: Pass<"extract-variable-info", "mlir::ModuleOp"> {
14+
let summary = "Pull out source variable info from Location to DILocalVariable";
15+
let description = [{
16+
This pass pulled out source vararible's debuginfo from LLVM IR dialect's Location
17+
into LLVM's DILocalVariable and fused it into previous Location so it can be passed to LLVM IR later in debugging mode.
18+
}];
19+
}
20+
1321
#endif

include/triton/Tools/Sys/GetEnv.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ inline const std::set<std::string> CACHE_INVALIDATING_ENV_VARS = {
2323
"LLVM_IR_ENABLE_DUMP",
2424
"LLVM_ENABLE_TIMING",
2525
"LLVM_PASS_PLUGIN_PATH",
26+
"LLVM_EXTRACT_DI_LOCAL_VARIABLES",
2627
"MLIR_ENABLE_DIAGNOSTICS",
2728
"MLIR_ENABLE_DUMP",
2829
"MLIR_DUMP_PATH",

lib/Target/LLVMIR/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
add_triton_library(TritonLLVMIR
22
LLVMDIScope.cpp
3+
LLVMDILocalVariable.cpp
34
LLVMIRBreakPhiStruct.cpp
45

56
DEPENDS
Lines changed: 268 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,268 @@
1+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
2+
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
3+
#include "mlir/IR/BuiltinAttributes.h"
4+
#include "mlir/Pass/Pass.h"
5+
#include "mlir/Support/LLVM.h"
6+
#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h"
7+
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
8+
#include "mlir/Target/LLVMIR/Export.h"
9+
#include "mlir/Target/LLVMIR/ModuleTranslation.h"
10+
#include "triton/Target/LLVMIR/Passes.h"
11+
#include "llvm/BinaryFormat/Dwarf.h"
12+
#include "llvm/Support/Debug.h"
13+
#include "llvm/Support/Path.h"
14+
15+
// #include "mlir/Conversion/UBToLLVM/UBToLLVM.h"
16+
//===----------------------------------------------------------------------===//
17+
// This file implements a pass to add ... to LLVM operations, and ...
18+
//===----------------------------------------------------------------------===//
19+
20+
namespace mlir {
21+
22+
#define DEBUG_TYPE "name-preservation"
23+
24+
#define GEN_PASS_DEF_LLVMDILOCALVARIABLE
25+
#include "triton/Target/LLVMIR/Passes.h.inc"
26+
27+
struct LLVMDILocalVariablePass
28+
: public impl::LLVMDILocalVariableBase<LLVMDILocalVariablePass> {
29+
30+
void fuseDILocalVariable(Operation *op) {
31+
if (op->getNumResults() == 0) {
32+
return;
33+
}
34+
35+
MLIRContext *context = op->getContext();
36+
OpBuilder builder(context);
37+
Location loc = op->getLoc();
38+
39+
// if the location is a NameLoc, a.k.a it defines a value, then insert a
40+
// dbg-value intrinsic after the op
41+
if (auto nameLoc = dyn_cast<NameLoc>(loc)) {
42+
Location childLoc = nameLoc.getChildLoc();
43+
StringAttr nameAttr = nameLoc.getName();
44+
45+
// also see reference of operation construction from
46+
// mlir/lib/Target/LLVMIR/ModuleImport.cpp which translated llvm::Module
47+
// into mlir::LLVM::Operation
48+
49+
// TODO: Those instantiation using defult is necessary for first viable
50+
// result, but no meaning for now
51+
LLVM::DIFileAttr diFileAttr =
52+
LLVM::DIFileAttr::get(context, "<unknown>", "<unknown>");
53+
54+
// Extracting type info into DITypeAttr
55+
mlir::Type resultType = op->getResult(0).getType();
56+
if (isa<LLVM::LLVMVoidType>(resultType)) {
57+
// we cannot allow void type to be noted as data type, otherwise trigger
58+
// later assertion fault
59+
return;
60+
}
61+
LLVM::DITypeAttr diTypeAttr = convertType(context, resultType);
62+
LLVM::DIFlags diFlags = LLVM::DIFlags::Zero;
63+
64+
// LLVM Dialect to LLVM translation requires DILocalScope when
65+
// DILocalVariable is present
66+
LLVM::DILocalScopeAttr diLocalScopeAttr =
67+
dyn_cast<LLVM::DILocalScopeAttr>(diSubprogramAttr);
68+
69+
// DILocalVariable of LLVM Dialect, which will be translated to LLVM IR's
70+
// llvm::DILocalVariable
71+
LLVM::DILocalVariableAttr diLocalVarAttr;
72+
73+
// TODO: current parameter only for first viable result for now
74+
diLocalVarAttr = LLVM::DILocalVariableAttr::get(
75+
context, diLocalScopeAttr, nameAttr, diFileAttr, 0, 0, 0, diTypeAttr,
76+
diFlags);
77+
78+
LLVM::DIExpressionAttr diExprAttr = LLVM::DIExpressionAttr::get(context);
79+
// Note: must set insertion point before calling create since it will
80+
// automatically insert the op
81+
builder.setInsertionPointAfter(op);
82+
// a subclass of mlir::Value, which is the value defined by this operation
83+
OpResult opResult = op->getResult(0);
84+
// create and insert this call-dbg-value intrinsic after the op
85+
Operation *dbgOp = builder.create<LLVM::DbgValueOp>(
86+
childLoc, opResult, diLocalVarAttr, diExprAttr);
87+
}
88+
}
89+
90+
auto calcBitWidth(mlir::Type type) -> std::optional<unsigned> {
91+
if (type.isIntOrFloat()) {
92+
return type.getIntOrFloatBitWidth();
93+
} else if (mlir::isa<mlir::VectorType>(type)) {
94+
auto vectorType = dyn_cast<mlir::VectorType>(type);
95+
llvm::ArrayRef<int64_t> shape = vectorType.getShape();
96+
mlir::Type elementType = vectorType.getElementType();
97+
llvm::ArrayRef<bool> scalableDims = vectorType.getScalableDims();
98+
unsigned size = 1;
99+
for (auto i : shape) {
100+
size *= i;
101+
}
102+
103+
if (auto elementTypeSize = calcBitWidth(elementType);
104+
elementTypeSize.has_value()) {
105+
return size * elementTypeSize.value();
106+
}
107+
}
108+
109+
return std::nullopt;
110+
}
111+
112+
// Note: mlir does not provided any built-in conversion from mlir::Type to
113+
// mlir::LLVM::DITypeAttr
114+
LLVM::DITypeAttr convertType(MLIRContext *context, mlir::Type type) {
115+
if (type.isInteger(1)) {
116+
return LLVM::DIBasicTypeAttr::get(context, llvm::dwarf::DW_TAG_base_type,
117+
mlir::StringAttr::get(context, "bool"),
118+
type.getIntOrFloatBitWidth(),
119+
llvm::dwarf::DW_ATE_boolean);
120+
}
121+
if (type.isInteger()) {
122+
return LLVM::DIBasicTypeAttr::get(context, llvm::dwarf::DW_TAG_base_type,
123+
mlir::StringAttr::get(context, "int"),
124+
type.getIntOrFloatBitWidth(),
125+
llvm::dwarf::DW_ATE_signed);
126+
} else if (type.isF16()) {
127+
return LLVM::DIBasicTypeAttr::get(context, llvm::dwarf::DW_TAG_base_type,
128+
mlir::StringAttr::get(context, "half"),
129+
type.getIntOrFloatBitWidth(),
130+
llvm::dwarf::DW_ATE_float);
131+
} else if (type.isF32()) {
132+
return LLVM::DIBasicTypeAttr::get(context, llvm::dwarf::DW_TAG_base_type,
133+
mlir::StringAttr::get(context, "float"),
134+
type.getIntOrFloatBitWidth(),
135+
llvm::dwarf::DW_ATE_float);
136+
} else if (type.isF64()) {
137+
return LLVM::DIBasicTypeAttr::get(
138+
context, llvm::dwarf::DW_TAG_base_type,
139+
mlir::StringAttr::get(context, "double"),
140+
type.getIntOrFloatBitWidth(), llvm::dwarf::DW_ATE_float);
141+
} else if (mlir::isa<mlir::VectorType>(type)) {
142+
if (auto vectorTypeSize = calcBitWidth(type);
143+
vectorTypeSize.has_value()) {
144+
return LLVM::DIBasicTypeAttr::get(
145+
context, llvm::dwarf::DW_TAG_base_type,
146+
mlir::StringAttr::get(context, "vector"), vectorTypeSize.value(),
147+
llvm::dwarf::DW_ATE_float);
148+
} else {
149+
// TODO: falling back to unknown_type, perhaps theres a better way to
150+
// handle when element type size is not determined
151+
}
152+
}
153+
154+
return LLVM::DIBasicTypeAttr::get(
155+
context, llvm::dwarf::DW_TAG_base_type,
156+
mlir::StringAttr::get(context, "unknown_type"), 0,
157+
llvm::dwarf::DW_ATE_signed);
158+
}
159+
160+
/// Attempt to extract a filename for the given loc.
161+
FileLineColLoc extractFileLoc(Location loc) {
162+
if (auto fileLoc = dyn_cast<FileLineColLoc>(loc))
163+
return fileLoc;
164+
if (auto nameLoc = dyn_cast<NameLoc>(loc))
165+
return extractFileLoc(nameLoc.getChildLoc());
166+
if (auto opaqueLoc = dyn_cast<OpaqueLoc>(loc))
167+
return extractFileLoc(opaqueLoc.getFallbackLocation());
168+
if (auto fusedLoc = dyn_cast<FusedLoc>(loc))
169+
return extractFileLoc(fusedLoc.getLocations().front());
170+
if (auto callerLoc = dyn_cast<CallSiteLoc>(loc))
171+
return extractFileLoc(callerLoc.getCaller());
172+
StringAttr unknownFile =
173+
mlir::StringAttr::get(loc.getContext(), "<unknown>");
174+
return mlir::FileLineColLoc::get(unknownFile, 0, 0);
175+
}
176+
177+
// Follow the same logic as LLVMDIScopePass to construct a subprogram scope
178+
LLVM::DISubprogramAttr getDISubprogramAttr(LLVM::LLVMFuncOp funcOp) {
179+
Location loc = funcOp.getLoc();
180+
if (auto fusedSubprogramAttr =
181+
loc->findInstanceOf<mlir::FusedLocWith<LLVM::DISubprogramAttr>>())
182+
return fusedSubprogramAttr.getMetadata();
183+
184+
MLIRContext *context = &getContext();
185+
186+
// To find a DICompileUnitAttr attached to a parent (the module for
187+
// example), otherwise create a default one.
188+
LLVM::DICompileUnitAttr compileUnitAttr;
189+
if (ModuleOp module = funcOp->getParentOfType<ModuleOp>()) {
190+
auto fusedCompileUnitAttr =
191+
module->getLoc()
192+
->findInstanceOf<mlir::FusedLocWith<LLVM::DICompileUnitAttr>>();
193+
if (fusedCompileUnitAttr)
194+
compileUnitAttr = fusedCompileUnitAttr.getMetadata();
195+
}
196+
197+
// Filename, line and colmun to associate to the function.
198+
LLVM::DIFileAttr fileAttr;
199+
int64_t line = 1, col = 1;
200+
FileLineColLoc fileLoc = extractFileLoc(loc);
201+
if (!fileLoc && compileUnitAttr) {
202+
fileAttr = compileUnitAttr.getFile();
203+
} else if (!fileLoc) {
204+
fileAttr = LLVM::DIFileAttr::get(context, "<unknown>", "");
205+
} else {
206+
line = fileLoc.getLine();
207+
col = fileLoc.getColumn();
208+
StringRef inputFilePath = fileLoc.getFilename().getValue();
209+
fileAttr = LLVM::DIFileAttr::get(
210+
context, llvm::sys::path::filename(inputFilePath),
211+
llvm::sys::path::parent_path(inputFilePath));
212+
}
213+
214+
auto subroutineTypeAttr =
215+
LLVM::DISubroutineTypeAttr::get(context, llvm::dwarf::DW_CC_normal, {});
216+
217+
DistinctAttr distinctId;
218+
auto subprogramFlags = LLVM::DISubprogramFlags::Optimized;
219+
if (!funcOp.isExternal()) {
220+
distinctId = mlir::DistinctAttr::create(mlir::UnitAttr::get(context));
221+
if (!compileUnitAttr) {
222+
compileUnitAttr = LLVM::DICompileUnitAttr::get(
223+
distinctId, llvm::dwarf::DW_LANG_C, fileAttr,
224+
StringAttr::get(context, "triton"),
225+
/*isOptimized=*/true, LLVM::DIEmissionKind::Full);
226+
}
227+
subprogramFlags = subprogramFlags | LLVM::DISubprogramFlags::Definition;
228+
} else {
229+
compileUnitAttr = {};
230+
}
231+
232+
StringAttr funcNameAttr = funcOp.getNameAttr();
233+
// Note that scopeline is set differently from LLVM's
234+
// DIScopeForLLVMFuncOpPass. I don't find reasons why scopeline should be
235+
// the column offset
236+
auto subprogramAttr = LLVM::DISubprogramAttr::get(
237+
context, distinctId, compileUnitAttr, fileAttr, funcNameAttr,
238+
funcNameAttr, fileAttr, /*line=*/line, /*scopeline=*/line,
239+
subprogramFlags, subroutineTypeAttr, /*retainNodes=*/{},
240+
/*annotations=*/{});
241+
242+
return subprogramAttr;
243+
}
244+
245+
// construct a subprogram of an operation by using its parent function's
246+
// DISubprogramAttr construction
247+
LLVM::DISubprogramAttr getDISubprogramAttr(Operation op) {
248+
auto funcOp = op.getParentOfType<LLVM::LLVMFuncOp>();
249+
return getDISubprogramAttr(funcOp);
250+
}
251+
252+
// set it while traversing into a function
253+
LLVM::DISubprogramAttr diSubprogramAttr;
254+
255+
void runOnOperation() override {
256+
Operation *op = getOperation();
257+
258+
getOperation()->walk<WalkOrder::PreOrder>([&](Operation *op) -> void {
259+
if (isa<LLVM::LLVMFuncOp>(op)) {
260+
diSubprogramAttr = getDISubprogramAttr(cast<LLVM::LLVMFuncOp>(op));
261+
} else {
262+
fuseDILocalVariable(op);
263+
}
264+
});
265+
}
266+
};
267+
268+
} // namespace mlir

lib/Target/LLVMIR/LLVMDIScope.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include "mlir/Pass/Pass.h"
44
#include "mlir/Support/LLVM.h"
55
#include "triton/Target/LLVMIR/Passes.h"
6+
#include "triton/Tools/Sys/GetEnv.hpp"
67
#include "llvm/BinaryFormat/Dwarf.h"
78
#include "llvm/Support/Debug.h"
89
#include "llvm/Support/Path.h"
@@ -91,7 +92,13 @@ struct LLVMDIScopePass : public impl::LLVMDIScopeBase<LLVMDIScopePass> {
9192
compileUnitAttr = LLVM::DICompileUnitAttr::get(
9293
distinctId, llvm::dwarf::DW_LANG_C, fileAttr,
9394
StringAttr::get(context, "triton"),
94-
/*isOptimized=*/true, LLVM::DIEmissionKind::LineTablesOnly);
95+
/*isOptimized=*/true,
96+
triton::tools::getBoolEnv("LLVM_EXTRACT_DI_LOCAL_VARIABLES")
97+
? LLVM::DIEmissionKind::Full
98+
: LLVM::DIEmissionKind::
99+
LineTablesOnly); // DIEmissionKind::Full is required by
100+
// emiting ptx with dbg-metadata
101+
// (otherwise assertion fail)
95102
}
96103
subprogramFlags = subprogramFlags | LLVM::DISubprogramFlags::Definition;
97104
} else {

0 commit comments

Comments
 (0)