Skip to content

Commit 038cbc5

Browse files
authored
[frontend] Remove Complex Regex for MLIR Parsing (#4924)
There were a number of complex regexes used for parsing MLIR in the python frontend. For maintainability reasons, it is likely better to just expose the MLIR bindings to python and use those instead. The PTX regex is left in place because we don't have an easy way to parse PTX (for now).
1 parent 38a11b8 commit 038cbc5

File tree

4 files changed

+181
-32
lines changed

4 files changed

+181
-32
lines changed

python/src/ir.cc

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
2525
#include "mlir/Transforms/LocationSnapshot.h"
2626

27+
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
2728
#include "triton/Dialect/Triton/IR/Dialect.h"
2829
#include "triton/Dialect/Triton/IR/Types.h"
2930
#include "triton/Dialect/Triton/IR/Utility.h"
@@ -491,6 +492,16 @@ void init_triton_ir(py::module &&m) {
491492
[](ModuleOp &self, FuncOp &funcOp) -> void {
492493
self.push_back(funcOp);
493494
})
495+
.def("get_entry_func_name",
496+
[](ModuleOp &self) -> std::string {
497+
for (auto &op : self.getOps()) {
498+
if (auto func = dyn_cast<FuncOp>(op)) {
499+
if (LLVM::isKernel(func))
500+
return func.getName().str();
501+
}
502+
}
503+
return "";
504+
})
494505
.def("has_function",
495506
[](ModuleOp &self, std::string &funcName) -> bool {
496507
if (self.lookupSymbol(funcName))
@@ -501,6 +512,43 @@ void init_triton_ir(py::module &&m) {
501512
[](ModuleOp &self, std::string &funcName) -> FuncOp {
502513
return self.lookupSymbol<FuncOp>(funcName);
503514
})
515+
/*
516+
* def ty_to_cpp(ty) is the consumer of this function.
517+
* If the type is a ptr it expects ty[0] == '*', else the type itself.
518+
*/
519+
520+
.def("get_function_signature",
521+
[](ModuleOp &self, FuncOp &func) -> std::vector<std::string> {
522+
std::vector<std::string> strVec;
523+
524+
auto type = func.getFunctionType();
525+
unsigned numArgs = type.getNumInputs();
526+
for (unsigned i = 0; i != numArgs; ++i) {
527+
std::string tempType;
528+
llvm::raw_string_ostream os(tempType);
529+
530+
auto ty = type.getInput(i);
531+
if (auto attributes = func.getCallableArgAttrs()) {
532+
Attribute attr = attributes[i];
533+
// Check for tt.nv_tma_desc = 1
534+
if (auto dAttr = dyn_cast<DictionaryAttr>(attr)) {
535+
if (dAttr.contains("tt.nv_tma_desc")) {
536+
strVec.push_back("nvTmaDesc");
537+
continue;
538+
}
539+
}
540+
}
541+
if (auto ptrType = dyn_cast<PointerType>(ty)) {
542+
auto pType = ptrType.getPointeeType();
543+
os << "*";
544+
pType.print(os);
545+
} else {
546+
ty.print(os);
547+
}
548+
strVec.push_back(tempType);
549+
}
550+
return strVec;
551+
})
504552
.def("get_int_attr",
505553
[](ModuleOp &self, std::string name) -> py::object {
506554
auto ret = self->getAttrOfType<IntegerAttr>(name);
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
import tempfile
2+
import triton
3+
from triton.compiler import IRSource
4+
from triton._C.libtriton import ir
5+
6+
target = triton.runtime.driver.active.get_current_target()
7+
8+
9+
def test_mlir_attribute_parsing() -> None:
10+
'''
11+
Tests that MLIR attributes are parsed correctly from input ttir/ttgir.
12+
13+
Checks for the following:
14+
1. Name and type signature are parsed correctly
15+
2. _get_num_warps_from_ir_str() works
16+
3. tt.nv_tma_desc attribute is parsed correctly
17+
'''
18+
19+
sample_ttgir = r"""
20+
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [8, 1], order = [1, 0]}>
21+
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}>
22+
#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [16, 8]}>
23+
#shared = #triton_gpu.shared<{vec = 4, perPhase = 2, maxPhase = 4, order = [1, 0], hasLeadingOffset = false}>
24+
#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [1, 0], hasLeadingOffset = false}>
25+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, triton_gpu.target = "cuda:80", "triton_gpu.threads-per-warp" = 32 : i32} {
26+
tt.func public @matmul_kernel(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32},
27+
%arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32},
28+
%arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32},
29+
%arg3: i32 {tt.divisibility = 16 : i32},
30+
%arg4: i32 {tt.divisibility = 16 : i32},
31+
%arg5: i32 {tt.divisibility = 16 : i32},
32+
%arg6: i32 {tt.divisibility = 16 : i32},
33+
%arg7: i32 {tt.divisibility = 16 : i32},
34+
%arg8: i32 {tt.divisibility = 16 : i32, tt.nv_tma_desc = 0 : i32},
35+
%desc: !tt.ptr<i8, 0> {tt.nv_tma_desc = 1 : i32}) attributes {noinline = false} {
36+
tt.return
37+
}
38+
}
39+
"""
40+
with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f:
41+
f.write(sample_ttgir)
42+
f.flush()
43+
context = ir.context()
44+
src = IRSource(f.name, context)
45+
46+
# check name and type signature
47+
# should match ty_to_cpp(...)
48+
assert src.signature == \
49+
{0: "*f32", 1: "*f32", 2: "*f32", 3: "i32", \
50+
4: "i32", 5: "i32", 6: "i32", 7: "i32", 8: "nvTmaDesc", 9: "nvTmaDesc"}
51+
assert src.name == "@matmul_kernel"
52+
53+
# check num warps
54+
assert src.parse_options()['num_warps'] == 8
55+
56+
sample_ttgir_vector_add = r"""
57+
#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
58+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} {
59+
tt.func public @add_kernel(%arg0: !tt.ptr<i32> {tt.divisibility = 16 : i32},
60+
%arg1: !tt.ptr<i32> {tt.divisibility = 16 : i32},
61+
%arg2: !tt.ptr<i32> {tt.divisibility = 16 : i32},
62+
%arg3: i32 {tt.divisibility = 16 : i32})
63+
attributes {noinline = false} {
64+
%c1024_i32 = arith.constant 1024 : i32
65+
%0 = tt.get_program_id x : i32
66+
%1 = arith.muli %0, %c1024_i32 : i32
67+
%2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked>
68+
%3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked>
69+
%4 = arith.addi %3, %2 : tensor<1024xi32, #blocked>
70+
%5 = tt.splat %arg3 : i32 -> tensor<1024xi32, #blocked>
71+
%6 = arith.cmpi slt, %4, %5 : tensor<1024xi32, #blocked>
72+
%7 = tt.splat %arg0 : !tt.ptr<i32> -> tensor<1024x!tt.ptr<i32>, #blocked>
73+
%8 = tt.addptr %7, %4 : tensor<1024x!tt.ptr<i32>, #blocked>, tensor<1024xi32, #blocked>
74+
%9 = tt.load %8, %6 : tensor<1024x!tt.ptr<i32>, #blocked>
75+
%10 = tt.splat %arg1 : !tt.ptr<i32> -> tensor<1024x!tt.ptr<i32>, #blocked>
76+
%11 = tt.addptr %10, %4 : tensor<1024x!tt.ptr<i32>, #blocked>, tensor<1024xi32, #blocked>
77+
%12 = tt.load %11, %6 : tensor<1024x!tt.ptr<i32>, #blocked>
78+
%13 = arith.addi %9, %12 : tensor<1024xi32, #blocked>
79+
%14 = tt.splat %arg2 : !tt.ptr<i32> -> tensor<1024x!tt.ptr<i32>, #blocked>
80+
%15 = tt.addptr %14, %4 : tensor<1024x!tt.ptr<i32>, #blocked>, tensor<1024xi32, #blocked>
81+
tt.store %15, %13, %6 : tensor<1024x!tt.ptr<i32>, #blocked>
82+
tt.return
83+
}
84+
}
85+
"""
86+
with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f:
87+
f.write(sample_ttgir_vector_add)
88+
f.flush()
89+
context = ir.context()
90+
src = IRSource(f.name, context)
91+
92+
# now test compilation
93+
triton.compile(f.name, target=target)

python/triton/compiler/__init__.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
from .compiler import CompiledKernel, ASTSource, compile, make_backend, LazyDict
1+
from .compiler import CompiledKernel, ASTSource, IRSource, compile, make_backend, LazyDict
22
from .errors import CompilationError
33

4-
__all__ = ["compile", "make_backend", "ASTSource", "AttrsDescriptor", "CompiledKernel", "CompilationError", "LazyDict"]
4+
__all__ = [
5+
"compile", "make_backend", "ASTSource", "IRSource", "AttrsDescriptor", "CompiledKernel", "CompilationError",
6+
"LazyDict"
7+
]

python/triton/compiler/compiler.py

Lines changed: 35 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -25,19 +25,13 @@
2525
# - (\((?:%\w+: \S+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\)) : match a pair of parentheses enclosing
2626
# zero or more arguments separated by commas, and capture it as group 2 (the argument list)
2727
# - (attributes \{[\S\s]+\})? : optionally match attributes enclosed in braces and capture it as group 3
28-
mlir_prototype_pattern = r"^\s*tt\.func\s+(?:public\s+)?(@\w+)(\((?:%\w+: [\S\s]+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\))\s*(attributes \{[\S\s]+\})?\s+\{\s*$"
2928
ptx_prototype_pattern = r"\.(?:visible|extern)\s+\.(?:entry|func)\s+(\w+)\s*\(([^)]*)\)"
3029
prototype_pattern = {
31-
"ttir": mlir_prototype_pattern,
32-
"ttgir": mlir_prototype_pattern,
3330
"ptx": ptx_prototype_pattern,
3431
}
3532

36-
mlir_arg_type_pattern = r'%\w+: ((?:[^,\s<)]+|<[^>]+>)+(?: {[^}]+})?),?'
3733
ptx_arg_type_pattern = r"\.param\s+\.(\w+)"
3834
arg_type_pattern = {
39-
"ttir": mlir_arg_type_pattern,
40-
"ttgir": mlir_arg_type_pattern,
4135
"ptx": ptx_arg_type_pattern,
4236
}
4337

@@ -55,16 +49,6 @@ def convert_type_repr(x):
5549
return x
5650

5751

58-
def _get_num_warps_from_ir_str(src: str):
59-
ttgir_num_warps_pattern = r'"triton_gpu.num-warps"\s?=\s?(\d+)\s?:'
60-
# TODO(jlebar): Using a regex to get num-warps is a hack, and will break if
61-
# e.g. someone has an instruction (not module) attribute named "num-warps".
62-
num_warps_matches = re.findall(ttgir_num_warps_pattern, src)
63-
assert len(num_warps_matches) == 1, "Expected exactly one match for num_warps"
64-
num_warps = int(num_warps_matches[0])
65-
return num_warps
66-
67-
6852
class ASTSource:
6953

7054
def __init__(self, fn, signature, constants=None, attrs=None) -> None:
@@ -107,28 +91,41 @@ def parse_options(self):
10791

10892
class IRSource:
10993

110-
def __init__(self, path):
94+
def __init__(self, path, context):
11195
self.path = path
11296
path = Path(path)
11397
self.ext = path.suffix[1:]
11498
self.src = path.read_text()
115-
match = re.search(prototype_pattern[self.ext], self.src, re.MULTILINE)
116-
self.name = match.group(1)
117-
signature = match.group(2)
118-
types = re.findall(arg_type_pattern[self.ext], signature)
119-
self.signature = {k: convert_type_repr(ty) for k, ty in enumerate(types)}
99+
ir.load_dialects(context)
100+
101+
# We don't have a easy-to-use PTX parser that we can use, so keep that regex for now.
102+
# TODO - replace with a proper parser
103+
if self.ext == "ptx":
104+
match = re.search(prototype_pattern[self.ext], self.src, re.MULTILINE)
105+
self.name = match.group(1)
106+
signature = match.group(2)
107+
types = re.findall(arg_type_pattern[self.ext], signature)
108+
self.signature = {k: convert_type_repr(ty) for k, ty in enumerate(types)}
109+
else:
110+
self.module = ir.parse_mlir_module(self.path, context)
111+
fn_name = self.module.get_entry_func_name()
112+
self.name = "@" + fn_name
113+
funcOp = self.module.get_function(fn_name)
114+
func_ty = self.module.get_function_signature(funcOp)
115+
self.signature = {k: ty for k, ty in enumerate(func_ty)}
120116

121117
def hash(self):
122118
return hashlib.sha256(self.src.encode("utf-8")).hexdigest()
123119

124120
def make_ir(self, options, codegen_fns, module_map, context):
125-
module = ir.parse_mlir_module(self.path, context)
126-
module.context = context
127-
return module
121+
self.module.context = context
122+
return self.module
128123

129124
def parse_options(self):
130125
if self.ext == "ttgir":
131-
return {'num_warps': _get_num_warps_from_ir_str(self.src)}
126+
num_warps = self.module.get_int_attr("triton_gpu.num-warps")
127+
assert num_warps is not None, "Unable to parse triton_gpu.num-warps attribute"
128+
return {'num_warps': num_warps}
132129
return dict()
133130

134131

@@ -225,7 +222,9 @@ def compile(src, target=None, options=None):
225222
# create backend
226223
if ir_source:
227224
assert isinstance(src, str), "source must be either AST or a filepath"
228-
src = IRSource(src)
225+
context = ir.context()
226+
src = IRSource(src, context)
227+
229228
extra_options = src.parse_options()
230229
options = backend.parse_options(dict(options or dict(), **extra_options))
231230
# create cache manager
@@ -266,9 +265,15 @@ def compile(src, target=None, options=None):
266265
# when the source is an IR file, don't apply the passes related to this stage. This makes it easier to write IR level tests.
267266
if ir_source:
268267
first_stage += 1
269-
context = ir.context()
270-
ir.load_dialects(context)
271-
backend.load_dialects(context)
268+
269+
if not isinstance(src, IRSource):
270+
context = ir.context()
271+
ir.load_dialects(context)
272+
backend.load_dialects(context)
273+
else:
274+
# For IRSource, we have already grabbed the context + called ir.load_dialects
275+
# just need to load the dialects for the backend.
276+
backend.load_dialects(context)
272277
codegen_fns = backend.get_codegen_implementation()
273278
module_map = backend.get_module_map()
274279
try:

0 commit comments

Comments
 (0)