66
77@triton .jit
88def _fwd_kernel_destindex_copy_kv (
9- K , Dest_loc ,
9+ K ,
10+ Dest_loc ,
1011 Out ,
11- stride_k_bs , stride_k_h , stride_k_d ,
12- stride_o_bs , stride_o_h , stride_o_d ,
12+ stride_k_bs ,
13+ stride_k_h ,
14+ stride_k_d ,
15+ stride_o_bs ,
16+ stride_o_h ,
17+ stride_o_d ,
1318 head_num ,
1419 BLOCK_DMODEL : tl .constexpr ,
15- BLOCK_HEAD : tl .constexpr
20+ BLOCK_HEAD : tl .constexpr ,
1621):
1722 cur_index = tl .program_id (0 )
1823 offs_h = tl .arange (0 , BLOCK_HEAD )
1924 offs_d = tl .arange (0 , BLOCK_DMODEL )
2025
21- dest_index = tl .load (Dest_loc + cur_index )
26+ dest_index = tl .load (Dest_loc + cur_index ). to ( tl . int64 )
2227
2328 k_ptrs = K + cur_index * stride_k_bs + stride_k_h * offs_h [:, None ] + stride_k_d * offs_d [None , :]
2429 o_ptrs = Out + dest_index * stride_o_bs + stride_o_h * offs_h [:, None ] + stride_o_d * offs_d [None , :]
@@ -39,9 +44,15 @@ def destindex_copy_kv(K, DestLoc, Out):
3944 num_warps = 1
4045
4146 _fwd_kernel_destindex_copy_kv [grid ](
42- K , DestLoc , Out ,
43- K .stride (0 ), K .stride (1 ), K .stride (2 ),
44- Out .stride (0 ), Out .stride (1 ), Out .stride (2 ),
47+ K ,
48+ DestLoc ,
49+ Out ,
50+ K .stride (0 ),
51+ K .stride (1 ),
52+ K .stride (2 ),
53+ Out .stride (0 ),
54+ Out .stride (1 ),
55+ Out .stride (2 ),
4556 head_num ,
4657 BLOCK_DMODEL = head_dim ,
4758 BLOCK_HEAD = BLOCK_HEAD ,
@@ -53,23 +64,35 @@ def destindex_copy_kv(K, DestLoc, Out):
5364
5465@triton .jit
5566def _fwd_kernel_destindex_copy_quantize_kv (
56- K , Dest_loc , Out , Out_scale ,
57- stride_k_bs , stride_k_h , stride_k_d ,
58- stride_o_bs , stride_o_h , stride_o_d ,
59- stride_os_bs , stride_os_h , stride_os_d ,
67+ K ,
68+ Dest_loc ,
69+ Out ,
70+ Out_scale ,
71+ stride_k_bs ,
72+ stride_k_h ,
73+ stride_k_d ,
74+ stride_o_bs ,
75+ stride_o_h ,
76+ stride_o_d ,
77+ stride_os_bs ,
78+ stride_os_h ,
79+ stride_os_d ,
6080 head_num ,
6181 BLOCK_DMODEL : tl .constexpr ,
62- BLOCK_HEAD : tl .constexpr
82+ BLOCK_HEAD : tl .constexpr ,
6383):
6484 cur_index = tl .program_id (0 )
6585 offs_h = tl .arange (0 , BLOCK_HEAD )
6686 offs_d = tl .arange (0 , BLOCK_DMODEL )
6787
68- dest_index = tl .load (Dest_loc + cur_index )
69- src_data = tl .load (K + cur_index * stride_k_bs + offs_h [:, None ] * stride_k_h + stride_k_d * offs_d [None , :],
70- mask = offs_h [:, None ] < head_num , other = 0.0 )
88+ dest_index = tl .load (Dest_loc + cur_index ).to (tl .int64 )
89+ src_data = tl .load (
90+ K + cur_index * stride_k_bs + offs_h [:, None ] * stride_k_h + stride_k_d * offs_d [None , :],
91+ mask = offs_h [:, None ] < head_num ,
92+ other = 0.0 ,
93+ )
7194 abs_data = tl .abs (src_data )
72- data_scale = (tl .max (abs_data , axis = 1 ) / 127. ).to (Out_scale .dtype .element_ty )[:, None ]
95+ data_scale = (tl .max (abs_data , axis = 1 ) / 127.0 ).to (Out_scale .dtype .element_ty )[:, None ]
7396 q_src_data = (src_data / data_scale ).to (tl .int8 )
7497 o_ptrs = Out + dest_index * stride_o_bs + stride_o_h * offs_h [:, None ] + stride_o_d * offs_d [None , :]
7598 os_ptrs = Out_scale + dest_index * stride_os_bs + stride_os_h * offs_h [:, None ]
@@ -88,10 +111,19 @@ def destindex_copy_quantize_kv(K, DestLoc, Out, Out_scale):
88111 num_warps = 1
89112
90113 _fwd_kernel_destindex_copy_quantize_kv [grid ](
91- K , DestLoc , Out , Out_scale ,
92- K .stride (0 ), K .stride (1 ), K .stride (2 ),
93- Out .stride (0 ), Out .stride (1 ), Out .stride (2 ),
94- Out_scale .stride (0 ), Out_scale .stride (1 ), Out_scale .stride (2 ),
114+ K ,
115+ DestLoc ,
116+ Out ,
117+ Out_scale ,
118+ K .stride (0 ),
119+ K .stride (1 ),
120+ K .stride (2 ),
121+ Out .stride (0 ),
122+ Out .stride (1 ),
123+ Out .stride (2 ),
124+ Out_scale .stride (0 ),
125+ Out_scale .stride (1 ),
126+ Out_scale .stride (2 ),
95127 head_num ,
96128 BLOCK_DMODEL = head_dim ,
97129 BLOCK_HEAD = BLOCK_HEAD ,
@@ -149,6 +181,6 @@ def test2():
149181 print ("cos " , cos (src .flatten ().to (torch .float32 ), (value_dest * scale_dest ).flatten ().to (torch .float32 )))
150182
151183
152- if __name__ == ' __main__' :
184+ if __name__ == " __main__" :
153185 test1 ()
154186 test2 ()
0 commit comments