Skip to content

Commit 38cc619

Browse files
committed
Refactor chunk_delta_h and chunk_o for improved type handling and performance
1 parent ada74c3 commit 38cc619

File tree

5 files changed

+207
-388
lines changed

5 files changed

+207
-388
lines changed

fla/ops/common/chunk_delta_h.py

Lines changed: 51 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -81,12 +81,12 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
8181
b_h4 = tl.zeros([64, BV], dtype=tl.float32)
8282

8383
# calculate offset
84-
h += (boh * H + i_h) * K*V
85-
v += (bos * H + i_h) * V
86-
k += (bos * H + i_h) * K
87-
w += (bos * H + i_h) * K
84+
h += ((boh * H + i_h) * K*V).to(tl.int64)
85+
v += ((bos * H + i_h) * V).to(tl.int64)
86+
k += ((bos * H + i_h) * K).to(tl.int64)
87+
w += ((bos * H + i_h) * K).to(tl.int64)
8888
if SAVE_NEW_VALUE:
89-
v_new += (bos * H + i_h) * V
89+
v_new += ((bos * H + i_h) * V).to(tl.int64)
9090
stride_v = H*V
9191
stride_h = H*K*V
9292
stride_k = H*K
@@ -181,30 +181,18 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
181181

182182
p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (0, i_t * BT), (64, BT), (0, 1))
183183
b_k = tl.load(p_k, boundary_check=(0, 1))
184-
if USE_GK:
185-
p_g = tl.make_block_ptr(gk + (bos*H + i_h) * K, (K, T), (1, H*K), (0, i_t * BT), (64, BT), (0, 1))
186-
b_k = (b_k * exp(b_gk_last1[:, None] - tl.load(p_g, boundary_check=(0, 1)))).to(b_k.dtype)
187184
b_h1 += tl.dot(b_k, b_v)
188185
if K > 64:
189186
p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (64, i_t * BT), (64, BT), (0, 1))
190187
b_k = tl.load(p_k, boundary_check=(0, 1))
191-
if USE_GK:
192-
p_g = tl.make_block_ptr(gk + (bos*H + i_h) * K, (K, T), (1, H*K), (64, i_t * BT), (64, BT), (0, 1))
193-
b_k = (b_k * exp(b_gk_last2[:, None] - tl.load(p_g, boundary_check=(0, 1)))).to(b_k.dtype)
194188
b_h2 += tl.dot(b_k, b_v)
195189
if K > 128:
196190
p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (128, i_t * BT), (64, BT), (0, 1))
197191
b_k = tl.load(p_k, boundary_check=(0, 1))
198-
if USE_GK:
199-
p_g = tl.make_block_ptr(gk + (bos*H + i_h) * K, (K, T), (1, H*K), (128, i_t * BT), (64, BT), (0, 1))
200-
b_k = (b_k * exp(b_gk_last3[:, None] - tl.load(p_g, boundary_check=(0, 1)))).to(b_k.dtype)
201192
b_h3 += tl.dot(b_k, b_v)
202193
if K > 192:
203194
p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (192, i_t * BT), (64, BT), (0, 1))
204195
b_k = tl.load(p_k, boundary_check=(0, 1))
205-
if USE_GK:
206-
p_g = tl.make_block_ptr(gk + (bos*H + i_h) * K, (K, T), (1, H*K), (192, i_t * BT), (64, BT), (0, 1))
207-
b_k = (b_k * exp(b_gk_last4[:, None] - tl.load(p_g, boundary_check=(0, 1)))).to(b_k.dtype)
208196
b_h4 += tl.dot(b_k, b_v)
209197
# epilogue
210198
if STORE_FINAL_STATE:
@@ -223,6 +211,7 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
223211

