Skip to content

Commit e24b0d8

Browse files
Dark Knightmeta-codesync[bot]
authored andcommitted
Revert D85686892 (#530)
Summary: Pull Request resolved: #530 This diff reverts D85686892 /data/users/daohang/fbtriton/third_party/tlx/tutorials/blackwell-gemm-ws.py:116:44: error: Result has an invalid layout: Could not determine the number of warps per CTA. Operation is not in a context with `ttg.num-warps`. result = tlx.local_load(acc_tmem_subslice1) Depends on D85686892 Reviewed By: dshi7 Differential Revision: D85779224 fbshipit-source-id: 464f5efe297067606e21d3b54a41c5aade777ae1
1 parent 774fe01 commit e24b0d8

File tree

2 files changed

+57
-69
lines changed

2 files changed

+57
-69
lines changed

python/src/ir.cc

Lines changed: 54 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -88,23 +88,6 @@ llvm::raw_ostream &mlir_dumps_or_dbgs() {
8888
}
8989
}
9090

91-
// Function to parse a comma-separated string into a vector of C-style strings
92-
llvm::SmallVector<const char *, 3>
93-
parseCommaSeparatedValues(const std::string &input,
94-
llvm::SmallVector<std::string, 3> &storage) {
95-
llvm::SmallVector<StringRef, 3> split;
96-
llvm::SmallVector<const char *, 3> result;
97-
StringRef(input.c_str()).split(split, ',');
98-
llvm::transform(split, std::back_inserter(result), [&storage](StringRef str) {
99-
// StringRefs are not always null-terminated.
100-
// The purpose for this storage pattern is to
101-
// produce a collection of C-strings that are.
102-
storage.push_back(str.str());
103-
return storage.back().c_str();
104-
});
105-
return result;
106-
}
107-
10891
// Run the pass manager under a source manager diagnostic handler, which
10992
// enables emitted MLIR diagnostics to directly reference Python source
11093
// code. This diagnostic handler supports filtering diagnostic info by
@@ -141,43 +124,6 @@ struct TritonSourceMgrDiagnosticHandler : public SourceMgrDiagnosticHandler {
141124
llvm::SourceMgr sourceMgr;
142125
};
143126

