Skip to content

Commit 80947a2

Browse files
authored
[tools/triton-tensor-layout] Allow parsing ttgir files with triton_nvidia_gpu ops (#4686)
If you want to dump layouts read from an MLIR file, and that file contains ops like `triton_nvidia_gpu.warp_group_dot`, this tool needs to know about the `triton_nvidia_gpu` dialect, or else it will throw an error about not finding the dialect
1 parent 1e88441 commit 80947a2

File tree

2 files changed

+11
-5
lines changed

2 files changed

+11
-5
lines changed

bin/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,5 +95,9 @@ export_executable_symbols_for_plugins(triton-llvm-opt)
9595
add_llvm_executable(triton-tensor-layout triton-tensor-layout.cpp PARTIAL_SOURCES_INTENDED)
9696
target_link_libraries(triton-tensor-layout PRIVATE
9797
TritonGPUIR
98+
TritonNvidiaGPUIR
9899
${triton_libs}
100+
${conversion_libs}
101+
${dialect_libs}
102+
TritonTestAnalysis
99103
)

bin/triton-tensor-layout.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1+
#include "RegisterTritonDialects.h"
2+
13
#include "mlir/AsmParser/AsmParser.h"
24
#include "mlir/AsmParser/AsmParserState.h"
35
#include "mlir/IR/MLIRContext.h"
46

57
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
8+
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
69

710
#include "llvm/Support/CommandLine.h"
811
#include "llvm/Support/ErrorOr.h"
@@ -114,7 +117,7 @@ LogicalResult printLayoutFromFile(MLIRContext *context, StringRef filename,
114117
return failure();
115118
}
116119

117-
auto printLambda = [&](StringRef name, Attribute attr) {
120+
auto printLambda = [&](StringRef name, mlir::Attribute attr) {
118121
ss << "Print layout attribute: #" << name << " = " << attr << "\n";
119122

120123
auto rankedTensorTy = RankedTensorType::get(
@@ -155,7 +158,7 @@ LogicalResult printLayoutFromString(MLIRContext *context,
155158
if (layoutAttrStr.empty())
156159
return success();
157160

158-
Attribute layout = parseAttribute(layoutAttrStr, context);
161+
mlir::Attribute layout = parseAttribute(layoutAttrStr, context);
159162
if (!layout) {
160163
llvm::errs() << "Invalid layout attribute: " << layoutAttrStr << "\n";
161164
return failure();
@@ -178,8 +181,7 @@ int main(int argc, char **argv) {
178181
cl::ParseCommandLineOptions(argc, argv, "tensor layout printer\n");
179182

180183
DialectRegistry registry;
181-
// Register all dialects that can print tensor layout.
182-
registry.insert<triton::gpu::TritonGPUDialect>();
184+
registerTritonDialects(registry);
183185

184186
MLIRContext ctx(registry);
185187
ctx.loadAllAvailableDialects();
@@ -189,7 +191,7 @@ int main(int argc, char **argv) {
189191
return 1;
190192
}
191193

192-
Type parsedTy = parseType(TensorStr, &ctx);
194+
mlir::Type parsedTy = parseType(TensorStr, &ctx);
193195
if (!parsedTy) {
194196
llvm::errs() << "Fail to parse the tensor type argument: " << TensorStr
195197
<< "\n";

0 commit comments

Comments
 (0)