1- import triton
2- import triton .language as tl
31import os
2+ from contextlib import contextmanager
3+
44import pytest
55import torch
6+ import triton
7+ import triton .language as tl
8+
9+
10+ @contextmanager
11+ def enable_remark_context ():
12+ try :
13+ os .environ ["MLIR_ENABLE_REMARK" ] = "1"
14+ yield
15+ finally :
16+ os .environ ["MLIR_ENABLE_REMARK" ] = "0"
617
718
819def is_perf_warning_enabled ():
9- return os .environ .get (' MLIR_ENABLE_REMARK' , '0' ) == '1'
20+ return os .environ .get (" MLIR_ENABLE_REMARK" , "0" ) == "1"
1021
1122
1223def is_cuda ():
1324 return triton .runtime .driver .active .get_current_target ().backend == "cuda"
1425
1526
16- def test_mma_remark (capfd ):
27+ def test_mma_remark (capfd , fresh_triton_cache ):
1728 if is_cuda ():
1829 capability = torch .cuda .get_device_capability ()
1930 if capability [0 ] < 9 :
2031 pytest .skip ("Requires sm >= 90 to run" )
2132
22- os .environ ['MLIR_ENABLE_REMARK' ] = '1'
23-
2433 @triton .jit
25- def matmul_kernel (a_ptr , b_ptr , c_ptr , M , N , K , stride_am , stride_ak , stride_bk , stride_bn , stride_cm , stride_cn ):
26- a_block_ptr = tl .make_block_ptr (base = a_ptr , shape = (M , K ), strides = (stride_am , stride_ak ), offsets = (0 , 0 ),
27- block_shape = (32 , 128 ), order = (1 , 0 ))
28- b_block_ptr = tl .make_block_ptr (base = b_ptr , shape = (K , N ), strides = (stride_bk , stride_bn ), offsets = (0 , 0 ),
29- block_shape = (128 , 32 ), order = (0 , 1 ))
30- c_block_ptr = tl .make_block_ptr (base = c_ptr , shape = (M , N ), strides = (stride_cm , stride_cn ), offsets = (0 , 0 ),
31- block_shape = (32 , 32 ), order = (1 , 0 ))
34+ def matmul_kernel (
35+ a_ptr ,
36+ b_ptr ,
37+ c_ptr ,
38+ M ,
39+ N ,
40+ K ,
41+ stride_am ,
42+ stride_ak ,
43+ stride_bk ,
44+ stride_bn ,
45+ stride_cm ,
46+ stride_cn ,
47+ ):
48+ a_block_ptr = tl .make_block_ptr (
49+ base = a_ptr ,
50+ shape = (M , K ),
51+ strides = (stride_am , stride_ak ),
52+ offsets = (0 , 0 ),
53+ block_shape = (32 , 128 ),
54+ order = (1 , 0 ),
55+ )
56+ b_block_ptr = tl .make_block_ptr (
57+ base = b_ptr ,
58+ shape = (K , N ),
59+ strides = (stride_bk , stride_bn ),
60+ offsets = (0 , 0 ),
61+ block_shape = (128 , 32 ),
62+ order = (0 , 1 ),
63+ )
64+ c_block_ptr = tl .make_block_ptr (
65+ base = c_ptr ,
66+ shape = (M , N ),
67+ strides = (stride_cm , stride_cn ),
68+ offsets = (0 , 0 ),
69+ block_shape = (32 , 32 ),
70+ order = (1 , 0 ),
71+ )
3272 a = tl .load (a_block_ptr )
3373 b = tl .load (b_block_ptr )
3474 c = tl .dot (a , b )
3575 tl .store (c_block_ptr , c )
3676
37- triton .compile (
38- triton .compiler .ASTSource (
39- fn = matmul_kernel , signature = {
40- 'a_ptr' : '*fp32' , 'b_ptr' : '*fp32' , 'c_ptr' : '*fp32' , 'M' : 'i32' , 'N' : 'i32' , 'K' : 'i32' , 'stride_am' :
41- 'i32' , 'stride_ak' : 'i32' , 'stride_bk' : 'i32' , 'stride_bn' : 'i32' , 'stride_cm' : 'i32' , 'stride_cn' :
42- 'i32'
43- }, constants = {}))
77+ with enable_remark_context ():
78+ triton .compile (
79+ triton .compiler .ASTSource (
80+ fn = matmul_kernel ,
81+ signature = {
82+ "a_ptr" : "*fp32" ,
83+ "b_ptr" : "*fp32" ,
84+ "c_ptr" : "*fp32" ,
85+ "M" : "i32" ,
86+ "N" : "i32" ,
87+ "K" : "i32" ,
88+ "stride_am" : "i32" ,
89+ "stride_ak" : "i32" ,
90+ "stride_bk" : "i32" ,
91+ "stride_bn" : "i32" ,
92+ "stride_cm" : "i32" ,
93+ "stride_cn" : "i32" ,
94+ },
95+ constants = {},
96+ ))
4497 captured = capfd .readouterr ()
4598
46- assert "remark: Warning: can't use MMA V3 for the dot op" in captured .err , "expect MMA V3 remark"
99+ assert ( "remark: Warning: can't use MMA V3 for the dot op" in captured .err ) , "expect MMA V3 remark"
47100 assert "note: see current operation:" in captured .err
48- os .environ ['MLIR_ENABLE_REMARK' ] = '0'
49101
50102
51- def test_remark_vectorization (capfd ):
52- os .environ ["MLIR_ENABLE_REMARK" ] = "1"
103+ def test_remark_vectorization (capfd , fresh_triton_cache ):
53104
54105 @triton .jit
55106 def ldst_vec (in_ptr0 , in_ptr1 , in_ptr2 , in_ptr3 , out_ptr0 , XBLOCK : tl .constexpr ):
@@ -75,12 +126,52 @@ def ldst_vec(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, XBLOCK: tl.constexpr)
75126 tl .store (out_ptr0 + (x4 ), tmp22 , None )
76127
77128 XBLOCK = 1024
78- triton .compile (
79- triton .compiler .ASTSource (
80- fn = ldst_vec , signature = {
81- 'in_ptr0' : '*i64' , 'in_ptr1' : '*i64' , 'in_ptr2' : '*fp16' , 'in_ptr3' : '*fp32' , 'out_ptr0' : '*fp16'
82- }, constants = {"XBLOCK" : XBLOCK }), options = {"num_warps" : 1 })
129+ with enable_remark_context ():
130+ triton .compile (
131+ triton .compiler .ASTSource (
132+ fn = ldst_vec ,
133+ signature = {
134+ "in_ptr0" : "*i64" ,
135+ "in_ptr1" : "*i64" ,
136+ "in_ptr2" : "*fp16" ,
137+ "in_ptr3" : "*fp32" ,
138+ "out_ptr0" : "*fp16" ,
139+ },
140+ constants = {"XBLOCK" : XBLOCK },
141+ ),
142+ options = {"num_warps" : 1 },
143+ )
83144
84145 _ , err = capfd .readouterr ()
85146 assert ("remark: Warning: vectorization fails" in err ), "expect vectorization failure remark"
86- os .environ ["MLIR_ENABLE_REMARK" ] = "0"
147+
148+
149+ def test_remark_swp_op_before_operands (capfd , fresh_triton_cache ):
150+
151+ @triton .jit
152+ def kernel_pipe_error (in_ptr , out_ptr ):
153+ SIZE : tl .constexpr = 64
154+ in_ptrs = in_ptr + tl .arange (0 , SIZE )
155+ val = tl .zeros ((SIZE , ), dtype = tl .float32 )
156+ k = 0
157+ for i in tl .range (0 , 64 , num_stages = 3 ):
158+ in_ptrs = in_ptr + tl .arange (0 , SIZE ) + SIZE * k
159+ val = tl .load (in_ptrs )
160+ out_ptrs = out_ptr + (tl .arange (0 , SIZE ) + i * SIZE )
161+ tl .store (out_ptrs , val )
162+ if tl .max (val ) > 0 :
163+ k += 1
164+
165+ with enable_remark_context ():
166+ triton .compile (
167+ triton .compiler .ASTSource (
168+ fn = kernel_pipe_error ,
169+ signature = {"in_ptr" : "*fp32" , "out_ptr" : "*fp32" },
170+ constants = {},
171+ ),
172+ options = {"cluster_dims" : (1 , 1 , 1 )},
173+ )
174+
175+ _ , err = capfd .readouterr ()
176+
177+ assert "operation scheduled before its operands" in err , "expect swp op remark"
0 commit comments