|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project |
| 3 | +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang |
| 4 | +# |
| 5 | +# This file contains code copied from the flash-linear-attention project. |
| 6 | +# The original source code was licensed under the MIT license and included |
| 7 | +# the following copyright notice: |
| 8 | +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang |
| 9 | +# ruff: noqa: E501 |
| 10 | +import warnings |
| 11 | +from typing import Optional |
| 12 | + |
| 13 | +import torch |
| 14 | +from einops import rearrange |
| 15 | + |
| 16 | +from .chunk_delta_h import chunk_gated_delta_rule_fwd_h |
| 17 | +from .chunk_o import chunk_fwd_o |
| 18 | +from .chunk_scaled_dot_kkt import chunk_scaled_dot_kkt_fwd |
| 19 | +from .cumsum import chunk_local_cumsum |
| 20 | +from .l2norm import l2norm_fwd |
| 21 | +from .solve_tril import solve_tril |
| 22 | +from .utils import SUPPRESS_LEVEL, input_guard |
| 23 | +from .wy_fast import recompute_w_u_fwd |
| 24 | + |
| 25 | + |
| 26 | +def chunk_gated_delta_rule_fwd(q: torch.Tensor, |
| 27 | + k: torch.Tensor, |
| 28 | + v: torch.Tensor, |
| 29 | + g: torch.Tensor, |
| 30 | + beta: torch.Tensor, |
| 31 | + scale: float, |
| 32 | + initial_state: torch.Tensor, |
| 33 | + output_final_state: bool, |
| 34 | + cu_seqlens: Optional[torch.LongTensor] = None): |
| 35 | + g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens) |
| 36 | + # obtain WY representation. u is actually the new v. |
| 37 | + A = chunk_scaled_dot_kkt_fwd(k=k, |
| 38 | + beta=beta, |
| 39 | + g_cumsum=g, |
| 40 | + cu_seqlens=cu_seqlens, |
| 41 | + output_dtype=torch.float32) |
| 42 | + A = solve_tril(A=A, cu_seqlens=cu_seqlens, output_dtype=k.dtype) |
| 43 | + w, u = recompute_w_u_fwd( |
| 44 | + k=k, |
| 45 | + v=v, |
| 46 | + beta=beta, |
| 47 | + A=A, |
| 48 | + g_cumsum=g, |
| 49 | + cu_seqlens=cu_seqlens, |
| 50 | + ) |
| 51 | + h, v_new, final_state = chunk_gated_delta_rule_fwd_h( |
| 52 | + k=k, |
| 53 | + w=w, |
| 54 | + u=u, |
| 55 | + g=g, |
| 56 | + initial_state=initial_state, |
| 57 | + output_final_state=output_final_state, |
| 58 | + cu_seqlens=cu_seqlens, |
| 59 | + ) |
| 60 | + o = chunk_fwd_o( |
| 61 | + q=q, |
| 62 | + k=k, |
| 63 | + v=v_new, |
| 64 | + h=h, |
| 65 | + g=g, |
| 66 | + scale=scale, |
| 67 | + cu_seqlens=cu_seqlens, |
| 68 | + ) |
| 69 | + if SUPPRESS_LEVEL < 3: |
| 70 | + return g, o, A, final_state, None, None, None |
| 71 | + elif SUPPRESS_LEVEL >= 3: |
| 72 | + return g, o, A, final_state, w, h, v_new |
| 73 | + |
| 74 | + |
| 75 | +class ChunkGatedDeltaRuleFunction(torch.autograd.Function): |
| 76 | + |
| 77 | + @staticmethod |
| 78 | + @input_guard |
| 79 | + @torch.amp.custom_fwd(device_type='cuda') |
| 80 | + def forward(ctx, |
| 81 | + q: torch.Tensor, |
| 82 | + k: torch.Tensor, |
| 83 | + v: torch.Tensor, |
| 84 | + g: torch.Tensor, |
| 85 | + beta: torch.Tensor, |
| 86 | + scale: float, |
| 87 | + initial_state: torch.Tensor, |
| 88 | + output_final_state: bool, |
| 89 | + cu_seqlens: Optional[torch.LongTensor] = None, |
| 90 | + use_qk_l2norm_in_kernel: bool = False): |
| 91 | + if use_qk_l2norm_in_kernel: |
| 92 | + q = l2norm_fwd(q) |
| 93 | + k = l2norm_fwd(k) |
| 94 | + |
| 95 | + g, o, A, final_state, w, h, v_new = chunk_gated_delta_rule_fwd( |
| 96 | + q=q, |
| 97 | + k=k, |
| 98 | + v=v, |
| 99 | + g=g, |
| 100 | + beta=beta, |
| 101 | + scale=scale, |
| 102 | + initial_state=initial_state, |
| 103 | + output_final_state=output_final_state, |
| 104 | + cu_seqlens=cu_seqlens, |
| 105 | + ) |
| 106 | + ctx.scale = scale |
| 107 | + ctx.use_qk_l2norm_in_kernel = use_qk_l2norm_in_kernel |
| 108 | + return o.to(q.dtype), final_state |
| 109 | + |
| 110 | + |
| 111 | +@torch.compiler.disable |
| 112 | +def chunk_gated_delta_rule(q: torch.Tensor, |
| 113 | + k: torch.Tensor, |
| 114 | + v: torch.Tensor, |
| 115 | + g: torch.Tensor, |
| 116 | + beta: torch.Tensor, |
| 117 | + scale: float = None, |
| 118 | + initial_state: torch.Tensor = None, |
| 119 | + output_final_state: bool = False, |
| 120 | + cu_seqlens: Optional[torch.LongTensor] = None, |
| 121 | + head_first: bool = False, |
| 122 | + use_qk_l2norm_in_kernel: bool = False): |
| 123 | + r""" |
| 124 | + Args: |
| 125 | + q (torch.Tensor): |
| 126 | + queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. |
| 127 | + k (torch.Tensor): |
| 128 | + keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. |
| 129 | + v (torch.Tensor): |
| 130 | + values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. |
| 131 | + g (torch.Tensor): |
| 132 | + (forget) gating tensor (in log space!) of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`. |
| 133 | + beta (torch.Tensor): |
| 134 | + betas of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`. |
| 135 | + scale (Optional[int]): |
| 136 | + Scale factor for the RetNet attention scores. |
| 137 | + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. |
| 138 | + initial_state (Optional[torch.Tensor]): |
| 139 | + Initial state of shape `[N, H, K, V]` for `N` input sequences. |
| 140 | + For equal-length input sequences, `N` equals the batch size `B`. |
| 141 | + Default: `None`. |
| 142 | + output_final_state (Optional[bool]): |
| 143 | + Whether to output the final state of shape `[N, H, K, V]`. Default: `False`. |
| 144 | + cu_seqlens (torch.LongTensor): |
| 145 | + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, |
| 146 | + consistent with the FlashAttention API. |
| 147 | + head_first (Optional[bool]): |
| 148 | + Whether the inputs are in the head-first format, which is not supported for variable-length inputs. |
| 149 | + Default: `False`. |
| 150 | +
|
| 151 | + Returns: |
| 152 | + o (torch.Tensor): |
| 153 | + Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. |
| 154 | + final_state (torch.Tensor): |
| 155 | + Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`. |
| 156 | +
|
| 157 | + Examples:: |
| 158 | + >>> import torch |
| 159 | + >>> import torch.nn.functional as F |
| 160 | + >>> from einops import rearrange |
| 161 | + >>> from fla.ops.gated_delta_rule import chunk_gated_delta_rule |
| 162 | + # inputs with equal lengths |
| 163 | + >>> B, T, H, K, V = 4, 2048, 4, 512, 512 |
| 164 | + >>> q = torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda') |
| 165 | + >>> k = F.normalize(torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda'), p=2, dim=-1) |
| 166 | + >>> v = torch.randn(B, T, H, V, dtype=torch.bfloat16, device='cuda') |
| 167 | + >>> beta = torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda').sigmoid() |
| 168 | + >>> g = F.logsigmoid(torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda')) |
| 169 | + >>> h0 = torch.randn(B, H, K, V, dtype=torch.bfloat16, device='cuda') |
| 170 | + >>> o, ht = chunk_gated_delta_rule( |
| 171 | + q, k, v, g, beta, |
| 172 | + initial_state=h0, |
| 173 | + output_final_state=True |
| 174 | + ) |
| 175 | + # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required |
| 176 | + >>> q, k, v, beta, g = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, beta, g)) |
| 177 | + # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected |
| 178 | + >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long) |
| 179 | + >>> o_var, ht_var = chunk_gated_delta_rule( |
| 180 | + q, k, v, g, beta, |
| 181 | + initial_state=h0, |
| 182 | + output_final_state=True, |
| 183 | + cu_seqlens=cu_seqlens |
| 184 | + ) |
| 185 | + """ |
| 186 | + assert q.dtype == k.dtype == v.dtype |
| 187 | + assert q.dtype != torch.float32, "ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16." |
| 188 | + assert len( |
| 189 | + beta.shape |
| 190 | + ) == 3, "beta must be of shape [B, T, H] if head_first=False, or [B, H, T] otherwise." |
| 191 | + |
| 192 | + if head_first: |
| 193 | + raise DeprecationWarning( |
| 194 | + "head_first is deprecated and will be removed in a future version. " |
| 195 | + "Please use head_first=False for now instead.", |
| 196 | + stacklevel=2) |
| 197 | + q, k, v, beta, g = map( |
| 198 | + lambda x: rearrange(x, 'b h t ... -> b t h ...'), |
| 199 | + (q, k, v, beta, g)) |
| 200 | + if not head_first and q.shape[1] < q.shape[2]: |
| 201 | + warnings.warn( |
| 202 | + f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). " |
| 203 | + "This may indicate the inputs were passed in head-first format [B, H, T, ...] " |
| 204 | + "when head_first=False was specified. " |
| 205 | + "Please verify your input tensor format matches the expected shape [B, T, H, ...].", |
| 206 | + stacklevel=2) |
| 207 | + if cu_seqlens is not None: |
| 208 | + if q.shape[0] != 1: |
| 209 | + raise ValueError( |
| 210 | + f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." |
| 211 | + f"Please flatten variable-length inputs before processing.") |
| 212 | + if initial_state is not None and initial_state.shape[0] != len( |
| 213 | + cu_seqlens) - 1: |
| 214 | + raise ValueError( |
| 215 | + f"The number of initial states is expected to be equal to the number of input sequences, " |
| 216 | + f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}." |
| 217 | + ) |
| 218 | + if scale is None: |
| 219 | + scale = k.shape[-1]**-0.5 |
| 220 | + o, final_state = ChunkGatedDeltaRuleFunction.apply( |
| 221 | + q, k, v, g, beta, scale, initial_state, output_final_state, cu_seqlens, |
| 222 | + use_qk_l2norm_in_kernel) |
| 223 | + if head_first: |
| 224 | + o = rearrange(o, 'b t h ... -> b h t ...') |
| 225 | + return o, final_state |
0 commit comments