@@ -41,11 +41,7 @@ def run_filecheck(name, module_str, check_template):
41
41
raise ValueError (matcher .stderr .getvalue ())
42
42
43
43
44
- def run_filecheck_test (kernel_fn ):
45
- assert isinstance (kernel_fn , triton .runtime .JITFunction )
46
- check_template = inspect .getsource (kernel_fn .fn )
47
- if check_template is None :
48
- raise ValueError ("kernel function must have a docstring with FileCheck template" )
44
+ def run_parser (kernel_fn ):
49
45
sigkeys = [x .name for x in kernel_fn .params ]
50
46
sigvals = [f"arg{ i } " for i in range (len (sigkeys ))]
51
47
signature = {k : v for (k , v ) in zip (sigkeys , sigvals )}
@@ -59,7 +55,15 @@ def run_filecheck_test(kernel_fn):
59
55
options = stub_backend .parse_options (dict (** extra_options ))
60
56
codegen_fns = stub_backend .get_codegen_implementation (options )
61
57
module_map = stub_backend .get_module_map ()
62
- mlir_module = src .make_ir (options , codegen_fns , module_map , context )
58
+ return src .make_ir (options , codegen_fns , module_map , context )
59
+
60
+
61
+ def run_filecheck_test (kernel_fn ):
62
+ assert isinstance (kernel_fn , triton .runtime .JITFunction )
63
+ check_template = inspect .getsource (kernel_fn .fn )
64
+ if check_template is None :
65
+ raise ValueError ("kernel function must have a docstring with FileCheck template" )
66
+ mlir_module = run_parser (kernel_fn )
63
67
64
68
run_filecheck ("placeholder" , str (mlir_module ), check_template )
65
69
@@ -142,6 +146,17 @@ def _flatten_ir(self, handles: List[ir.value]) -> None:
142
146
self .first ._flatten_ir (handles )
143
147
self .second ._flatten_ir (handles )
144
148
149
+ @triton .jit
150
+ def get_first (self ):
151
+ return self .first
152
+
153
+ def get_second (self , _builder = None ):
154
+ return self .second
155
+
156
+ @triton .jit
157
+ def unpack (self ):
158
+ return self .get_first (), self .get_second ()
159
+
145
160
146
161
@tl .core .builtin
147
162
def pair_value_ctor (first , second , _builder = None ):
@@ -160,3 +175,19 @@ def test_assign_attribute():
160
175
# CHECK-NEXT: call @"anchor{{.*}}"([[RANGE]], %c42_i32)
161
176
pair .second = 42
162
177
anchor (pair )
178
+
179
+
180
+ @filecheck_test
181
+ @triton .jit
182
+ def test_jit_method ():
183
+ # CHECK-LABEL: test_jit_method
184
+ # CHECK: %c11_i32 = arith.constant 11 : i32
185
+ # CHECK: [[RANGE:%.*]] = tt.make_range {end = 4 : i32, start = 0 : i32}
186
+ scalar = 11
187
+ # CHECK: [[V:%.*]]:2 = tt.call @"unpack{{.*}}"([[RANGE]], %c11_i32)
188
+ pair = pair_value_ctor (tl .arange (0 , 4 ), scalar )
189
+ a , b = pair .unpack ()
190
+ # CHECK: call @anchor{{.*}}([[V]]#0)
191
+ anchor (a )
192
+ # CHECK: call @anchor{{.*}}([[V]]#1)
193
+ anchor (b )
0 commit comments