Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions fla/ops/linear_attn/fused_chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,9 +276,10 @@ def fused_chunk_linear_attn(
v: torch.Tensor,
scale: Optional[float] = None,
initial_state: torch.Tensor = None,
z_state: torch.Tensor = None,
output_final_state: bool = False,
normalize: bool = True,
head_first: bool = True
head_first: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
r"""
Args:
Expand All @@ -293,6 +294,8 @@ def fused_chunk_linear_attn(
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
initial_state (Optional[torch.Tensor]):
Initial state of shape `[B, H, K, V]`. Default: `None`.
z_state (Optional[torch.Tensor]):
Z state of shape `[B, H, K, 1]`. This is only needed when normalization is enabled. Default: `None`.
output_final_state (Optional[bool]):
Whether to output the final state of shape `[B, H, K, V]`. Default: `False`.
normalize (bool):
Expand All @@ -311,8 +314,16 @@ def fused_chunk_linear_attn(
if not head_first:
q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
o, final_state = FusedChunkLinearAttentionFunction.apply(q, k, v, scale, initial_state, output_final_state)

if normalize:
o = normalize_output(q * scale, k, o)
if z_state is None:
k_shape = list(k.shape)
k_shape[-2 ]= 1
z_state = k.new_zeros(k_shape)
o, z_state = normalize_output(q * scale, k, o, z_state)
if not head_first:
o = o.transpose(1, 2)

if normalize:
return o, (final_state, z_state)
return o, final_state
41 changes: 39 additions & 2 deletions fla/ops/linear_attn/fused_recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,17 +235,54 @@ def fused_recurrent_linear_attn(
v: torch.Tensor,
scale: Optional[float] = None,
initial_state: torch.Tensor = None,
z_state: torch.Tensor = None,
output_final_state: bool = False,
normalize: bool = False,
normalize: bool = True,
head_first: bool = True
) -> Tuple[torch.Tensor, torch.Tensor]:
r"""
Args:
q (torch.Tensor):
queries of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`
k (torch.Tensor):
keys of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`
v (torch.Tensor):
values of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`
scale (Optional[int]):
Scale factor for linear attention scores.
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
initial_state (Optional[torch.Tensor]):
Initial state of shape `[B, H, K, V]`. Default: `None`.
z_state (Optional[torch.Tensor]):
Z state Of shape `[B, H, K, 1]. This is only needed when normalization is enabled. `. Default: `None`.
output_final_state (Optional[bool]):
Whether to output the final state of shape `[B, H, K, V]`. Default: `False`.
normalize (bool):
Whether to normalize the output. Default: `True`.
head_first (Optional[bool]):
Whether the inputs are in the head-first format. Default: `True`.

Returns:
o (torch.Tensor):
Outputs of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`
final_state (torch.Tensor):
Final state of shape `[B, H, K, V]` if `output_final_state=True` else `None`
"""
if scale is None:
scale = q.shape[-1] ** -0.5
if not head_first:
q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
o, final_state = FusedRecurrentLinearAttentionFunction.apply(q, k, v, scale, initial_state, output_final_state)

if normalize:
o = normalize_output(q * scale, k, o)
if z_state is None:
k_shape = list(k.shape)
k_shape[-2 ]= 1
z_state = k.new_zeros(k_shape)
o, z_state = normalize_output(q * scale, k, o, z_state)
if not head_first:
o = o.transpose(1, 2)

if normalize:
return o, (final_state, z_state)
return o, final_state
7 changes: 5 additions & 2 deletions fla/ops/linear_attn/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang

import torch


@torch.jit.script
def normalize_output(q, k, o):
def normalize_output(q, k, o, z_state):
k = k.cumsum(-2)
k = k + z_state
z = (q * k).sum(-1, keepdim=True)
return o / (z + 1e-10)
return o / (z + 1e-10), k[...,-1:,:]