144-
TritonSourceMgrDiagnosticHandler
145-
setupTritonDiagnosticHandler(MLIRContext *context) {
146-
bool showOperations = false, showStacktraces = false, showRemarks = false,
147-
showWarnings = false;
148-
149-
if (auto enableDiagnostics =
150-
triton::tools::getStrEnv("MLIR_ENABLE_DIAGNOSTICS");
151-
!enableDiagnostics.empty()) {
152-
llvm::SmallVector<std::string, 3> storage;
153-
parseCommaSeparatedValues(enableDiagnostics, storage);
154-
for (auto &str : storage) {
155-
if (str == "warnings") {
156-
showWarnings = true;
157-
} else if (str == "remarks") {
158-
showRemarks = true;
159-
} else if (str == "stacktraces") {
160-
showStacktraces = true;
161-
} else if (str == "operations") {
162-
showOperations = true;
163-
}
164-
// we show errors by default, so no need to set it
165-
}
166-
}
167-
168-
DiagnosticSeverity minSeverity =
169-
showWarnings ? DiagnosticSeverity::Warning : DiagnosticSeverity::Error;
170-
minSeverity = showRemarks ? DiagnosticSeverity::Remark : minSeverity;
171-
172-
context->printOpOnDiagnostic(showOperations);
173-
context->printStackTraceOnDiagnostic(showStacktraces);
174-
if (showStacktraces) {
175-
context->disableMultithreading();
176-
}
177-
178-
return TritonSourceMgrDiagnosticHandler(context, minSeverity);
179-
}
180-
181127
std::string locationToString(Location loc) {
182128
std::string str;
183129
llvm::raw_string_ostream os(str);
@@ -186,6 +132,23 @@ std::string locationToString(Location loc) {
186132
return str;
187133
}
188134

135+
// Function to parse a comma-separated string into a vector of C-style strings
136+
llvm::SmallVector<const char *, 3>
137+
parseCommaSeparatedValues(const std::string &input,
138+
llvm::SmallVector<std::string, 3> &storage) {
139+
llvm::SmallVector<StringRef, 3> split;
140+
llvm::SmallVector<const char *, 3> result;
141+
StringRef(input.c_str()).split(split, ',');
142+
llvm::transform(split, std::back_inserter(result), [&storage](StringRef str) {
143+
// StringRefs are not always null-terminated.
144+
// The purpose for this storage pattern is to
145+
// produce a collection of C-strings that are.
146+
storage.push_back(str.str());
147+
return storage.back().c_str();
148+
});
149+
return result;
150+
}
151+
189152
void outputWarning(Location loc, const std::string &msg) {
190153
std::string locStr = locationToString(loc);
191154

@@ -731,12 +694,7 @@ void init_triton_ir(py::module &&m) {
731694
.def("walk",
732695
[](ModuleOp &self, const std::function<void(Operation *)> &fn) {
733696
self.walk(fn);
734-
})
735-
.def("verify_with_diagnostics", [](ModuleOp &self) {
736-
TritonSourceMgrDiagnosticHandler handler =
737-
setupTritonDiagnosticHandler(self.getContext());
738-
return succeeded(verify(self.getOperation()));
739-
});
697+
});
740698

741699
m.def("make_attr", [](const std::vector<int> &values, MLIRContext &context) {
742700
return mlir::cast<Attribute>(DenseIntElementsAttr::get(
@@ -1965,8 +1923,42 @@ void init_triton_ir(py::module &&m) {
19651923
self.enableTiming();
19661924
}
19671925

1968-
TritonSourceMgrDiagnosticHandler diagHandler =
1969-
setupTritonDiagnosticHandler(context);
1926+
// setting up diagnostics
1927+
bool showOperations = false, showStacktraces = false,
1928+
showRemarks = false, showWarnings = false;
1929+
1930+
if (auto enableDiagnostics =
1931+
triton::tools::getStrEnv("MLIR_ENABLE_DIAGNOSTICS");
1932+
!enableDiagnostics.empty()) {
1933+
llvm::SmallVector<std::string, 3> storage;
1934+
parseCommaSeparatedValues(enableDiagnostics, storage);
1935+
for (auto &str : storage) {
1936+
if (str == "warnings") {
1937+
showWarnings = true;
1938+
} else if (str == "remarks") {
1939+
showRemarks = true;
1940+
} else if (str == "stacktraces") {
1941+
showStacktraces = true;
1942+
} else if (str == "operations") {
1943+
showOperations = true;
1944+
}
1945+
// we show errors by default, so no need to set it
1946+
}
1947+
}
1948+
1949+
DiagnosticSeverity minSeverity = showWarnings
1950+
? DiagnosticSeverity::Warning
1951+
: DiagnosticSeverity::Error;
1952+
minSeverity =
1953+
showRemarks ? DiagnosticSeverity::Remark : minSeverity;
1954+
1955+
TritonSourceMgrDiagnosticHandler diagHandler(context, minSeverity);
1956+
1957+
context->printOpOnDiagnostic(showOperations);
1958+
context->printStackTraceOnDiagnostic(showStacktraces);
1959+
if (showStacktraces) {
1960+
context->disableMultithreading();
1961+
}
19701962
if (failed(self.run(mod.getOperation())))
19711963
throw std::runtime_error("PassManager::run failed");
19721964
},

python/triton/compiler/code_generator.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1602,11 +1602,7 @@ def ast_to_ttir(fn, src, context, options, codegen_fns, module_map, module=None)
16021602
jit_fn=fn, is_kernel=True, file_name=file_name, begin_line=begin_line, options=options,
16031603
codegen_fns=codegen_fns, module_map=module_map, module=module, is_gluon=fn.is_gluon())
16041604
generator.visit(fn.parse())
1605-
module = generator.module
1605+
ret = generator.module
16061606
# module takes ownership of the context
1607-
module.context = context
1608-
if not module.verify_with_diagnostics():
1609-
if not fn.is_gluon():
1610-
print(module)
1611-
raise RuntimeError("error encountered during parsing")
1612-
return module
1607+
ret.context = context
1608+
return ret

0 commit comments

Comments
 (0)