Skip to content

Commit a191c58

Browse files
authored
Determine the chunk size at the kernel entry (#619)
* Determine the chunk size at the kernel entry * Fix split_size
1 parent 0d3e202 commit a191c58

File tree

19 files changed

+72
-83
lines changed

19 files changed

+72
-83
lines changed

fla/ops/common/chunk_h.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -283,8 +283,8 @@ def chunk_fwd_h(
283283
states_in_fp32: bool = False
284284
) -> Tuple[torch.Tensor, torch.Tensor]:
285285
B, T, H, K, V = *k.shape, v.shape[-1]
286-
BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
287-
BS = BT if split_size is None else min(split_size, max(16, triton.next_power_of_2(T)))
286+
BT = chunk_size
287+
BS = BT if split_size is None else split_size
288288
assert BS % BT == 0, f"The `split_size` (got {BS}) must be a multiple of `chunk_size` {BT}"
289289
# N: the actual number of sequences in the batch with either equal or variable lengths
290290
if cu_seqlens is None:
@@ -341,8 +341,8 @@ def chunk_bwd_dh(
341341
) -> Tuple[torch.Tensor, torch.Tensor]:
342342
B, T, H, K, V = *k.shape, v.shape[-1]
343343
HQ = q.shape[2]
344-
BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
345-
BS = BT if split_size is None else min(split_size, max(16, triton.next_power_of_2(T)))
344+
BT = chunk_size
345+
BS = BT if split_size is None else split_size
346346
assert BS % BT == 0, f"The `split_size` (got {BS}) must be a multiple of `chunk_size` {BT}"
347347
# N: the actual number of sequences in the batch with either equal or variable lengths
348348
# NG: number of groups in GQA

fla/ops/common/chunk_h_parallel.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -418,7 +418,7 @@ def chunk_fwd_h(
418418
chunk_size: int = 64
419419
) -> Tuple[torch.Tensor, torch.Tensor]:
420420
B, T, H, K, V = *k.shape, v.shape[-1]
421-
BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
421+
BT = chunk_size
422422

423423
chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
424424
# N: the actual number of sequences in the batch with either equal or variable lengths
@@ -491,7 +491,7 @@ def chunk_bwd_dh(
491491
) -> Tuple[torch.Tensor, torch.Tensor]:
492492
B, T, H, K, V = *k.shape, v.shape[-1]
493493
HQ = q.shape[2]
494-
BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
494+
BT = chunk_size
495495

496496
chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
497497
# N: the actual number of sequences in the batch with either equal or variable lengths

fla/ops/common/chunk_o.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -494,7 +494,7 @@ def chunk_fwd_o(
494494
chunk_size: int = 64
495495
) -> torch.Tensor:
496496
B, T, H, K, V = *q.shape, v.shape[-1]
497-
BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
497+
BT = chunk_size
498498
chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
499499
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
500500
if scale is None:
@@ -534,7 +534,7 @@ def chunk_bwd_dv(
534534
chunk_size: int = 64
535535
) -> torch.Tensor:
536536
B, T, H, K, V = *k.shape, do.shape[-1]
537-
BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
537+
BT = chunk_size
538538
chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
539539
# H100 can have larger block size
540540
if check_shared_mem('hopper', k.device.index):
@@ -585,7 +585,7 @@ def chunk_bwd_dv_local(
585585
chunk_size: int = 64
586586
) -> torch.Tensor:
587587
B, T, H, K, V = *k.shape, do.shape[-1]
588-
BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
588+
BT = chunk_size
589589
chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
590590
# H100 can have larger block size
591591
if check_shared_mem('hopper', k.device.index):
@@ -638,7 +638,7 @@ def chunk_bwd_dqkwg(
638638
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
639639

640640
B, T, H, K, V = *k.shape, v.shape[-1]
641-
BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
641+
BT = chunk_size
642642
chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
643643
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
644644

fla/ops/gated_delta_product/chunk_deltaproduct_o.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def chunk_gated_delta_product_fwd_o(
130130
) -> torch.Tensor:
131131
assert q.shape[1] * num_householder == k.shape[1], "q.shape[1] * num_householder must be equal to k.shape[1]"
132132
B, T, H, K, V = *q.shape, v.shape[-1]
133-
BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
133+
BT = chunk_size
134134
chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
135135
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
136136
o = v.new_empty(B, T, H, V).fill_(-float('inf'))

fla/ops/generalized_delta_rule/dplr/chunk.py

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from typing import Optional
66

77
import torch
8-
import triton
98

109
from fla.ops.generalized_delta_rule.dplr.chunk_A_bwd import chunk_dplr_bwd_dqk_intra
1110
from fla.ops.generalized_delta_rule.dplr.chunk_A_fwd import chunk_dplr_fwd_intra
@@ -32,9 +31,7 @@ def chunk_dplr_fwd(
3231
cu_seqlens: Optional[torch.LongTensor] = None,
3332
chunk_size: int = 64
3433
):
35-
T = q.shape[1]
36-
BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
37-
gi, ge = chunk_rwkv6_fwd_cumsum(gk, BT, cu_seqlens=cu_seqlens)
34+
gi, ge = chunk_rwkv6_fwd_cumsum(gk, chunk_size, cu_seqlens=cu_seqlens)
3835

3936
A_ab, A_qk, A_ak, A_qb, qg, kg, ag, bg = chunk_dplr_fwd_intra(
4037
q=q,
@@ -45,7 +42,7 @@ def chunk_dplr_fwd(
4542
ge=ge,
4643
scale=scale,
4744
cu_seqlens=cu_seqlens,
48-
chunk_size=BT,
45+
chunk_size=chunk_size,
4946
)
5047
del ge
5148

@@ -57,7 +54,7 @@ def chunk_dplr_fwd(
5754
A_ak=A_ak,
5855
v=v,
5956
cu_seqlens=cu_seqlens,
60-
chunk_size=BT
57+
chunk_size=chunk_size
6158
)
6259
del A_ab, A_ak
6360
h, v_new, final_state = chunk_dplr_fwd_h(
@@ -70,7 +67,7 @@ def chunk_dplr_fwd(
7067
initial_state=initial_state,
7168
output_final_state=output_final_state,
7269
cu_seqlens=cu_seqlens,
73-
chunk_size=BT
70+
chunk_size=chunk_size
7471
)
7572
del u, kg, bg, gi
7673

@@ -82,7 +79,7 @@ def chunk_dplr_fwd(
8279
A_qb=A_qb,
8380
h=h,
8481
cu_seqlens=cu_seqlens,
85-
chunk_size=BT
82+
chunk_size=chunk_size
8683
)
8784
del v_new, h, A_qk, A_qb
8885

@@ -136,12 +133,12 @@ def backward(
136133
dht: torch.Tensor
137134
):
138135
q, k, v, a, b, gk, initial_state = ctx.saved_tensors
139-
BT = ctx.chunk_size
136+
chunk_size = ctx.chunk_size
140137
cu_seqlens = ctx.cu_seqlens
141138
scale = ctx.scale
142139

143140
# ******* start recomputing everything, otherwise i believe the gpu memory will be exhausted *******
144-
gi, ge = chunk_rwkv6_fwd_cumsum(gk, BT, cu_seqlens=cu_seqlens)
141+
gi, ge = chunk_rwkv6_fwd_cumsum(gk, chunk_size, cu_seqlens=cu_seqlens)
145142

146143
A_ab, A_qk, A_ak, A_qb, qg, kg, ag, bg = chunk_dplr_fwd_intra(
147144
q=q,
@@ -152,15 +149,15 @@ def backward(
152149
ge=ge,
153150
scale=scale,
154151
cu_seqlens=cu_seqlens,
155-
chunk_size=BT,
152+
chunk_size=chunk_size,
156153
)
157154
w, u, A_ab_inv = prepare_wy_repr_fwd(
158155
ag=ag,
159156
A_ab=A_ab,
160157
A_ak=A_ak,
161158
v=v,
162159
cu_seqlens=cu_seqlens,
163-
chunk_size=BT
160+
chunk_size=chunk_size
164161
)
165162
del A_ab
166163
h, v_new, _ = chunk_dplr_fwd_h(
@@ -172,7 +169,7 @@ def backward(
172169
gk=gi,
173170
initial_state=initial_state,
174171
cu_seqlens=cu_seqlens,
175-
chunk_size=BT
172+
chunk_size=chunk_size
176173
)
177174
del u
178175
# ******* end of recomputation *******
@@ -186,7 +183,7 @@ def backward(
186183
A_qb=A_qb,
187184
scale=scale,
188185
cu_seqlens=cu_seqlens,
189-
chunk_size=BT
186+
chunk_size=chunk_size
190187
)
191188

192189
dh, dh0, dv_new = chunk_dplr_bwd_dhu(
@@ -199,7 +196,7 @@ def backward(
199196
do=do,
200197
dv=dv_new_intra,
201198
cu_seqlens=cu_seqlens,
202-
chunk_size=BT
199+
chunk_size=chunk_size
203200
)
204201

205202
dv = chunk_dplr_bwd_dv(
@@ -208,7 +205,7 @@ def backward(
208205
do=do,
209206
dh=dh,
210207
cu_seqlens=cu_seqlens,
211-
chunk_size=BT
208+
chunk_size=chunk_size
212209
)
213210
del A_qk
214211

@@ -224,7 +221,7 @@ def backward(
224221
w=w,
225222
gk=gi,
226223
cu_seqlens=cu_seqlens,
227-
chunk_size=BT,
224+
chunk_size=chunk_size,
228225
scale=scale,
229226
)
230227
del v_new
@@ -238,7 +235,7 @@ def backward(
238235
du=dv_new,
239236
dv0=dv,
240237
cu_seqlens=cu_seqlens,
241-
chunk_size=BT
238+
chunk_size=chunk_size
242239
)
243240
del A_ak
244241

@@ -258,7 +255,7 @@ def backward(
258255
dkg=dkg,
259256
dag=dag,
260257
dbg=dbg,
261-
chunk_size=BT,
258+
chunk_size=chunk_size,
262259
scale=scale,
263260
cu_seqlens=cu_seqlens,
264261
)

fla/ops/generalized_delta_rule/dplr/chunk_A_bwd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,7 @@ def chunk_dplr_bwd_dqk_intra(
303303
chunk_size: int = 64,
304304
):
305305
B, T, H, K = q.shape
306-
BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
306+
BT = chunk_size
307307
BK = min(64, triton.next_power_of_2(K)) if check_shared_mem() else min(32, triton.next_power_of_2(K))
308308

309309
chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None

fla/ops/generalized_delta_rule/dplr/chunk_A_fwd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ def chunk_dplr_fwd_intra(
153153
cu_seqlens: Optional[torch.LongTensor] = None,
154154
):
155155
B, T, H, K = k.shape
156-
BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
156+
BT = chunk_size
157157

158158
chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
159159
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)

fla/ops/generalized_delta_rule/dplr/chunk_h_bwd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def chunk_dplr_bwd_dhu(
121121
chunk_size: int = 64
122122
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
123123
B, T, H, K, V = *qg.shape, do.shape[-1]
124-
BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
124+
BT = chunk_size
125125
BK = max(triton.next_power_of_2(K), 16)
126126
assert BK <= 256, "current kernel does not support head dimension being larger than 256."
127127
# H100

fla/ops/generalized_delta_rule/dplr/chunk_h_fwd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def chunk_dplr_fwd_h(
120120
chunk_size: int = 64
121121
) -> Tuple[torch.Tensor, torch.Tensor]:
122122
B, T, H, K, V = *kg.shape, u.shape[-1]
123-
BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
123+
BT = chunk_size
124124

125125
chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
126126
# N: the actual number of sequences in the batch with either equal or variable lengths

fla/ops/generalized_delta_rule/dplr/chunk_o_bwd.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,7 @@ def chunk_dplr_bwd_dv(
301301
chunk_size: int = 64
302302
) -> torch.Tensor:
303303
B, T, H, K, V = *kg.shape, do.shape[-1]
304-
BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
304+
BT = chunk_size
305305

306306
chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
307307
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
@@ -344,7 +344,7 @@ def chunk_dplr_bwd_o(
344344

345345
B, T, H, K, V = *w.shape, v.shape[-1]
346346

347-
BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
347+
BT = chunk_size
348348
chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
349349
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
350350

@@ -398,7 +398,7 @@ def chunk_dplr_bwd_dAu(
398398
chunk_size: int = 64
399399
) -> torch.Tensor:
400400
B, T, H, V = v.shape
401-
BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
401+
BT = chunk_size
402402
chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
403403
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
404404

0 commit comments

Comments
 (0)