Skip to content

Commit f74c46a

Browse files
committed
[Inference] Delete head_first warnings during generation
1 parent 6f232e0 commit f74c46a

File tree

7 files changed

+38
-184
lines changed

7 files changed

+38
-184
lines changed

fla/ops/delta_rule/fused_recurrent.py

Lines changed: 5 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
11
# -*- coding: utf-8 -*-
22
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
33

4-
import warnings
54
from typing import Optional, Tuple
65

76
import torch
87
import triton
98
import triton.language as tl
10-
from einops import rearrange
119

1210
from fla.modules.l2norm import l2norm_bwd, l2norm_fwd
1311
from fla.utils import input_guard
@@ -448,19 +446,18 @@ def fused_recurrent_delta_rule(
448446
initial_state: torch.Tensor = None,
449447
output_final_state: bool = False,
450448
cu_seqlens: Optional[torch.LongTensor] = None,
451-
head_first: bool = False,
452449
use_qk_l2norm_in_kernel: bool = False
453450
) -> Tuple[torch.Tensor, torch.Tensor]:
454451
r"""
455452
Args:
456453
q (torch.Tensor):
457-
queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
454+
queries of shape `[B, T, H, K]`.
458455
k (torch.Tensor):
459-
keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
456+
keys of shape `[B, T, H, K]`.
460457
v (torch.Tensor):
461-
values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
458+
values of shape `[B, T, H, V]`.
462459
beta (torch.Tensor):
463-
betas of shape `[B, T, H]` if `head_first=False` else `(B, H, T)`.
460+
betas of shape `[B, T, H]`.
464461
scale (Optional[int]):
465462
Scale factor for the RetNet attention scores.
466463
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
@@ -473,13 +470,10 @@ def fused_recurrent_delta_rule(
473470
cu_seqlens (torch.LongTensor):
474471
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
475472
consistent with the FlashAttention API.
476-
head_first (Optional[bool]):
477-
Whether the inputs are in the head-first format, which is not supported for variable-length inputs.
478-
Default: `False`.
479473
480474
Returns:
481475
o (torch.Tensor):
482-
Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
476+
Outputs of shape `[B, T, H, V]`.
483477
final_state (torch.Tensor):
484478
Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`.
485479
@@ -513,19 +507,6 @@ def fused_recurrent_delta_rule(
513507
>>> assert o.allclose(o_var.view(o.shape))
514508
>>> assert ht.allclose(ht_var)
515509
"""
516-
if head_first:
517-
raise DeprecationWarning(
518-
"head_first is deprecated and will be removed in a future version. "
519-
"Please use head_first=False for now instead."
520-
)
521-
q, k, v, beta = map(lambda x: rearrange(x, 'b h t ... -> b t h ...'), (q, k, v, beta))
522-
if not head_first and q.shape[1] < q.shape[2]:
523-
warnings.warn(
524-
f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
525-
"This may indicate the inputs were passed in head-first format [B, H, T, ...] "
526-
"when head_first=False was specified. "
527-
"Please verify your input tensor format matches the expected shape [B, T, H, ...]."
528-
)
529510
if cu_seqlens is not None:
530511
if q.shape[0] != 1:
531512
raise ValueError(
@@ -554,6 +535,4 @@ def fused_recurrent_delta_rule(
554535
cu_seqlens,
555536
use_qk_l2norm_in_kernel
556537
)
557-
if head_first:
558-
o = rearrange(o, 'b t h v -> b h t v')
559538
return o, final_state

fla/ops/gated_delta_rule/fused_recurrent.py

Lines changed: 6 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
11
# -*- coding: utf-8 -*-
22
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
33

4-
import warnings
54
from typing import Optional, Tuple
65

76
import torch
87
import triton
98
import triton.language as tl
10-
from einops import rearrange
119

1210
from fla.ops.utils.op import exp
1311
from fla.utils import input_guard
@@ -218,20 +216,19 @@ def fused_recurrent_gated_delta_rule(
218216
output_final_state: bool = False,
219217
cu_seqlens: Optional[torch.LongTensor] = None,
220218
use_qk_l2norm_in_kernel: bool = False,
221-
head_first: bool = False,
222219
) -> Tuple[torch.Tensor, torch.Tensor]:
223220
r"""
224221
Args:
225222
q (torch.Tensor):
226-
queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
223+
queries of shape `[B, T, H, K]`.
227224
k (torch.Tensor):
228-
keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
225+
keys of shape `[B, T, H, K]`.
229226
v (torch.Tensor):
230-
values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
227+
values of shape `[B, T, H, V]`.
231228
g (torch.Tensor):
232-
g (decays) of shape `[B, T, H]` if `head_first=False` else `(B, H, T)`.
229+
g (decays) of shape `[B, T, H]`.
233230
beta (torch.Tensor):
234-
betas of shape `[B, T, H]` if `head_first=False` else `(B, H, T)`.
231+
betas of shape `[B, T, H]`.
235232
scale (Optional[int]):
236233
Scale factor for the RetNet attention scores.
237234
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
@@ -247,7 +244,7 @@ def fused_recurrent_gated_delta_rule(
247244
248245
Returns:
249246
o (torch.Tensor):
250-
Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
247+
Outputs of shape `[B, T, H, V]`.
251248
final_state (torch.Tensor):
252249
Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`.
253250
@@ -282,19 +279,6 @@ def fused_recurrent_gated_delta_rule(
282279
>>> assert o.allclose(o_var.view(o.shape))
283280
>>> assert ht.allclose(ht_var)
284281
"""
285-
if head_first:
286-
raise DeprecationWarning(
287-
"head_first is deprecated and will be removed in a future version. "
288-
"Please use head_first=False for now instead."
289-
)
290-
q, k, v, beta, g = map(lambda x: rearrange(x, 'b h t ... -> b t h ...'), (q, k, v, beta, g))
291-
if not head_first and q.shape[1] < q.shape[2]:
292-
warnings.warn(
293-
f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
294-
"This may indicate the inputs were passed in head-first format [B, H, T, ...] "
295-
"when head_first=False was specified. "
296-
"Please verify your input tensor format matches the expected shape [B, T, H, ...]."
297-
)
298282
if cu_seqlens is not None:
299283
if q.shape[0] != 1:
300284
raise ValueError(
@@ -324,6 +308,4 @@ def fused_recurrent_gated_delta_rule(
324308
cu_seqlens,
325309
use_qk_l2norm_in_kernel
326310
)
327-
if head_first:
328-
o = rearrange(o, 'b t h v -> b h t v')
329311
return o, final_state

fla/ops/generalized_delta_rule/dplr/fused_recurrent.py

Lines changed: 6 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
11
# -*- coding: utf-8 -*-
22
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
33

4-
import warnings
54
from typing import Optional, Tuple
65

76
import torch
87
import triton
98
import triton.language as tl
10-
from einops import rearrange
119

1210
from fla.ops.utils.op import exp
1311
from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard, use_cuda_graph
@@ -212,24 +210,23 @@ def fused_recurrent_dplr_delta_rule(
212210
output_final_state: bool = False,
213211
reverse: bool = False,
214212
cu_seqlens: Optional[torch.Tensor] = None,
215-
head_first: bool = False,
216213
) -> Tuple[torch.Tensor, torch.Tensor]:
217214
r"""
218215
This function computes the recurrence S_t = S_t @ (I + a_t b_t^T) + v_t k_t^T in a recurrent manner.
219216
220217
Args:
221218
q (torch.Tensor):
222-
queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
219+
queries of shape `[B, T, H, K]`.
223220
k (torch.Tensor):
224-
keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
221+
keys of shape `[B, T, H, K]`.
225222
v (torch.Tensor):
226-
values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
223+
values of shape `[B, T, H, V]`.
227224
a (torch.Tensor):
228-
a of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
225+
a of shape `[B, T, H, K]`.
229226
b (torch.Tensor):
230-
b of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
227+
b of shape `[B, T, H, K]`.
231228
gk (torch.Tensor):
232-
gk of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. decay term in log space!
229+
gk of shape `[B, T, H, K]`. decay term in log space!
233230
scale (Optional[int]):
234231
Scale factor for the RetNet attention scores.
235232
If not provided, it will default to `1 / sqrt(K)`. Default: 1.
@@ -244,23 +241,7 @@ def fused_recurrent_dplr_delta_rule(
244241
cu_seqlens (Optional[torch.Tensor]):
245242
Cumulative sequence lengths of shape `[N + 1]` used for variable-length training,
246243
consistent with the FlashAttention API.
247-
head_first (Optional[bool]):
248-
Whether the inputs are in the head-first format, which is not supported for variable-length inputs.
249-
Default: `False`.
250244
"""
251-
if head_first:
252-
raise DeprecationWarning(
253-
"head_first is deprecated and will be removed in a future version. "
254-
"Please use head_first=False for now instead."
255-
)
256-
q, k, v, a, b, gk = map(lambda x: rearrange(x, 'b h t ... -> b t h ...'), (q, k, v, a, b, gk))
257-
if not head_first and q.shape[1] < q.shape[2]:
258-
warnings.warn(
259-
f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
260-
"This may indicate the inputs were passed in head-first format [B, H, T, ...] "
261-
"when head_first=False was specified. "
262-
"Please verify your input tensor format matches the expected shape [B, T, H, ...]."
263-
)
264245
if cu_seqlens is not None:
265246
if q.shape[0] != 1:
266247
raise ValueError(
@@ -289,6 +270,4 @@ def fused_recurrent_dplr_delta_rule(
289270
reverse,
290271
cu_seqlens,
291272
)
292-
if head_first:
293-
o = rearrange(o, 'b t h ... -> b h t ...')
294273
return o, final_state

fla/ops/generalized_delta_rule/iplr/fused_recurrent.py

Lines changed: 6 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
11
# -*- coding: utf-8 -*-
2-
# Copyright (c) 2024-2025, Songlin Yang, Yu Zhang
2+
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
33

4-
import warnings
54
from typing import Optional, Tuple
65

76
import torch
87
import triton
98
import triton.language as tl
10-
from einops import rearrange
119

1210
from fla.utils import input_guard
1311

@@ -398,22 +396,21 @@ def fused_recurrent_iplr_delta_rule(
398396
initial_state: torch.Tensor = None,
399397
output_final_state: bool = False,
400398
cu_seqlens: Optional[torch.Tensor] = None,
401-
head_first: bool = False
402399
) -> Tuple[torch.Tensor, torch.Tensor]:
403400
r"""
404401
This function computes the recurrence S_t = S_t @ (I + a_t b_t^T) + v_t k_t^T in a recurrent manner.
405402
406403
Args:
407404
q (torch.Tensor):
408-
queries of shape `[B, H, T, K]`
405+
queries of shape `[B, T, H, K]`
409406
k (torch.Tensor):
410-
keys of shape `[B, H, T, K]`
407+
keys of shape `[B, T, H, K]`
411408
v (torch.Tensor):
412-
values of shape `[B, H, T, V]`
409+
values of shape `[B, T, H, V]`
413410
a (torch.Tensor):
414-
as of shape `[B, H, T, K]`
411+
as of shape `[B, T, H, K]`
415412
b (torch.Tensor):
416-
bs of shape `[B, H, T, K]`
413+
bs of shape `[B, T, H, K]`
417414
scale (Optional[int]):
418415
Scale factor for the RetNet attention scores.
419416
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
@@ -426,19 +423,6 @@ def fused_recurrent_iplr_delta_rule(
426423
consistent with the FlashAttention API.
427424
428425
"""
429-
if head_first:
430-
raise DeprecationWarning(
431-
"head_first is deprecated and will be removed in a future version. "
432-
"Please use head_first=False for now instead."
433-
)
434-
q, k, v, a, b = map(lambda x: rearrange(x, 'b h t ... -> b t h ...'), (q, k, v, a, b))
435-
if not head_first and q.shape[1] < q.shape[2]:
436-
warnings.warn(
437-
f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
438-
"This may indicate the inputs were passed in head-first format [B, H, T, ...] "
439-
"when head_first=False was specified. "
440-
"Please verify your input tensor format matches the expected shape [B, T, H, ...]."
441-
)
442426
if cu_seqlens is not None:
443427
if q.shape[0] != 1:
444428
raise ValueError(
@@ -465,6 +449,4 @@ def fused_recurrent_iplr_delta_rule(
465449
output_final_state,
466450
cu_seqlens
467451
)
468-
if head_first:
469-
o = rearrange(o, 'b t h ... -> b h t ...')
470452
return o, final_state

0 commit comments

Comments
 (0)