66
77logger = init_logger (__name__ )
88
9+
910@triton .jit
1011def _page_io (
1112 mem_index_ptr ,
@@ -43,7 +44,7 @@ def _page_io(
4344 v_stride_layer_num = tl .cast (v_stride_layer_num , dtype = tl .int64 )
4445 k_stride_size = tl .cast (k_stride_size , dtype = tl .int64 )
4546 v_stride_size = tl .cast (v_stride_size , dtype = tl .int64 )
46-
47+
4748 tid = tl .program_id (0 )
4849 kv_head_id = tl .program_id (1 )
4950 page_head_id = page_head_start + kv_head_id
@@ -57,18 +58,86 @@ def _page_io(
5758
5859 for layer_index in tl .range (layer_num , num_stages = 3 ):
5960 if IS_WRITE :
60- k_tensor = tl .load (k_ptr + layer_index * k_stride_layer_num + mem_index * k_stride_size + kv_head_id * k_stride_head + off_dim * k_stride_dim , mask = mask )
61- v_tensor = tl .load (v_ptr + layer_index * v_stride_layer_num + mem_index * v_stride_size + kv_head_id * v_stride_head + off_dim * v_stride_dim , mask = mask )
62- tl .store (k_page_ptr + tid * k_page_stride_size + layer_index * k_page_stride_layer_num + page_head_id * k_page_stride_head + off_dim * k_page_stride_dim , k_tensor , mask = mask )
63- tl .store (v_page_ptr + tid * v_page_stride_size + layer_index * v_page_stride_layer_num + page_head_id * v_page_stride_head + off_dim * v_page_stride_dim , v_tensor , mask = mask )
61+ k_tensor = tl .load (
62+ k_ptr
63+ + layer_index * k_stride_layer_num
64+ + mem_index * k_stride_size
65+ + kv_head_id * k_stride_head
66+ + off_dim * k_stride_dim ,
67+ mask = mask ,
68+ )
69+ v_tensor = tl .load (
70+ v_ptr
71+ + layer_index * v_stride_layer_num
72+ + mem_index * v_stride_size
73+ + kv_head_id * v_stride_head
74+ + off_dim * v_stride_dim ,
75+ mask = mask ,
76+ )
77+ tl .store (
78+ k_page_ptr
79+ + tid * k_page_stride_size
80+ + layer_index * k_page_stride_layer_num
81+ + page_head_id * k_page_stride_head
82+ + off_dim * k_page_stride_dim ,
83+ k_tensor ,
84+ mask = mask ,
85+ )
86+ tl .store (
87+ v_page_ptr
88+ + tid * v_page_stride_size
89+ + layer_index * v_page_stride_layer_num
90+ + page_head_id * v_page_stride_head
91+ + off_dim * v_page_stride_dim ,
92+ v_tensor ,
93+ mask = mask ,
94+ )
6495 else :
65- k_page_tensor = tl .load (k_page_ptr + tid * k_page_stride_size + layer_index * k_page_stride_layer_num + page_head_id * k_page_stride_head + off_dim * k_page_stride_dim , mask = mask )
66- v_page_tensor = tl .load (v_page_ptr + tid * v_page_stride_size + layer_index * v_page_stride_layer_num + page_head_id * v_page_stride_head + off_dim * v_page_stride_dim , mask = mask )
67- tl .store (k_ptr + layer_index * k_stride_layer_num + mem_index * k_stride_size + kv_head_id * k_stride_head + off_dim * k_stride_dim , k_page_tensor , mask = mask )
68- tl .store (v_ptr + layer_index * v_stride_layer_num + mem_index * v_stride_size + kv_head_id * v_stride_head + off_dim * v_stride_dim , v_page_tensor , mask = mask )
96+ k_page_tensor = tl .load (
97+ k_page_ptr
98+ + tid * k_page_stride_size
99+ + layer_index * k_page_stride_layer_num
100+ + page_head_id * k_page_stride_head
101+ + off_dim * k_page_stride_dim ,
102+ mask = mask ,
103+ )
104+ v_page_tensor = tl .load (
105+ v_page_ptr
106+ + tid * v_page_stride_size
107+ + layer_index * v_page_stride_layer_num
108+ + page_head_id * v_page_stride_head
109+ + off_dim * v_page_stride_dim ,
110+ mask = mask ,
111+ )
112+ tl .store (
113+ k_ptr
114+ + layer_index * k_stride_layer_num
115+ + mem_index * k_stride_size
116+ + kv_head_id * k_stride_head
117+ + off_dim * k_stride_dim ,
118+ k_page_tensor ,
119+ mask = mask ,
120+ )
121+ tl .store (
122+ v_ptr
123+ + layer_index * v_stride_layer_num
124+ + mem_index * v_stride_size
125+ + kv_head_id * v_stride_head
126+ + off_dim * v_stride_dim ,
127+ v_page_tensor ,
128+ mask = mask ,
129+ )
69130 return
70131
71- def page_io (mem_indexes :torch .Tensor , page_tensor : torch .Tensor , kv_buffer : torch .Tensor , tp_index :int , tp_world_size :int , mode :str ):
132+
133+ def page_io (
134+ mem_indexes : torch .Tensor ,
135+ page_tensor : torch .Tensor ,
136+ kv_buffer : torch .Tensor ,
137+ tp_index : int ,
138+ tp_world_size : int ,
139+ mode : str ,
140+ ):
72141 assert mode in ["read" , "write" ]
73142 assert mem_indexes .is_contiguous ()
74143 assert page_tensor .is_contiguous ()
@@ -86,9 +155,10 @@ def page_io(mem_indexes:torch.Tensor, page_tensor: torch.Tensor, kv_buffer: torc
86155 v_page_tensor = page_tensor [:, :, - page_v_head_num :, :]
87156
88157 k_head_num , v_head_num = kv_head_num // 2 , kv_head_num // 2
158+ assert k_head_num == v_head_num
89159 k_buffer = kv_buffer [:, :, 0 :k_head_num , :]
90160 v_buffer = kv_buffer [:, :, k_head_num :, :]
91-
161+
92162 tp_index = tp_index // repeat_count
93163 tp_world_size = tp_world_size // repeat_count
94164
@@ -127,14 +197,13 @@ def page_io(mem_indexes:torch.Tensor, page_tensor: torch.Tensor, kv_buffer: torc
127197 layer_num = layer_num ,
128198 head_dim = head_dim ,
129199 HEAD_DIM_BLOCK = triton .next_power_of_2 (head_dim ),
130- IS_WRITE = mode == "write" ,
200+ IS_WRITE = mode == "write" ,
131201 NEED_MASK = triton .next_power_of_2 (head_dim ) != head_dim ,
132202 num_warps = 1 ,
133203 )
134204 return
135205
136206
137-
138207@triton .jit
139208def _mla_page_io (
140209 mem_index_ptr ,
@@ -157,7 +226,7 @@ def _mla_page_io(
157226 page_stride_size = tl .cast (page_stride_size , dtype = tl .int64 )
158227 kv_stride_layer_num = tl .cast (kv_stride_layer_num , dtype = tl .int64 )
159228 kv_stride_size = tl .cast (kv_stride_size , dtype = tl .int64 )
160-
229+
161230 tid = tl .program_id (0 )
162231
163232 mem_index = tl .load (mem_index_ptr + tid )
@@ -169,14 +238,45 @@ def _mla_page_io(
169238
170239 for layer_index in tl .range (layer_num , num_stages = 3 ):
171240 if IS_WRITE :
172- kv_tensor = tl .load (kv_ptr + layer_index * kv_stride_layer_num + mem_index * kv_stride_size + 0 * kv_stride_head + off_dim * kv_stride_dim , mask = mask )
173- tl .store (page_ptr + tid * page_stride_size + layer_index * page_stride_layer_num + 0 * page_stride_head + off_dim * page_stride_dim , kv_tensor , mask = mask )
241+ kv_tensor = tl .load (
242+ kv_ptr
243+ + layer_index * kv_stride_layer_num
244+ + mem_index * kv_stride_size
245+ + 0 * kv_stride_head
246+ + off_dim * kv_stride_dim ,
247+ mask = mask ,
248+ )
249+ tl .store (
250+ page_ptr
251+ + tid * page_stride_size
252+ + layer_index * page_stride_layer_num
253+ + 0 * page_stride_head
254+ + off_dim * page_stride_dim ,
255+ kv_tensor ,
256+ mask = mask ,
257+ )
174258 else :
175- page_tensor = tl .load (page_ptr + tid * page_stride_size + layer_index * page_stride_layer_num + 0 * page_stride_head + off_dim * page_stride_dim , mask = mask )
176- tl .store (kv_ptr + layer_index * kv_stride_layer_num + mem_index * kv_stride_size + 0 * kv_stride_head + off_dim * kv_stride_dim , page_tensor , mask = mask )
259+ page_tensor = tl .load (
260+ page_ptr
261+ + tid * page_stride_size
262+ + layer_index * page_stride_layer_num
263+ + 0 * page_stride_head
264+ + off_dim * page_stride_dim ,
265+ mask = mask ,
266+ )
267+ tl .store (
268+ kv_ptr
269+ + layer_index * kv_stride_layer_num
270+ + mem_index * kv_stride_size
271+ + 0 * kv_stride_head
272+ + off_dim * kv_stride_dim ,
273+ page_tensor ,
274+ mask = mask ,
275+ )
177276 return
178277
179- def mla_page_io (mem_indexes :torch .Tensor , page_tensor : torch .Tensor , kv_buffer : torch .Tensor , mode :str ):
278+
279+ def mla_page_io (mem_indexes : torch .Tensor , page_tensor : torch .Tensor , kv_buffer : torch .Tensor , mode : str ):
180280 assert mode in ["read" , "write" ]
181281 assert mem_indexes .is_contiguous ()
182282 assert page_tensor .is_contiguous ()
@@ -189,7 +289,6 @@ def mla_page_io(mem_indexes:torch.Tensor, page_tensor: torch.Tensor, kv_buffer:
189289 assert page_head_dim == head_dim
190290 assert page_head_num == kv_head_num == 1
191291
192-
193292 token_num = len (mem_indexes )
194293 grid = (token_num ,)
195294
@@ -208,7 +307,7 @@ def mla_page_io(mem_indexes:torch.Tensor, page_tensor: torch.Tensor, kv_buffer:
208307 layer_num = layer_num ,
209308 head_dim = head_dim ,
210309 HEAD_DIM_BLOCK = triton .next_power_of_2 (head_dim ),
211- IS_WRITE = mode == "write" ,
310+ IS_WRITE = mode == "write" ,
212311 NEED_MASK = triton .next_power_of_2 (head_dim ) != head_dim ,
213312 num_warps = 1 ,
214313 )
0 commit comments