55from typing import Optional
66
77import torch
8- import triton
98
109from fla .ops .generalized_delta_rule .dplr .chunk_A_bwd import chunk_dplr_bwd_dqk_intra
1110from fla .ops .generalized_delta_rule .dplr .chunk_A_fwd import chunk_dplr_fwd_intra
@@ -32,9 +31,7 @@ def chunk_dplr_fwd(
3231 cu_seqlens : Optional [torch .LongTensor ] = None ,
3332 chunk_size : int = 64
3433):
35- T = q .shape [1 ]
36- BT = min (chunk_size , max (triton .next_power_of_2 (T ), 16 ))
37- gi , ge = chunk_rwkv6_fwd_cumsum (gk , BT , cu_seqlens = cu_seqlens )
34+ gi , ge = chunk_rwkv6_fwd_cumsum (gk , chunk_size , cu_seqlens = cu_seqlens )
3835
3936 A_ab , A_qk , A_ak , A_qb , qg , kg , ag , bg = chunk_dplr_fwd_intra (
4037 q = q ,
@@ -45,7 +42,7 @@ def chunk_dplr_fwd(
4542 ge = ge ,
4643 scale = scale ,
4744 cu_seqlens = cu_seqlens ,
48- chunk_size = BT ,
45+ chunk_size = chunk_size ,
4946 )
5047 del ge
5148
@@ -57,7 +54,7 @@ def chunk_dplr_fwd(
5754 A_ak = A_ak ,
5855 v = v ,
5956 cu_seqlens = cu_seqlens ,
60- chunk_size = BT
57+ chunk_size = chunk_size
6158 )
6259 del A_ab , A_ak
6360 h , v_new , final_state = chunk_dplr_fwd_h (
@@ -70,7 +67,7 @@ def chunk_dplr_fwd(
7067 initial_state = initial_state ,
7168 output_final_state = output_final_state ,
7269 cu_seqlens = cu_seqlens ,
73- chunk_size = BT
70+ chunk_size = chunk_size
7471 )
7572 del u , kg , bg , gi
7673
@@ -82,7 +79,7 @@ def chunk_dplr_fwd(
8279 A_qb = A_qb ,
8380 h = h ,
8481 cu_seqlens = cu_seqlens ,
85- chunk_size = BT
82+ chunk_size = chunk_size
8683 )
8784 del v_new , h , A_qk , A_qb
8885
@@ -136,12 +133,12 @@ def backward(
136133 dht : torch .Tensor
137134 ):
138135 q , k , v , a , b , gk , initial_state = ctx .saved_tensors
139- BT = ctx .chunk_size
136+ chunk_size = ctx .chunk_size
140137 cu_seqlens = ctx .cu_seqlens
141138 scale = ctx .scale
142139
143140 # ******* start recomputing everything, otherwise i believe the gpu memory will be exhausted *******
144- gi , ge = chunk_rwkv6_fwd_cumsum (gk , BT , cu_seqlens = cu_seqlens )
141+ gi , ge = chunk_rwkv6_fwd_cumsum (gk , chunk_size , cu_seqlens = cu_seqlens )
145142
146143 A_ab , A_qk , A_ak , A_qb , qg , kg , ag , bg = chunk_dplr_fwd_intra (
147144 q = q ,
@@ -152,15 +149,15 @@ def backward(
152149 ge = ge ,
153150 scale = scale ,
154151 cu_seqlens = cu_seqlens ,
155- chunk_size = BT ,
152+ chunk_size = chunk_size ,
156153 )
157154 w , u , A_ab_inv = prepare_wy_repr_fwd (
158155 ag = ag ,
159156 A_ab = A_ab ,
160157 A_ak = A_ak ,
161158 v = v ,
162159 cu_seqlens = cu_seqlens ,
163- chunk_size = BT
160+ chunk_size = chunk_size
164161 )
165162 del A_ab
166163 h , v_new , _ = chunk_dplr_fwd_h (
@@ -172,7 +169,7 @@ def backward(
172169 gk = gi ,
173170 initial_state = initial_state ,
174171 cu_seqlens = cu_seqlens ,
175- chunk_size = BT
172+ chunk_size = chunk_size
176173 )
177174 del u
178175 # ******* end of recomputation *******
@@ -186,7 +183,7 @@ def backward(
186183 A_qb = A_qb ,
187184 scale = scale ,
188185 cu_seqlens = cu_seqlens ,
189- chunk_size = BT
186+ chunk_size = chunk_size
190187 )
191188
192189 dh , dh0 , dv_new = chunk_dplr_bwd_dhu (
@@ -199,7 +196,7 @@ def backward(
199196 do = do ,
200197 dv = dv_new_intra ,
201198 cu_seqlens = cu_seqlens ,
202- chunk_size = BT
199+ chunk_size = chunk_size
203200 )
204201
205202 dv = chunk_dplr_bwd_dv (
@@ -208,7 +205,7 @@ def backward(
208205 do = do ,
209206 dh = dh ,
210207 cu_seqlens = cu_seqlens ,
211- chunk_size = BT
208+ chunk_size = chunk_size
212209 )
213210 del A_qk
214211
@@ -224,7 +221,7 @@ def backward(
224221 w = w ,
225222 gk = gi ,
226223 cu_seqlens = cu_seqlens ,
227- chunk_size = BT ,
224+ chunk_size = chunk_size ,
228225 scale = scale ,
229226 )
230227 del v_new
@@ -238,7 +235,7 @@ def backward(
238235 du = dv_new ,
239236 dv0 = dv ,
240237 cu_seqlens = cu_seqlens ,
241- chunk_size = BT
238+ chunk_size = chunk_size
242239 )
243240 del A_ak
244241
@@ -258,7 +255,7 @@ def backward(
258255 dkg = dkg ,
259256 dag = dag ,
260257 dbg = dbg ,
261- chunk_size = BT ,
258+ chunk_size = chunk_size ,
262259 scale = scale ,
263260 cu_seqlens = cu_seqlens ,
264261 )
0 commit comments