11"""
2+ Mixed precision tests for matmul (tl.dot) with cast (tl.to)
3+
24issue: https://github.com/triton-lang/triton/issues/2523
3- fused type convert and matmul, base on triton matmul, the different with matmul:
4- 1. force C's dtype=dot_out_dtype to ["float16", "float32"]
5- 2. accept A and B with dtype=["float32", "float64"]
65
6+ TODO: float8 types
77"""
8+
89import pytest
910import torch
1011
12+ import triton
1113import triton .language as tl
12- from triton import cdiv , jit
1314
14- input_dtypes = ["float32" , "float64" ]
15+ input_dtypes = ["float16" , " float32" , "float64" ]
1516out_dtypes = ["float16" , "float32" ]
1617
1718
19+ @triton .jit
20+ def matmul_kernel (A , B , C , M , N , K , #
21+ stride_am , stride_ak , #
22+ stride_bk , stride_bn , #
23+ stride_cm , stride_cn , #
24+ dot_out_dtype : tl .constexpr , #
25+ BLOCK_M : tl .constexpr , BLOCK_N : tl .constexpr , #
26+ BLOCK_K : tl .constexpr , GROUP_M : tl .constexpr ):
27+ # matrix multiplication
28+ pid = tl .program_id (0 )
29+ grid_m = tl .cdiv (M , BLOCK_M )
30+ grid_n = tl .cdiv (N , BLOCK_N )
31+ # re-order program ID for better L2 performance
32+ width = GROUP_M * grid_n
33+ group_id = pid // width
34+ group_size = min (grid_m - group_id * GROUP_M , GROUP_M )
35+ pid_m = group_id * GROUP_M + (pid % group_size )
36+ pid_n = (pid % width ) // (group_size )
37+ # do matrix multiplication
38+ rm = pid_m * BLOCK_M + tl .arange (0 , BLOCK_M )
39+ rn = pid_n * BLOCK_N + tl .arange (0 , BLOCK_N )
40+ ram = tl .max_contiguous (tl .multiple_of (rm % M , BLOCK_M ), BLOCK_M )
41+ rbn = tl .max_contiguous (tl .multiple_of (rn % N , BLOCK_N ), BLOCK_N )
42+ rk = tl .arange (0 , BLOCK_K )
43+ # pointers
44+ A = A + (ram [:, None ] * stride_am + rk [None , :] * stride_ak )
45+ B = B + (rk [:, None ] * stride_bk + rbn [None , :] * stride_bn )
46+ acc = tl .zeros ((BLOCK_M , BLOCK_N ), dtype = dot_out_dtype )
47+ for k in range (0 , tl .cdiv (K , BLOCK_K )):
48+ k_remaining = K - k * BLOCK_K
49+ _0 = tl .zeros ((1 , 1 ), dtype = C .dtype .element_ty )
50+ a = tl .load (A , mask = rk [None , :] < k_remaining , other = _0 )
51+ b = tl .load (B , mask = rk [:, None ] < k_remaining , other = _0 )
52+ a = a .to (C .dtype .element_ty )
53+ b = b .to (C .dtype .element_ty )
54+ acc += tl .dot (a , b , out_dtype = dot_out_dtype )
55+ A += BLOCK_K * stride_ak
56+ B += BLOCK_K * stride_bk
57+ acc = acc .to (C .dtype .element_ty )
58+ # rematerialize rm and rn to save registers
59+ rm = pid_m * BLOCK_M + tl .arange (0 , BLOCK_M )
60+ rn = pid_n * BLOCK_N + tl .arange (0 , BLOCK_N )
61+ C = C + (rm [:, None ] * stride_cm + rn [None , :] * stride_cn )
62+ mask = (rm < M )[:, None ] & (rn < N )[None , :]
63+ tl .store (C , acc , mask = mask )
64+
65+
1866@pytest .mark .parametrize ("M, K, N, w_dtype, x_dtype, out_dtype" ,
1967 [(M , K , N , w , x , o ) #
2068 for (M , K , N ) in [(128 , 128 , 128 ), (1280 , 768 , 1024 )] #
2371 for o in out_dtypes ])
2472def test_cast_matmul (M , K , N , w_dtype , x_dtype , out_dtype ):
2573 if x_dtype == w_dtype :
26- pytest .skip ("skip same dtype" )
74+ pytest .skip ("skip the same input dtype" )
2775 device = torch .cuda .current_device ()
2876 x_dtype = getattr (torch , x_dtype )
2977 w_dtype = getattr (torch , w_dtype )
@@ -36,53 +84,7 @@ def test_cast_matmul(M, K, N, w_dtype, x_dtype, out_dtype):
3684
3785 # launch kernel
3886 BLOCK_M , BLOCK_N , BLOCK_K = 16 , 16 , 32
39- grid = ((cdiv (M , BLOCK_M ) * cdiv (N , BLOCK_N )), 1 )
40-
41- @jit
42- def matmul_kernel (A , B , C , M , N , K , #
43- stride_am , stride_ak , #
44- stride_bk , stride_bn , #
45- stride_cm , stride_cn , #
46- dot_out_dtype : tl .constexpr , #
47- BLOCK_M : tl .constexpr , BLOCK_N : tl .constexpr , #
48- BLOCK_K : tl .constexpr , GROUP_M : tl .constexpr ):
49- # matrix multiplication
50- pid = tl .program_id (0 )
51- grid_m = tl .cdiv (M , BLOCK_M )
52- grid_n = tl .cdiv (N , BLOCK_N )
53- # re-order program ID for better L2 performance
54- width = GROUP_M * grid_n
55- group_id = pid // width
56- group_size = min (grid_m - group_id * GROUP_M , GROUP_M )
57- pid_m = group_id * GROUP_M + (pid % group_size )
58- pid_n = (pid % width ) // (group_size )
59- # do matrix multiplication
60- rm = pid_m * BLOCK_M + tl .arange (0 , BLOCK_M )
61- rn = pid_n * BLOCK_N + tl .arange (0 , BLOCK_N )
62- ram = tl .max_contiguous (tl .multiple_of (rm % M , BLOCK_M ), BLOCK_M )
63- rbn = tl .max_contiguous (tl .multiple_of (rn % N , BLOCK_N ), BLOCK_N )
64- rk = tl .arange (0 , BLOCK_K )
65- # pointers
66- A = A + (ram [:, None ] * stride_am + rk [None , :] * stride_ak )
67- B = B + (rk [:, None ] * stride_bk + rbn [None , :] * stride_bn )
68- acc = tl .zeros ((BLOCK_M , BLOCK_N ), dtype = dot_out_dtype )
69- for k in range (0 , tl .cdiv (K , BLOCK_K )):
70- k_remaining = K - k * BLOCK_K
71- _0 = tl .zeros ((1 , 1 ), dtype = C .dtype .element_ty )
72- a = tl .load (A , mask = rk [None , :] < k_remaining , other = _0 )
73- b = tl .load (B , mask = rk [:, None ] < k_remaining , other = _0 )
74- a = a .to (C .dtype .element_ty )
75- b = b .to (C .dtype .element_ty )
76- acc += tl .dot (a , b , out_dtype = dot_out_dtype )
77- A += BLOCK_K * stride_ak
78- B += BLOCK_K * stride_bk
79- acc = acc .to (C .dtype .element_ty )
80- # rematerialize rm and rn to save registers
81- rm = pid_m * BLOCK_M + tl .arange (0 , BLOCK_M )
82- rn = pid_n * BLOCK_N + tl .arange (0 , BLOCK_N )
83- C = C + (rm [:, None ] * stride_cm + rn [None , :] * stride_cn )
84- mask = (rm < M )[:, None ] & (rn < N )[None , :]
85- tl .store (C , acc , mask = mask )
87+ grid = ((triton .cdiv (M , BLOCK_M ) * triton .cdiv (N , BLOCK_N )), 1 )
8688
8789 matmul_kernel [grid ](
8890 a , b , out_triton , M , N , K , #
0 commit comments