224212
@triton.heuristics({
225213
'USE_G': lambda args: args['g'] is not None,
214+
'USE_GK': lambda args: args['gk'] is not None,
226215
'USE_INITIAL_STATE': lambda args: args['dh0'] is not None,
227216
'USE_FINAL_STATE_GRADIENT': lambda args: args['dht'] is not None,
228217
'IS_VARLEN': lambda args: args['cu_seqlens'] is not None,
@@ -244,6 +233,7 @@ def chunk_gated_delta_rule_bwd_kernel_dhu_blockdim64(
244233
k,
245234
w,
246235
g,
236+
gk,
247237
dht,
248238
dh0,
249239
do,
@@ -260,6 +250,7 @@ def chunk_gated_delta_rule_bwd_kernel_dhu_blockdim64(
260250
BT: tl.constexpr,
261251
BV: tl.constexpr,
262252
USE_G: tl.constexpr,
253+
USE_GK: tl.constexpr,
263254
USE_INITIAL_STATE: tl.constexpr,
264255
USE_FINAL_STATE_GRADIENT: tl.constexpr,
265256
IS_VARLEN: tl.constexpr
@@ -286,13 +277,16 @@ def chunk_gated_delta_rule_bwd_kernel_dhu_blockdim64(
286277
b_dh4 = tl.zeros([64, BV], dtype=tl.float32)
287278

288279
# calculate offset
289-
dh += (boh * H + i_h) * K*V
290-
dv += (bos * H + i_h) * V
291-
dv2 += (bos * H + i_h) * V
292-
q += (bos * H + i_h) * K
293-
k += (bos * H + i_h) * K
294-
w += (bos * H + i_h) * K
295-
do += (bos * H + i_h) * V
280+
q += ((bos * H + i_h) * K).to(tl.int64)
281+
k += ((bos * H + i_h) * K).to(tl.int64)
282+
w += ((bos * H + i_h) * K).to(tl.int64)
283+
do += ((bos * H + i_h) * V).to(tl.int64)
284+
dv += ((bos * H + i_h) * V).to(tl.int64)
285+
dv2 += ((bos * H + i_h) * V).to(tl.int64)
286+
dh += ((boh * H + i_h) * K*V).to(tl.int64)
287+
if USE_GK:
288+
gk += ((bos * H + i_h) * K).to(tl.int64)
289+
296290
stride_v = H*V
297291
stride_h = H*K*V
298292
stride_k = H*K
@@ -327,44 +321,50 @@ def chunk_gated_delta_rule_bwd_kernel_dhu_blockdim64(
327321
p_dh4 = tl.make_block_ptr(dh + i_t*stride_h, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0))
328322
tl.store(p_dh4, b_dh4.to(p_dh4.dtype.element_ty), boundary_check=(0, 1))
329323

324+
last_idx = min((i_t + 1) * BT, T) - 1
330325
if USE_G:
331-
last_idx = min((i_t + 1) * BT, T) - 1
332326
bg_last = tl.load(g + (bos + last_idx) * H + i_h)
333327
bg_last_exp = exp(bg_last)
334328
p_g = tl.make_block_ptr(g + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
335329
b_g = tl.load(p_g, boundary_check=(0,))
336330
b_g_exp = exp(b_g)
337-
else:
338-
bg_last = None
339-
last_idx = None
340-
b_g = None
341-
b_g_exp = None
342331

343332
p_dv = tl.make_block_ptr(dv, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
344-
p_do = tl.make_block_ptr(do, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
345333
p_dv2 = tl.make_block_ptr(dv2, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
334+
p_do = tl.make_block_ptr(do, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
346335

347336
b_do = tl.load(p_do, boundary_check=(0, 1))
348-
b_dv = tl.zeros([BT, BV], dtype=tl.float32)
349337

350338
# Update dv
351339
p_k = tl.make_block_ptr(k, (T, K), (stride_k, 1), (i_t * BT, 0), (BT, 64), (1, 0))
352340
b_k = tl.load(p_k, boundary_check=(0, 1))
353-
b_dv += tl.dot(b_k, b_dh1.to(b_k.dtype))
341+
if USE_GK:
342+
o_k1 = tl.arange(0, 64)
343+
b_gk_last1 = tl.load(gk + last_idx * H*K + o_k1, mask=(o_k1 < K), other=0.)
344+
b_dv = tl.dot(b_k, b_dh1.to(b_k.dtype))
354345

355346
if K > 64:
356347
p_k = tl.make_block_ptr(k, (T, K), (stride_k, 1), (i_t * BT, 64), (BT, 64), (1, 0))
357348
b_k = tl.load(p_k, boundary_check=(0, 1))
349+
if USE_GK:
350+
o_k2 = 64 + o_k1
351+
b_gk_last2 = tl.load(gk + last_idx * H*K + o_k2, mask=(o_k2 < K), other=0.)
358352
b_dv += tl.dot(b_k, b_dh2.to(b_k.dtype))
359353

360354
if K > 128:
361355
p_k = tl.make_block_ptr(k, (T, K), (stride_k, 1), (i_t * BT, 128), (BT, 64), (1, 0))
362356
b_k = tl.load(p_k, boundary_check=(0, 1))
357+
if USE_GK:
358+
o_k3 = 128 + o_k1
359+
b_gk_last3 = tl.load(gk + last_idx * H*K + o_k3, mask=(o_k3 < K), other=0.)
363360
b_dv += tl.dot(b_k, b_dh3.to(b_k.dtype))
364361

365362
if K > 192:
366363
p_k = tl.make_block_ptr(k, (T, K), (stride_k, 1), (i_t * BT, 192), (BT, 64), (1, 0))
367364
b_k = tl.load(p_k, boundary_check=(0, 1))
365+
if USE_GK:
366+
o_k4 = 192 + o_k1
367+
b_gk_last4 = tl.load(gk + last_idx * H*K + o_k4, mask=(o_k4 < K), other=0.)
368368
b_dv += tl.dot(b_k, b_dh4.to(b_k.dtype))
369369

370370
if USE_G:
@@ -381,8 +381,9 @@ def chunk_gated_delta_rule_bwd_kernel_dhu_blockdim64(
381381
if USE_G:
382382
b_dh1 *= bg_last_exp
383383
b_q = b_q * b_g_exp[None, :]
384-
b_q = (b_q * scale).to(b_q.dtype)
385-
b_dh1 += tl.dot(b_q, b_do.to(b_q.dtype))-tl.dot(b_w, b_dv.to(b_w.dtype))
384+
if USE_GK:
385+
b_dh1 *= exp(b_gk_last1[:, None])
386+
b_dh1 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(b_w, b_dv.to(b_w.dtype))
386387
if K > 64:
387388
p_q = tl.make_block_ptr(q, (K, T), (1, stride_k), (64, i_t * BT), (64, BT), (0, 1))
388389
p_w = tl.make_block_ptr(w, (K, T), (1, stride_k), (64, i_t * BT), (64, BT), (0, 1))
@@ -391,8 +392,9 @@ def chunk_gated_delta_rule_bwd_kernel_dhu_blockdim64(
391392
if USE_G:
392393
b_dh2 *= bg_last_exp
393394
b_q = b_q * b_g_exp[None, :]
394-
b_q = (b_q * scale).to(b_q.dtype)
395-
b_dh2 += tl.dot(b_q, b_do.to(b_q.dtype))-tl.dot(b_w, b_dv.to(b_w.dtype))
395+
if USE_GK:
396+
b_dh2 *= exp(b_gk_last2[:, None])
397+
b_dh2 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(b_w, b_dv.to(b_w.dtype))
396398
if K > 128:
397399
p_q = tl.make_block_ptr(q, (K, T), (1, stride_k), (128, i_t * BT), (64, BT), (0, 1))
398400
p_w = tl.make_block_ptr(w, (K, T), (1, stride_k), (128, i_t * BT), (64, BT), (0, 1))
@@ -401,8 +403,9 @@ def chunk_gated_delta_rule_bwd_kernel_dhu_blockdim64(
401403
if USE_G:
402404
b_dh3 *= bg_last_exp
403405
b_q = b_q * b_g_exp[None, :]
404-
b_q = (b_q * scale).to(b_q.dtype)
405-
b_dh3 += tl.dot(b_q, b_do.to(b_q.dtype))-tl.dot(b_w, b_dv.to(b_w.dtype))
406+
if USE_GK:
407+
b_dh3 *= exp(b_gk_last3[:, None])
408+
b_dh3 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(b_w, b_dv.to(b_w.dtype))
406409
if K > 192:
407410
p_q = tl.make_block_ptr(q, (K, T), (1, stride_k), (192, i_t * BT), (64, BT), (0, 1))
408411
p_w = tl.make_block_ptr(w, (K, T), (1, stride_k), (192, i_t * BT), (64, BT), (0, 1))
@@ -411,8 +414,9 @@ def chunk_gated_delta_rule_bwd_kernel_dhu_blockdim64(
411414
if USE_G:
412415
b_dh4 *= bg_last_exp
413416
b_q = b_q * b_g_exp[None, :]
414-
b_q = (b_q * scale).to(b_q.dtype)
415-
b_dh4 += tl.dot(b_q, b_do.to(b_q.dtype))-tl.dot(b_w, b_dv.to(b_w.dtype))
417+
if USE_GK:
418+
b_dh4 *= exp(b_gk_last4[:, None])
419+
b_dh4 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(b_w, b_dv.to(b_w.dtype))
416420

417421
if USE_INITIAL_STATE:
418422
p_dh0 = tl.make_block_ptr(dh0, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0))
@@ -481,12 +485,13 @@ def chunk_gated_delta_rule_bwd_dhu(
481485
q: torch.Tensor,
482486
k: torch.Tensor,
483487
w: torch.Tensor,
484-
g: torch.Tensor,
485-
h0: torch.Tensor,
486-
dht: Optional[torch.Tensor],
487488
do: torch.Tensor,
488489
dv: torch.Tensor,
489-
scale: float,
490+
g: Optional[torch.Tensor] = None,
491+
gk: Optional[torch.Tensor] = None,
492+
h0: Optional[torch.Tensor] = None,
493+
dht: Optional[torch.Tensor] = None,
494+
scale: Optional[float] = None,
490495
cu_seqlens: Optional[torch.LongTensor] = None,
491496
chunk_size: int = 64, # SY: remove this argument and force chunk size 64?
492497
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
@@ -511,6 +516,7 @@ def grid(meta): return (triton.cdiv(V, meta['BV']), N*H)
511516
k=k,
512517
w=w,
513518
g=g,
519+
gk=gk,
514520
dht=dht,
515521
dh0=dh0,
516522
do=do,

0 commit comments

Comments
 (0)