2525# - (\((?:%\w+: \S+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\)) : match a pair of parentheses enclosing
2626# zero or more arguments separated by commas, and capture it as group 2 (the argument list)
2727# - (attributes \{[\S\s]+\})? : optionally match attributes enclosed in braces and capture it as group 3
28+ mlir_prototype_pattern = r"^\s*tt\.func\s+(?:public\s+)?(@\w+)(\((?:%\w+: [\S\s]+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\))\s*(attributes \{[\S\s]+\})?\s+\{\s*$"
2829ptx_prototype_pattern = r"\.(?:visible|extern)\s+\.(?:entry|func)\s+(\w+)\s*\(([^)]*)\)"
2930prototype_pattern = {
31+ "ttir" : mlir_prototype_pattern ,
32+ "ttgir" : mlir_prototype_pattern ,
3033 "ptx" : ptx_prototype_pattern ,
3134}
3235
36+ mlir_arg_type_pattern = r'%\w+: ((?:[^,\s<)]+|<[^>]+>)+(?: {[^}]+})?),?'
3337ptx_arg_type_pattern = r"\.param\s+\.(\w+)"
3438arg_type_pattern = {
39+ "ttir" : mlir_arg_type_pattern ,
40+ "ttgir" : mlir_arg_type_pattern ,
3541 "ptx" : ptx_arg_type_pattern ,
3642}
3743
@@ -49,6 +55,16 @@ def convert_type_repr(x):
4955 return x
5056
5157
58+ def _get_num_warps_from_ir_str (src : str ):
59+ ttgir_num_warps_pattern = r'"triton_gpu.num-warps"\s?=\s?(\d+)\s?:'
60+ # TODO(jlebar): Using a regex to get num-warps is a hack, and will break if
61+ # e.g. someone has an instruction (not module) attribute named "num-warps".
62+ num_warps_matches = re .findall (ttgir_num_warps_pattern , src )
63+ assert len (num_warps_matches ) == 1 , "Expected exactly one match for num_warps"
64+ num_warps = int (num_warps_matches [0 ])
65+ return num_warps
66+
67+
5268class ASTSource :
5369
5470 def __init__ (self , fn , signature , constants = None , attrs = None ) -> None :
@@ -91,41 +107,28 @@ def parse_options(self):
91107
92108class IRSource :
93109
94- def __init__ (self , path , context ):
110+ def __init__ (self , path ):
95111 self .path = path
96112 path = Path (path )
97113 self .ext = path .suffix [1 :]
98114 self .src = path .read_text ()
99- ir .load_dialects (context )
100-
101- # We don't have a easy-to-use PTX parser that we can use, so keep that regex for now.
102- # TODO - replace with a proper parser
103- if self .ext == "ptx" :
104- match = re .search (prototype_pattern [self .ext ], self .src , re .MULTILINE )
105- self .name = match .group (1 )
106- signature = match .group (2 )
107- types = re .findall (arg_type_pattern [self .ext ], signature )
108- self .signature = {k : convert_type_repr (ty ) for k , ty in enumerate (types )}
109- else :
110- self .module = ir .parse_mlir_module (self .path , context )
111- fn_name = self .module .get_entry_func_name ()
112- self .name = "@" + fn_name
113- funcOp = self .module .get_function (fn_name )
114- func_ty = self .module .get_function_signature (funcOp )
115- self .signature = {k : ty for k , ty in enumerate (func_ty )}
115+ match = re .search (prototype_pattern [self .ext ], self .src , re .MULTILINE )
116+ self .name = match .group (1 )
117+ signature = match .group (2 )
118+ types = re .findall (arg_type_pattern [self .ext ], signature )
119+ self .signature = {k : convert_type_repr (ty ) for k , ty in enumerate (types )}
116120
117121 def hash (self ):
118122 return hashlib .sha256 (self .src .encode ("utf-8" )).hexdigest ()
119123
120124 def make_ir (self , options , codegen_fns , module_map , context ):
121- self .module .context = context
122- return self .module
125+ module = ir .parse_mlir_module (self .path , context )
126+ module .context = context
127+ return module
123128
124129 def parse_options (self ):
125130 if self .ext == "ttgir" :
126- num_warps = self .module .get_int_attr ("triton_gpu.num-warps" )
127- assert num_warps is not None , "Unable to parse triton_gpu.num-warps attribute"
128- return {'num_warps' : num_warps }
131+ return {'num_warps' : _get_num_warps_from_ir_str (self .src )}
129132 return dict ()
130133
131134
@@ -222,9 +225,7 @@ def compile(src, target=None, options=None):
222225 # create backend
223226 if ir_source :
224227 assert isinstance (src , str ), "source must be either AST or a filepath"
225- context = ir .context ()
226- src = IRSource (src , context )
227-
228+ src = IRSource (src )
228229 extra_options = src .parse_options ()
229230 options = backend .parse_options (dict (options or dict (), ** extra_options ))
230231 # create cache manager
@@ -265,15 +266,9 @@ def compile(src, target=None, options=None):
265266 # when the source is an IR file, don't apply the passes related to this stage. This makes it easier to write IR level tests.
266267 if ir_source :
267268 first_stage += 1
268-
269- if not isinstance (src , IRSource ):
270- context = ir .context ()
271- ir .load_dialects (context )
272- backend .load_dialects (context )
273- else :
274- # For IRSource, we have already grabbed the context + called ir.load_dialects
275- # just need to load the dialects for the backend.
276- backend .load_dialects (context )
269+ context = ir .context ()
270+ ir .load_dialects (context )
271+ backend .load_dialects (context )
277272 codegen_fns = backend .get_codegen_implementation ()
278273 module_map = backend .get_module_map ()
279274 try :
0 commit comments