@@ -213,17 +213,18 @@ def _attn_bwd_dkdv(dk, dv, #
213213 # Filled in by the wrapper.
214214 start_n , start_m , num_steps , #
215215 MASK : tl .constexpr ):
216- offs_m = start_m + tl .arange (0 , BLOCK_M1 )
217216 offs_n = start_n + tl .arange (0 , BLOCK_N1 )
218- offs_k = tl .arange (0 , HEAD_DIM )
219- qT_ptrs = Q + offs_m [None , :] * stride_tok + offs_k [:, None ] * stride_d
220- do_ptrs = DO + offs_m [:, None ] * stride_tok + offs_k [None , :] * stride_d
217+ qT_desc = tl .make_tensor_descriptor (Q , shape = [HEAD_DIM , N_CTX ], strides = [stride_d , stride_tok ],
218+ block_shape = [HEAD_DIM , BLOCK_M1 ])
219+
220+ do_desc = tl .make_tensor_descriptor (DO , shape = [N_CTX , HEAD_DIM ], strides = [stride_tok , stride_d ],
221+ block_shape = [BLOCK_M1 , HEAD_DIM ])
221222 # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
222223 tl .static_assert (BLOCK_N1 % BLOCK_M1 == 0 )
223224 curr_m = start_m
224225 step_m = BLOCK_M1
225226 for blk_idx in range (num_steps ):
226- qT = tl .load (qT_ptrs )
227+ qT = qT_desc .load ([ 0 , start_m + blk_idx * step_m ] )
227228 # Load m before computing qk to reduce pipeline stall.
228229 offs_m = curr_m + tl .arange (0 , BLOCK_M1 )
229230 m = tl .load (M + offs_m )
@@ -233,7 +234,7 @@ def _attn_bwd_dkdv(dk, dv, #
233234 if MASK :
234235 mask = (offs_m [None , :] >= offs_n [:, None ])
235236 pT = tl .where (mask , pT , 0.0 )
236- do = tl .load (do_ptrs )
237+ do = do_desc .load ([ start_m + blk_idx * step_m , 0 ] )
237238 # Compute dV.
238239 ppT = pT
239240 ppT = ppT .to (tl .float16 )
@@ -247,8 +248,6 @@ def _attn_bwd_dkdv(dk, dv, #
247248 dk += tl .dot (dsT , tl .trans (qT ))
248249 # Increment pointers.
249250 curr_m += step_m
250- qT_ptrs += step_m * stride_tok
251- do_ptrs += step_m * stride_tok
252251 return dk , dv
253252
254253
@@ -267,19 +266,20 @@ def _attn_bwd_dq(dq, q, K, V, #
267266 start_m , start_n , num_steps , #
268267 MASK : tl .constexpr ):
269268 offs_m = start_m + tl .arange (0 , BLOCK_M2 )
270- offs_n = start_n + tl .arange (0 , BLOCK_N2 )
271- offs_k = tl .arange (0 , HEAD_DIM )
272- kT_ptrs = K + offs_n [None , :] * stride_tok + offs_k [:, None ] * stride_d
273- vT_ptrs = V + offs_n [None , :] * stride_tok + offs_k [:, None ] * stride_d
269+ kT_desc = tl .make_tensor_descriptor (K , shape = [HEAD_DIM , N_CTX ], strides = [stride_d , stride_tok ],
270+ block_shape = [HEAD_DIM , BLOCK_N2 ])
271+
272+ vT_desc = tl .make_tensor_descriptor (V , shape = [HEAD_DIM , N_CTX ], strides = [stride_d , stride_tok ],
273+ block_shape = [HEAD_DIM , BLOCK_N2 ])
274274 # D (= delta) is pre-divided by ds_scale.
275275 Di = tl .load (D + offs_m )
276276 # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
277277 tl .static_assert (BLOCK_M2 % BLOCK_N2 == 0 )
278278 curr_n = start_n
279279 step_n = BLOCK_N2
280280 for blk_idx in range (num_steps ):
281- kT = tl .load (kT_ptrs )
282- vT = tl .load (vT_ptrs )
281+ kT = kT_desc .load ([ 0 , start_n + blk_idx * step_n ] )
282+ vT = vT_desc .load ([ 0 , start_n + blk_idx * step_n ] )
283283 qk = tl .dot (q , kT )
284284 p = tl .math .exp2 (qk - m )
285285 # Autoregressive masking.
@@ -296,8 +296,6 @@ def _attn_bwd_dq(dq, q, K, V, #
296296 dq += tl .dot (ds , tl .trans (kT ))
297297 # Increment pointers.
298298 curr_n += step_n
299- kT_ptrs += step_n * stride_tok
300- vT_ptrs += step_n * stride_tok
301299 return dq
302300
303301
0 commit comments