Skip to content

Commit 71ed189

Browse files
Nathancgyyzhangcs
andauthored
[DeltaFormer] Fixed testing ops error (#602)
* [Stick-Breaking Attention] Add Model * Revert "[Stick-Breaking Attention] Add Model" This reverts commit db7411a. * [deltaformer] fixed ops test error * Test under fp16 * added varlen ops test & passed * Update naive.py --------- Co-authored-by: Yu Zhang <[email protected]>
1 parent 2e73362 commit 71ed189

File tree

3 files changed

+175
-77
lines changed

3 files changed

+175
-77
lines changed

fla/ops/deltaformer/naive.py

Lines changed: 65 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,16 +36,34 @@ def tril_softmax(scores: torch.Tensor, strict: bool = True) -> torch.Tensor:
3636
return probs
3737

3838

39-
def naive_deltaformer_attn(
39+
def naive_causal_attention_bhtd(
40+
q: torch.Tensor,
41+
k: torch.Tensor,
42+
v: torch.Tensor,
43+
) -> torch.Tensor:
44+
B, H, T, D = q.shape
45+
qk_scale = 1.0 / math.sqrt(D)
46+
scores = torch.matmul(q, k.transpose(-1, -2)) * qk_scale # [B, H, T, T]
47+
causal_mask = torch.triu(torch.ones(T, T, device=q.device), diagonal=1).bool()
48+
scores = scores.masked_fill(causal_mask, float('-inf'))
49+
attn_weights = torch.softmax(scores, dim=-1) # [B, H, T, T]
50+
o = torch.matmul(attn_weights, v) # [B, H, T, D]
51+
52+
return o
53+
54+
55+
def naive_deltaformer_attn_head_first(
4056
q: torch.Tensor,
4157
k: torch.Tensor,
4258
v: torch.Tensor,
4359
beta: Optional[torch.Tensor] = None,
4460
) -> torch.Tensor:
4561
"""
46-
Naive reference implementation of DeltaFormer pre-attention.
62+
Naive reference implementation of DeltaFormer attention for head-first format.
4763
48-
Computes u[i] = v[i] - beta[i] * sum_{j<i} softmax(q[i] @ k[:i]^T) @ u[:i]
64+
Two-stage process:
65+
1. Computes u[i] = v[i] - beta[i] * sum_{j<i} softmax(q[i] @ k[:i]^T) @ u[:i]
66+
2. Applies causal attention: o = causal_attn(q, k, u)
4967
5068
Args:
5169
q: [B, H, T, D]
@@ -54,7 +72,7 @@ def naive_deltaformer_attn(
5472
beta: [B, H, T] or None (defaults to ones)
5573
5674
Returns:
57-
u: [B, H, T, D]
75+
o: [B, H, T, D]
5876
"""
5977
assert q.dim() == 4 and k.dim() == 4 and v.dim() == 4, "q,k,v must be [B,H,T,D]"
6078
B, H, T, D = q.shape
@@ -83,8 +101,49 @@ def naive_deltaformer_attn(
83101
weighted_sum = (w.unsqueeze(-1) * u_prev).sum(dim=-2) # [B,H,D]
84102
u_t = vf[:, :, t, :] - betaf[:, :, t].unsqueeze(-1) * weighted_sum
85103
u_list.append(u_t)
86-
u = torch.stack(u_list, dim=2)
87-
return u.to(orig_dtype)
104+
u = torch.stack(u_list, dim=2) # [B,H,T,D]
105+
106+
o = naive_causal_attention_bhtd(q, k, u.to(orig_dtype))
107+
return o.to(orig_dtype)
108+
109+
110+
def naive_deltaformer_attn(
111+
q: torch.Tensor,
112+
k: torch.Tensor,
113+
v: torch.Tensor,
114+
beta: Optional[torch.Tensor] = None,
115+
) -> torch.Tensor:
116+
"""
117+
Naive reference implementation of DeltaFormer attention for sequence-first format.
118+
119+
Args:
120+
q: [B, T, H, D]
121+
k: [B, T, H, D]
122+
v: [B, T, H, D]
123+
beta: [B, T, H] or None (defaults to ones)
124+
125+
Returns:
126+
o: [B, T, H, D]
127+
"""
128+
assert q.dim() == 4 and k.dim() == 4 and v.dim() == 4, "q,k,v must be [B,T,H,D]"
129+
B, T, H, D = q.shape
130+
assert k.shape == (B, T, H, D) and v.shape == (B, T, H, D)
131+
132+
q_bhtd = q.transpose(1, 2) # [B, T, H, D] -> [B, H, T, D]
133+
k_bhtd = k.transpose(1, 2) # [B, T, H, D] -> [B, H, T, D]
134+
v_bhtd = v.transpose(1, 2) # [B, T, H, D] -> [B, H, T, D]
135+
136+
if beta is not None:
137+
assert beta.shape == (B, T, H)
138+
beta_bhtd = beta.transpose(1, 2) # [B, T, H] -> [B, H, T]
139+
else:
140+
beta_bhtd = None
141+
142+
o_bhtd = naive_deltaformer_attn_head_first(q_bhtd, k_bhtd, v_bhtd, beta_bhtd)
143+
144+
o_bthd = o_bhtd.transpose(1, 2) # [B, H, T, D] -> [B, T, H, D]
145+
146+
return o_bthd
88147

89148

90149
__all__ = [

fla/ops/deltaformer/parallel.py

Lines changed: 29 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,7 @@ def parallel_deltaformer_bwd_kernel_u(
300300
block_shape=(BLOCK_C, D),
301301
order=(1, 0),
302302
)
303-
q = tl.load(q_blk_ptr)
303+
q = tl.load(q_blk_ptr, boundary_check=(0,))
304304

305305
for kv_i in range(0, T, BLOCK_T):
306306
k_blk_ptr = tl.make_block_ptr(
@@ -311,7 +311,7 @@ def parallel_deltaformer_bwd_kernel_u(
311311
block_shape=(D, BLOCK_T),
312312
order=(0, 1),
313313
)
314-
k = tl.load(k_blk_ptr)
314+
k = tl.load(k_blk_ptr, boundary_check=(1,))
315315
qk = tl.dot(q, k) * fa_scale
316316

317317
lse_blk_ptr = tl.make_block_ptr(
@@ -322,7 +322,7 @@ def parallel_deltaformer_bwd_kernel_u(
322322
block_shape=(BLOCK_T,),
323323
order=(0,),
324324
)
325-
lse = tl.load(lse_blk_ptr)
325+
lse = tl.load(lse_blk_ptr, boundary_check=(0,))
326326
beta_blk_ptr = tl.make_block_ptr(
327327
base=beta_ptr + pid_h,
328328
shape=(T,),
@@ -331,7 +331,7 @@ def parallel_deltaformer_bwd_kernel_u(
331331
block_shape=(BLOCK_T,),
332332
order=(0,),
333333
)
334-
beta = tl.load(beta_blk_ptr)
334+
beta = tl.load(beta_blk_ptr, boundary_check=(0,))
335335

336336
p = tl.math.exp2(qk - lse[None, :]) * beta[None, :]
337337

@@ -343,7 +343,7 @@ def parallel_deltaformer_bwd_kernel_u(
343343
block_shape=(BLOCK_T, D),
344344
order=(1, 0),
345345
)
346-
v = tl.load(v_blk_ptr)
346+
v = tl.load(v_blk_ptr, boundary_check=(0,))
347347
acc = tl.dot(p.to(v_ptr.dtype.element_ty), v, acc)
348348

349349
o_blk_ptr = tl.make_block_ptr(
@@ -354,7 +354,7 @@ def parallel_deltaformer_bwd_kernel_u(
354354
block_shape=(BLOCK_C, D),
355355
order=(1, 0),
356356
)
357-
tl.store(o_blk_ptr, acc.to(o_ptr.dtype.element_ty))
357+
tl.store(o_blk_ptr, acc.to(o_ptr.dtype.element_ty), boundary_check=(0,))
358358

359359

360360
@triton.autotune(configs=_config_deltaformer(), key=['T', 'D'])
@@ -389,7 +389,7 @@ def parallel_deltaformer_bwd_kernel_row_sum(
389389
block_shape=(BLOCK_C, D),
390390
order=(1, 0),
391391
)
392-
k_row = tl.load(k_row_blk_ptr)
392+
k_row = tl.load(k_row_blk_ptr, boundary_check=(0,))
393393
lse_blk_ptr = tl.make_block_ptr(
394394
base=lse_ptr + pid_h,
395395
shape=(T,),
@@ -398,7 +398,7 @@ def parallel_deltaformer_bwd_kernel_row_sum(
398398
block_shape=(BLOCK_C,),
399399
order=(0,),
400400
)
401-
lse = tl.load(lse_blk_ptr)
401+
lse = tl.load(lse_blk_ptr, boundary_check=(0,))
402402
grad_v_blk_ptr = tl.make_block_ptr(
403403
base=grad_v_ptr + pid_h * D,
404404
shape=(T, D),
@@ -407,7 +407,7 @@ def parallel_deltaformer_bwd_kernel_row_sum(
407407
block_shape=(BLOCK_C, D),
408408
order=(1, 0),
409409
)
410-
grad_v_row = -tl.load(grad_v_blk_ptr)
410+
grad_v_row = -tl.load(grad_v_blk_ptr, boundary_check=(0,))
411411

412412
for kv_i in range(0, (pid_c + 1) * BLOCK_C, BLOCK_T):
413413
k_blk_ptr = tl.make_block_ptr(
@@ -418,7 +418,7 @@ def parallel_deltaformer_bwd_kernel_row_sum(
418418
block_shape=(D, BLOCK_T),
419419
order=(0, 1),
420420
)
421-
k = tl.load(k_blk_ptr)
421+
k = tl.load(k_blk_ptr, boundary_check=(1,))
422422
qk = tl.dot(k_row, k) * fa_scale
423423
p = tl.math.exp2(qk - lse[:, None])
424424

@@ -430,7 +430,7 @@ def parallel_deltaformer_bwd_kernel_row_sum(
430430
block_shape=(D, BLOCK_T),
431431
order=(0, 1),
432432
)
433-
ut = tl.load(u_blk_ptr)
433+
ut = tl.load(u_blk_ptr, boundary_check=(1,))
434434
dp = tl.dot(grad_v_row, ut)
435435
if kv_i + BLOCK_T >= pid_c * BLOCK_C:
436436
mask = (rowid_block[:, None] <= colid_block[None, :] + kv_i)
@@ -445,7 +445,7 @@ def parallel_deltaformer_bwd_kernel_row_sum(
445445
block_shape=(BLOCK_C,),
446446
order=(0,),
447447
)
448-
tl.store(row_dot_block_ptr, acc)
448+
tl.store(row_dot_block_ptr, acc, boundary_check=(0,))
449449

450450

451451
@triton.autotune(configs=[triton.Config({'BLOCK_C': BC}, num_stages=ns, num_warps=nw)
@@ -484,7 +484,7 @@ def parallel_deltaformer_bwd_kernel_qk(
484484
block_shape=(BLOCK_C, D),
485485
order=(1, 0),
486486
)
487-
k_row = tl.load(k_row_blk_ptr)
487+
k_row = tl.load(k_row_blk_ptr, boundary_check=(0,))
488488
lse_blk_ptr = tl.make_block_ptr(
489489
base=lse_ptr + pid_h,
490490
shape=(T,),
@@ -493,7 +493,7 @@ def parallel_deltaformer_bwd_kernel_qk(
493493
block_shape=(BLOCK_C,),
494494
order=(0,),
495495
)
496-
lse = tl.load(lse_blk_ptr)
496+
lse = tl.load(lse_blk_ptr, boundary_check=(0,))
497497
beta_blk_ptr = tl.make_block_ptr(
498498
base=beta_ptr + pid_h,
499499
shape=(T,),
@@ -502,7 +502,7 @@ def parallel_deltaformer_bwd_kernel_qk(
502502
block_shape=(BLOCK_C,),
503503
order=(0,),
504504
)
505-
beta = tl.load(beta_blk_ptr)
505+
beta = tl.load(beta_blk_ptr, boundary_check=(0,))
506506
grad_v_blk_ptr = tl.make_block_ptr(
507507
base=grad_v_ptr + pid_h * D,
508508
shape=(T, D),
@@ -511,7 +511,7 @@ def parallel_deltaformer_bwd_kernel_qk(
511511
block_shape=(BLOCK_C, D),
512512
order=(1, 0),
513513
)
514-
grad_v_row = -tl.load(grad_v_blk_ptr)
514+
grad_v_row = -tl.load(grad_v_blk_ptr, boundary_check=(0,))
515515
row_dot_blk_ptr = tl.make_block_ptr(
516516
base=row_dot_ptr + pid_h,
517517
shape=(T,),
@@ -520,7 +520,7 @@ def parallel_deltaformer_bwd_kernel_qk(
520520
block_shape=(BLOCK_C,),
521521
order=(0,),
522522
)
523-
row_dot_row = tl.load(row_dot_blk_ptr).to(k_ptr.dtype.element_ty)
523+
row_dot_row = tl.load(row_dot_blk_ptr, boundary_check=(0,)).to(k_ptr.dtype.element_ty)
524524

525525
for kv_i in range(0, pid_c * BLOCK_C, BLOCK_C):
526526
k_blk_ptr = tl.make_block_ptr(
@@ -531,7 +531,7 @@ def parallel_deltaformer_bwd_kernel_qk(
531531
block_shape=(D, BLOCK_C),
532532
order=(0, 1),
533533
)
534-
kt = tl.load(k_blk_ptr)
534+
kt = tl.load(k_blk_ptr, boundary_check=(1,))
535535
qk = tl.dot(k_row, kt) * fa_scale
536536
p = tl.math.exp2(qk - lse[:, None]) * beta[:, None]
537537

@@ -557,7 +557,7 @@ def parallel_deltaformer_bwd_kernel_qk(
557557
block_shape=(BLOCK_C, D),
558558
order=(1, 0),
559559
)
560-
k_row_true = tl.load(k_row_blk_ptr)
560+
k_row_true = tl.load(k_row_blk_ptr, boundary_check=(0,))
561561
qk = tl.dot(k_row, tl.trans(k_row_true, 1, 0)) * fa_scale
562562
p = tl.math.exp2(qk - lse[:, None]) * beta[:, None]
563563
u_blk_ptr = tl.make_block_ptr(
@@ -587,7 +587,7 @@ def parallel_deltaformer_bwd_kernel_qk(
587587
order=(1, 0),
588588
)
589589
acc = acc * qk_scale
590-
tl.store(grad_q_blk_ptr, acc.to(grad_q_ptr.dtype.element_ty))
590+
tl.store(grad_q_blk_ptr, acc.to(grad_q_ptr.dtype.element_ty), boundary_check=(0,))
591591

592592
daat = tl.trans(da, 1, 0)
593593
acc = tl.dot(daat.to(k_row.dtype), k_row)
@@ -602,7 +602,7 @@ def parallel_deltaformer_bwd_kernel_qk(
602602
block_shape=(D, BLOCK_C),
603603
order=(0, 1),
604604
)
605-
kt = tl.load(k_blk_ptr)
605+
kt = tl.load(k_blk_ptr, boundary_check=(1,))
606606
lse_blk_ptr = tl.make_block_ptr(
607607
base=lse_ptr + pid_h,
608608
shape=(T,),
@@ -611,7 +611,7 @@ def parallel_deltaformer_bwd_kernel_qk(
611611
block_shape=(BLOCK_C,),
612612
order=(0,),
613613
)
614-
lse = tl.load(lse_blk_ptr)
614+
lse = tl.load(lse_blk_ptr, boundary_check=(0,))
615615
beta_blk_ptr = tl.make_block_ptr(
616616
base=beta_ptr + pid_h,
617617
shape=(T,),
@@ -620,7 +620,7 @@ def parallel_deltaformer_bwd_kernel_qk(
620620
block_shape=(BLOCK_C,),
621621
order=(0,),
622622
)
623-
beta = tl.load(beta_blk_ptr)
623+
beta = tl.load(beta_blk_ptr, boundary_check=(0,))
624624
qk = tl.dot(k_row, kt) * fa_scale
625625
p = tl.math.exp2(qk - lse[None, :]) * beta[None, :]
626626

@@ -632,7 +632,7 @@ def parallel_deltaformer_bwd_kernel_qk(
632632
block_shape=(D, BLOCK_C),
633633
order=(0, 1),
634634
)
635-
grad_vt = tl.load(grad_vt_blk_ptr)
635+
grad_vt = tl.load(grad_vt_blk_ptr, boundary_check=(1,))
636636
row_dot_blk_ptr = tl.make_block_ptr(
637637
base=row_dot_ptr + pid_h,
638638
shape=(T,),
@@ -641,7 +641,7 @@ def parallel_deltaformer_bwd_kernel_qk(
641641
block_shape=(BLOCK_C,),
642642
order=(0,),
643643
)
644-
row_dot = tl.load(row_dot_blk_ptr).to(k_ptr.dtype.element_ty)
644+
row_dot = tl.load(row_dot_blk_ptr, boundary_check=(0,)).to(k_ptr.dtype.element_ty)
645645
dp = tl.dot(nu, grad_vt)
646646
da = p * (dp - row_dot[None, :])
647647
k = tl.trans(kt, 1, 0)
@@ -656,7 +656,7 @@ def parallel_deltaformer_bwd_kernel_qk(
656656
order=(1, 0),
657657
)
658658
acc = acc * qk_scale
659-
tl.store(grad_k_blk_ptr, acc.to(grad_k_ptr.dtype.element_ty))
659+
tl.store(grad_k_blk_ptr, acc.to(grad_k_ptr.dtype.element_ty), boundary_check=(0,))
660660

661661

662662
class ParallelDeltaformerFunction(torch.autograd.Function):
@@ -872,6 +872,7 @@ def _forward_impl(
872872
w_t = w.transpose(0, 1).contiguous()
873873
u_chunk_view_t = u_chunk_view.transpose(0, 1).contiguous()
874874
invcum.forward_inplace(u_chunk_view_t, w_t)
875+
u_chunk_view.copy_(u_chunk_view_t.transpose(0, 1))
875876

876877
chunk_base += (T_max + C - 1) // C
877878

@@ -932,6 +933,7 @@ def _forward_impl(
932933
w_t = w.transpose(0, 1).contiguous()
933934
u_chunk_view_t = u_chunk_view.transpose(0, 1).contiguous()
934935
invcum.forward_inplace(u_chunk_view_t, w_t)
936+
u_chunk_view.copy_(u_chunk_view_t.transpose(0, 1))
935937

936938
chunk_base += (L + C - 1) // C
937939

@@ -953,7 +955,7 @@ def deltaformer_attn(
953955
B, T, H, D = k.shape
954956
C = min(C, T)
955957

956-
u = ParallelDeltaformerFunction.apply(k, k, v, beta, C, cu_seqlens)
958+
u = ParallelDeltaformerFunction.apply(q, k, v, beta, C, cu_seqlens)
957959

958960
if attention_mask is not None:
959961
q_padded, (k_padded, u_padded), indices_q, cu_seqlens_lens, max_seq_lens = unpad_input(q, (k, u), attention_mask, T)

0 commit comments

Comments
 (0)