Skip to content

Commit ecc9bd4

Browse files
Revert "[frontend] Remove Complex Regex for MLIR Parsing (#4924)"
This reverts commit 038cbc5.
1 parent 2cbd321 commit ecc9bd4

File tree

4 files changed

+32
-181
lines changed

4 files changed

+32
-181
lines changed

python/src/ir.cc

Lines changed: 0 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
2525
#include "mlir/Transforms/LocationSnapshot.h"
2626

27-
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
2827
#include "triton/Dialect/Triton/IR/Dialect.h"
2928
#include "triton/Dialect/Triton/IR/Types.h"
3029
#include "triton/Dialect/Triton/IR/Utility.h"
@@ -503,16 +502,6 @@ void init_triton_ir(py::module &&m) {
503502
[](ModuleOp &self, FuncOp &funcOp) -> void {
504503
self.push_back(funcOp);
505504
})
506-
.def("get_entry_func_name",
507-
[](ModuleOp &self) -> std::string {
508-
for (auto &op : self.getOps()) {
509-
if (auto func = dyn_cast<FuncOp>(op)) {
510-
if (LLVM::isKernel(func))
511-
return func.getName().str();
512-
}
513-
}
514-
return "";
515-
})
516505
.def("has_function",
517506
[](ModuleOp &self, std::string &funcName) -> bool {
518507
if (self.lookupSymbol(funcName))
@@ -523,43 +512,6 @@ void init_triton_ir(py::module &&m) {
523512
[](ModuleOp &self, std::string &funcName) -> FuncOp {
524513
return self.lookupSymbol<FuncOp>(funcName);
525514
})
526-
/*
527-
* def ty_to_cpp(ty) is the consumer of this function.
528-
* If the type is a ptr it expects ty[0] == '*', else the type itself.
529-
*/
530-
531-
.def("get_function_signature",
532-
[](ModuleOp &self, FuncOp &func) -> std::vector<std::string> {
533-
std::vector<std::string> strVec;
534-
535-
auto type = func.getFunctionType();
536-
unsigned numArgs = type.getNumInputs();
537-
for (unsigned i = 0; i != numArgs; ++i) {
538-
std::string tempType;
539-
llvm::raw_string_ostream os(tempType);
540-
541-
auto ty = type.getInput(i);
542-
if (auto attributes = func.getCallableArgAttrs()) {
543-
Attribute attr = attributes[i];
544-
// Check for tt.nv_tma_desc = 1
545-
if (auto dAttr = dyn_cast<DictionaryAttr>(attr)) {
546-
if (dAttr.contains("tt.nv_tma_desc")) {
547-
strVec.push_back("nvTmaDesc");
548-
continue;
549-
}
550-
}
551-
}
552-
if (auto ptrType = dyn_cast<PointerType>(ty)) {
553-
auto pType = ptrType.getPointeeType();
554-
os << "*";
555-
pType.print(os);
556-
} else {
557-
ty.print(os);
558-
}
559-
strVec.push_back(tempType);
560-
}
561-
return strVec;
562-
})
563515
.def("get_int_attr",
564516
[](ModuleOp &self, std::string name) -> py::object {
565517
auto ret = self->getAttrOfType<IntegerAttr>(name);

python/test/unit/tools/test_irsource.py

Lines changed: 0 additions & 93 deletions
This file was deleted.

python/triton/compiler/__init__.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,4 @@
1-
from .compiler import CompiledKernel, ASTSource, IRSource, compile, make_backend, LazyDict
1+
from .compiler import CompiledKernel, ASTSource, compile, make_backend, LazyDict
22
from .errors import CompilationError
33

4-
__all__ = [
5-
"compile", "make_backend", "ASTSource", "IRSource", "AttrsDescriptor", "CompiledKernel", "CompilationError",
6-
"LazyDict"
7-
]
4+
__all__ = ["compile", "make_backend", "ASTSource", "AttrsDescriptor", "CompiledKernel", "CompilationError", "LazyDict"]

python/triton/compiler/compiler.py

Lines changed: 30 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,19 @@
2525
# - (\((?:%\w+: \S+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\)) : match a pair of parentheses enclosing
2626
# zero or more arguments separated by commas, and capture it as group 2 (the argument list)
2727
# - (attributes \{[\S\s]+\})? : optionally match attributes enclosed in braces and capture it as group 3
28+
mlir_prototype_pattern = r"^\s*tt\.func\s+(?:public\s+)?(@\w+)(\((?:%\w+: [\S\s]+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\))\s*(attributes \{[\S\s]+\})?\s+\{\s*$"
2829
ptx_prototype_pattern = r"\.(?:visible|extern)\s+\.(?:entry|func)\s+(\w+)\s*\(([^)]*)\)"
2930
prototype_pattern = {
31+
"ttir": mlir_prototype_pattern,
32+
"ttgir": mlir_prototype_pattern,
3033
"ptx": ptx_prototype_pattern,
3134
}
3235

36+
mlir_arg_type_pattern = r'%\w+: ((?:[^,\s<)]+|<[^>]+>)+(?: {[^}]+})?),?'
3337
ptx_arg_type_pattern = r"\.param\s+\.(\w+)"
3438
arg_type_pattern = {
39+
"ttir": mlir_arg_type_pattern,
40+
"ttgir": mlir_arg_type_pattern,
3541
"ptx": ptx_arg_type_pattern,
3642
}
3743

@@ -49,6 +55,16 @@ def convert_type_repr(x):
4955
return x
5056

5157

58+
def _get_num_warps_from_ir_str(src: str):
59+
ttgir_num_warps_pattern = r'"triton_gpu.num-warps"\s?=\s?(\d+)\s?:'
60+
# TODO(jlebar): Using a regex to get num-warps is a hack, and will break if
61+
# e.g. someone has an instruction (not module) attribute named "num-warps".
62+
num_warps_matches = re.findall(ttgir_num_warps_pattern, src)
63+
assert len(num_warps_matches) == 1, "Expected exactly one match for num_warps"
64+
num_warps = int(num_warps_matches[0])
65+
return num_warps
66+
67+
5268
class ASTSource:
5369

5470
def __init__(self, fn, signature, constants=None, attrs=None) -> None:
@@ -91,41 +107,28 @@ def parse_options(self):
91107

92108
class IRSource:
93109

94-
def __init__(self, path, context):
110+
def __init__(self, path):
95111
self.path = path
96112
path = Path(path)
97113
self.ext = path.suffix[1:]
98114
self.src = path.read_text()
99-
ir.load_dialects(context)
100-
101-
# We don't have a easy-to-use PTX parser that we can use, so keep that regex for now.
102-
# TODO - replace with a proper parser
103-
if self.ext == "ptx":
104-
match = re.search(prototype_pattern[self.ext], self.src, re.MULTILINE)
105-
self.name = match.group(1)
106-
signature = match.group(2)
107-
types = re.findall(arg_type_pattern[self.ext], signature)
108-
self.signature = {k: convert_type_repr(ty) for k, ty in enumerate(types)}
109-
else:
110-
self.module = ir.parse_mlir_module(self.path, context)
111-
fn_name = self.module.get_entry_func_name()
112-
self.name = "@" + fn_name
113-
funcOp = self.module.get_function(fn_name)
114-
func_ty = self.module.get_function_signature(funcOp)
115-
self.signature = {k: ty for k, ty in enumerate(func_ty)}
115+
match = re.search(prototype_pattern[self.ext], self.src, re.MULTILINE)
116+
self.name = match.group(1)
117+
signature = match.group(2)
118+
types = re.findall(arg_type_pattern[self.ext], signature)
119+
self.signature = {k: convert_type_repr(ty) for k, ty in enumerate(types)}
116120

117121
def hash(self):
118122
return hashlib.sha256(self.src.encode("utf-8")).hexdigest()
119123

120124
def make_ir(self, options, codegen_fns, module_map, context):
121-
self.module.context = context
122-
return self.module
125+
module = ir.parse_mlir_module(self.path, context)
126+
module.context = context
127+
return module
123128

124129
def parse_options(self):
125130
if self.ext == "ttgir":
126-
num_warps = self.module.get_int_attr("triton_gpu.num-warps")
127-
assert num_warps is not None, "Unable to parse triton_gpu.num-warps attribute"
128-
return {'num_warps': num_warps}
131+
return {'num_warps': _get_num_warps_from_ir_str(self.src)}
129132
return dict()
130133

131134

@@ -222,9 +225,7 @@ def compile(src, target=None, options=None):
222225
# create backend
223226
if ir_source:
224227
assert isinstance(src, str), "source must be either AST or a filepath"
225-
context = ir.context()
226-
src = IRSource(src, context)
227-
228+
src = IRSource(src)
228229
extra_options = src.parse_options()
229230
options = backend.parse_options(dict(options or dict(), **extra_options))
230231
# create cache manager
@@ -265,15 +266,9 @@ def compile(src, target=None, options=None):
265266
# when the source is an IR file, don't apply the passes related to this stage. This makes it easier to write IR level tests.
266267
if ir_source:
267268
first_stage += 1
268-
269-
if not isinstance(src, IRSource):
270-
context = ir.context()
271-
ir.load_dialects(context)
272-
backend.load_dialects(context)
273-
else:
274-
# For IRSource, we have already grabbed the context + called ir.load_dialects
275-
# just need to load the dialects for the backend.
276-
backend.load_dialects(context)
269+
context = ir.context()
270+
ir.load_dialects(context)
271+
backend.load_dialects(context)
277272
codegen_fns = backend.get_codegen_implementation()
278273
module_map = backend.get_module_map()
279274
try:

0 commit comments

Comments
 (0)