diff --git a/lightllm/common/basemodel/layer_infer/cache_tensor_manager.py b/lightllm/common/basemodel/layer_infer/cache_tensor_manager.py index 0cf0c7ae9..fa87630d9 100644 --- a/lightllm/common/basemodel/layer_infer/cache_tensor_manager.py +++ b/lightllm/common/basemodel/layer_infer/cache_tensor_manager.py @@ -93,6 +93,10 @@ def __init__(self): self.cuda_graph_cur_batch_size = None self.is_cuda_graph = False self.managed_total_tensor_bytes = 0 + # 防止误用导致显存泄露,添加标记变量。 + # 当使用者没有合法的调用 cache_env_in 和 cache_env_out 的时候 + # 如果调用了alloc_tensor 接口,则退化为 torch.empty 申请方式。 + self.cache_env_ok = False def cache_env_in( self, is_cuda_graph: bool = False, cur_batch_size: int = 0, cuda_graph_max_batch_size: int = 0 @@ -107,6 +111,7 @@ def cache_env_in( assert self.inner_cuda_graph_manager.cuda_graph_max_batch_size == cuda_graph_max_batch_size self.cuda_graph_cur_batch_size = cur_batch_size assert cur_batch_size != 0 + self.cache_env_ok = True return def cache_env_out(self): @@ -115,6 +120,7 @@ def cache_env_out(self): self.free_shape_dtype_to_bufs.clear() self.calcu_shape_cache.clear() self.changed_ptr.clear() + self.cache_env_ok = False return def alloc_tensor( @@ -129,6 +135,11 @@ def alloc_tensor( # shape 类型转换 if isinstance(shape, list): shape = torch.Size(shape) + + # cache manager 没有被正常使用时 + if not self.cache_env_ok: + return torch.empty(shape, dtype=data_type, device=device, requires_grad=False) + # 是 cuda graph的时候,由cuda graph manager 接管 if self.is_cuda_graph: return self.inner_cuda_graph_manager.alloc_tensor_for_cuda_graph( diff --git a/lightllm/models/qwen2_5_vl/qwen2_5_visual.py b/lightllm/models/qwen2_5_vl/qwen2_5_visual.py index 23a63b0ce..a9e162e3f 100644 --- a/lightllm/models/qwen2_5_vl/qwen2_5_visual.py +++ b/lightllm/models/qwen2_5_vl/qwen2_5_visual.py @@ -22,6 +22,8 @@ from transformers.utils import TensorType from lightllm.server.multimodal_params import MultimodalParams, ImageItem from lightllm.models.qwen2_vl.qwen2_visual import PatchEmbed, VisionRotaryEmbedding +from lightllm.models.vit.triton_kernel.flashattention_nopad import flash_attention_fwd +from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager # adapted from # https://github.com/huggingface/transformers/blob/ @@ -123,7 +125,7 @@ def apply_rotary_pos_emb_vision( return q_embed, k_embed -class Qwen2_5_VLVisionAttention(nn.Module): +class Qwen2_5_VLVisionFlashAttention(nn.Module): def __init__(self, dim: int, num_heads: int = 16) -> None: super().__init__() self.num_heads = num_heads @@ -148,80 +150,15 @@ def forward( cos, sin = position_embeddings q, k = apply_rotary_pos_emb_vision(q, k, cos, sin) - attention_mask = torch.full( - [1, seq_length, seq_length], - torch.finfo(q.dtype).min, - device=q.device, - dtype=q.dtype, - ) - for i in range(1, len(cu_seqlens)): - attention_mask[ - ..., - cu_seqlens[i - 1] : cu_seqlens[i], - cu_seqlens[i - 1] : cu_seqlens[i], - ] = 0 - - q = q.transpose(0, 1) - k = k.transpose(0, 1) - v = v.transpose(0, 1) - attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim) - attn_weights = attn_weights + attention_mask - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype) - attn_output = torch.matmul(attn_weights, v) - attn_output = attn_output.transpose(0, 1) - attn_output = attn_output.reshape(seq_length, -1) - attn_output = self.proj(attn_output) - return attn_output - - -class Qwen2_5_VLVisionSdpaAttention(nn.Module): - def __init__(self, dim: int, num_heads: int = 16) -> None: - super().__init__() - self.num_heads = num_heads - self.qkv = nn.Linear(dim, dim * 3, bias=True) - self.proj = nn.Linear(dim, dim) - - def forward( - self, - hidden_states: torch.Tensor, - cu_seqlens: torch.Tensor, - rotary_pos_emb: Optional[torch.Tensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - ) -> torch.Tensor: - seq_length = hidden_states.shape[0] - q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) - if position_embeddings is None: - emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) - cos = emb.cos() - sin = emb.sin() - else: - cos, sin = position_embeddings - q, k = apply_rotary_pos_emb_vision(q, k, cos, sin) - - attention_mask = torch.zeros([1, seq_length, seq_length], device=q.device, dtype=torch.bool) - for i in range(1, len(cu_seqlens)): - attention_mask[ - ..., - cu_seqlens[i - 1] : cu_seqlens[i], - cu_seqlens[i - 1] : cu_seqlens[i], - ] = True - q = q.transpose(0, 1) - k = k.transpose(0, 1) - v = v.transpose(0, 1) - attn_output = F.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0) - attn_output = attn_output.transpose(0, 1) + cu_seqlens = cu_seqlens.to(q.device, torch.int32) + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + attn_output = g_cache_manager.alloc_tensor(q.shape, q.dtype, device=q.device) + flash_attention_fwd(q, k, v, attn_output, cu_seqlens, max_seqlen) attn_output = attn_output.reshape(seq_length, -1) attn_output = self.proj(attn_output) return attn_output -QWEN2_5_VL_VISION_ATTENTION_CLASSES = { - "eager": Qwen2_5_VLVisionAttention, - # "flash_attention_2": Qwen2_5_VLVisionFlashAttention2, - "sdpa": Qwen2_5_VLVisionSdpaAttention, -} - - class Qwen2_5_VLVisionBlock(nn.Module): def __init__( self, @@ -229,13 +166,12 @@ def __init__( intermediate_size, num_heads, hidden_act, - attn_implementation: str = "eager", ) -> None: super().__init__() self.norm1 = Qwen2RMSNorm(hidden_size, eps=1e-6) self.norm2 = Qwen2RMSNorm(hidden_size, eps=1e-6) - self.attn = QWEN2_5_VL_VISION_ATTENTION_CLASSES[attn_implementation](hidden_size, num_heads=num_heads) + self.attn = Qwen2_5_VLVisionFlashAttention(hidden_size, num_heads=num_heads) self.mlp = Qwen2_5_VLMLP( hidden_size=hidden_size, intermediate_size=intermediate_size, @@ -312,8 +248,6 @@ def __init__( self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size - self.attn_implementation = "eager" - self.patch_embed = PatchEmbed( patch_size=self.patch_size, temporal_patch_size=self.temporal_patch_size, @@ -331,7 +265,6 @@ def __init__( self.intermediate_size, self.num_heads, self.hidden_act, - self.attn_implementation, ) for _ in range(self.depth) ] diff --git a/lightllm/models/qwen2_vl/qwen2_visual.py b/lightllm/models/qwen2_vl/qwen2_visual.py index 5c1024201..4a9012518 100644 --- a/lightllm/models/qwen2_vl/qwen2_visual.py +++ b/lightllm/models/qwen2_vl/qwen2_visual.py @@ -43,7 +43,8 @@ from transformers.utils import TensorType from lightllm.server.multimodal_params import MultimodalParams, ImageItem from lightllm.models.qwen2_vl.vision_process import Qwen2VLImageProcessor - +from lightllm.models.vit.triton_kernel.flashattention_nopad import flash_attention_fwd +from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager from transformers.utils import is_flash_attn_2_available @@ -210,7 +211,7 @@ def forward( # adapted from # https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py -class VisionFlashAttention2(nn.Module): +class VisionFlashAttention(nn.Module): def __init__(self, dim: int, num_heads: int = 16) -> None: super().__init__() self.num_heads = num_heads @@ -222,63 +223,31 @@ def forward( ) -> torch.Tensor: seq_length = hidden_states.shape[0] q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) - q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0) - k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0) + q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb) + k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb) + q = q.squeeze(0) + k = k.squeeze(0) + cu_seqlens = cu_seqlens.to(q.device, torch.int32) max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() - attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape( - seq_length, -1 - ) - attn_output = self.proj(attn_output) - return attn_output - - -# adapted from -# https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py -class VisionSdpaAttention(nn.Module): - def __init__(self, dim: int, num_heads: int = 16) -> None: - super().__init__() - self.num_heads = num_heads - self.qkv = nn.Linear(dim, dim * 3, bias=True) - self.proj = nn.Linear(dim, dim) - - def forward( - self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None - ) -> torch.Tensor: - seq_length = hidden_states.shape[0] - q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) - q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0) - k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0) + attn_output = g_cache_manager.alloc_tensor(q.shape, q.dtype, device=q.device) - attention_mask = torch.zeros([1, seq_length, seq_length], device=q.device, dtype=torch.bool) - for i in range(1, len(cu_seqlens)): - attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True - q = q.transpose(0, 1) - k = k.transpose(0, 1) - v = v.transpose(0, 1) - attn_output = F.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0) - attn_output = attn_output.transpose(0, 1) + flash_attention_fwd(q, k, v, attn_output, cu_seqlens, max_seqlen) attn_output = attn_output.reshape(seq_length, -1) attn_output = self.proj(attn_output) return attn_output -QWEN2_VL_VISION_ATTENTION_CLASSES = { - "eager": VisionAttention, - # "flash_attention_2": VisionFlashAttention2, - "sdpa": VisionSdpaAttention, -} - # adapted from # https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py class Qwen2VLVisionBlock(nn.Module): - def __init__(self, embed_dim, mlp_ratio, num_heads, hidden_act, attn_implementation: str = "eager") -> None: + def __init__(self, embed_dim, mlp_ratio, num_heads, hidden_act) -> None: super().__init__() self.norm1 = LayerNorm(embed_dim, eps=1e-6) self.norm2 = LayerNorm(embed_dim, eps=1e-6) mlp_hidden_dim = int(embed_dim * mlp_ratio) - self.attn = QWEN2_VL_VISION_ATTENTION_CLASSES[attn_implementation](embed_dim, num_heads=num_heads) + self.attn = VisionFlashAttention(embed_dim, num_heads=num_heads) self.mlp = VisionMlp(dim=embed_dim, hidden_dim=mlp_hidden_dim, hidden_act=hidden_act) def forward(self, hidden_states, cu_seqlens, rotary_pos_emb) -> torch.Tensor: @@ -318,8 +287,6 @@ def __init__( self.spatial_merge_size = spatial_merge_size self.temporal_patch_size = temporal_patch_size - self.attn_implementation = "eager" - self.patch_embed = PatchEmbed( patch_size=self.patch_size, temporal_patch_size=self.temporal_patch_size, @@ -332,9 +299,7 @@ def __init__( self.blocks = nn.ModuleList( [ - Qwen2VLVisionBlock( - self.embed_dim, self.mlp_ratio, self.num_heads, self.hidden_act, self.attn_implementation - ) + Qwen2VLVisionBlock(self.embed_dim, self.mlp_ratio, self.num_heads, self.hidden_act) for _ in range(self.depth) ] ) diff --git a/lightllm/models/vit/layer_infer/transformer_layer_infer.py b/lightllm/models/vit/layer_infer/transformer_layer_infer.py index 14ba9cfed..ad8e5b6ca 100644 --- a/lightllm/models/vit/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/vit/layer_infer/transformer_layer_infer.py @@ -1,13 +1,8 @@ import torch -import torch.functional as F import torch.distributed as dist -import numpy as np -from typing import Tuple -from functools import partial -import triton + from lightllm.models.vit.layer_weights.transformer_layer_weight import ViTTransformerLayerWeight -from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward, torch_rms_norm from lightllm.models.vit.triton_kernel.flashattention_nopad import flash_attention_fwd from lightllm.utils.dist_utils import get_current_rank_in_dp, get_dp_world_size from lightllm.models.vit.triton_kernel.gelu_vit import gelu_fwd @@ -103,9 +98,13 @@ def _get_qkv(self, input, layer_weight: ViTTransformerLayerWeight) -> torch.Tens def _context_attention_kernel(self, q, k, v) -> torch.Tensor: out = g_cache_manager.alloc_tensor(q.shape, q.dtype, device=q.device) - batch_size = q.shape[0] - seq_len = q.shape[1] - flash_attention_fwd(q, k, v, out) + batch_size, seq_len, head_num, head_dim = q.shape + total_len = batch_size * seq_len + reshape = lambda t: t.view(total_len, head_num, head_dim) + q, k, v, out = map(reshape, (q, k, v, out)) + cu_seqlens = torch.arange(batch_size + 1, dtype=torch.int32, device=q.device) * seq_len + max_seqlen = seq_len + flash_attention_fwd(q, k, v, out, cu_seqlens, max_seqlen) return out.reshape(batch_size, seq_len, -1) def _get_o(self, input, layer_weight: ViTTransformerLayerWeight) -> torch.Tensor: diff --git a/lightllm/models/vit/triton_kernel/flashattention_nopad.py b/lightllm/models/vit/triton_kernel/flashattention_nopad.py index 1cb9c9cae..34e7ed6be 100644 --- a/lightllm/models/vit/triton_kernel/flashattention_nopad.py +++ b/lightllm/models/vit/triton_kernel/flashattention_nopad.py @@ -1,9 +1,10 @@ import torch - import triton import triton.language as tl import math +import time import torch.nn.functional as F +from typing import Optional, Tuple from lightllm.utils.device_utils import is_hopper if triton.__version__ >= "2.1.0": @@ -14,24 +15,21 @@ def _fwd_kernel( K, V, sm_scale, - seq_len, Out, - q_stride_b, q_stride_s, q_stride_h, q_stride_d, - k_stride_b, k_stride_s, k_stride_h, k_stride_d, - v_stride_b, v_stride_s, v_stride_h, v_stride_d, - o_stride_b, o_stride_s, o_stride_h, o_stride_d, + head_dim_act, + cu_seqlens, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, @@ -40,12 +38,18 @@ def _fwd_kernel( cur_head = tl.program_id(1) start_m = tl.program_id(0) + seq_start = tl.load(cu_seqlens + cur_batch).to(tl.int32) + seq_end = tl.load(cu_seqlens + cur_batch + 1).to(tl.int32) + seq_len = seq_end - seq_start + # initialize offsets offs_n = tl.arange(0, BLOCK_N) offs_d = tl.arange(0, BLOCK_DMODEL) offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - off_q = cur_batch * q_stride_b + cur_head * q_stride_h + offs_m[:, None] * q_stride_s + offs_d[None, :] - q = tl.load(Q + off_q, mask=offs_m[:, None] < seq_len, other=0.0) + + mask_d = offs_d < head_dim_act + off_q = cur_head * q_stride_h + (seq_start + offs_m[:, None]) * q_stride_s + offs_d[None, :] * q_stride_d + q = tl.load(Q + off_q, mask=(offs_m[:, None] < seq_len) & mask_d[None, :], other=0.0) # initialize pointer to m and l m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") l_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") @@ -55,12 +59,11 @@ def _fwd_kernel( start_n = tl.multiple_of(start_n, BLOCK_N) # -- compute qk ---- off_k = ( - cur_batch * k_stride_b - + (start_n + offs_n[None, :]) * k_stride_s + (seq_start + start_n + offs_n[None, :]) * k_stride_s + cur_head * k_stride_h - + offs_d[:, None] + + offs_d[:, None] * k_stride_d ) - k = tl.load(K + off_k, mask=(start_n + offs_n[None, :]) < seq_len, other=0.0) + k = tl.load(K + off_k, mask=((start_n + offs_n[None, :]) < seq_len) & mask_d[:, None], other=0.0) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) qk += tl.dot(q, k) @@ -77,12 +80,11 @@ def _fwd_kernel( # update acc off_v = ( - cur_batch * v_stride_b - + (start_n + offs_n[:, None]) * v_stride_s + (seq_start + start_n + offs_n[:, None]) * v_stride_s + cur_head * v_stride_h - + offs_d[None, :] + + offs_d[None, :] * v_stride_d ) - v = tl.load(V + off_v, mask=(start_n + offs_n[:, None]) < seq_len, other=0.0) + v = tl.load(V + off_v, mask=((start_n + offs_n[:, None]) < seq_len) & mask_d[None, :], other=0.0) p = p.to(v.dtype) acc += tl.dot(p, v) # update m_i and l_i @@ -93,9 +95,9 @@ def _fwd_kernel( o_scale = tl.exp(m_i - l_i) acc = acc * o_scale[:, None] # initialize pointers to output - off_o = cur_batch * o_stride_b + offs_m[:, None] * o_stride_s + cur_head * o_stride_h + offs_d[None, :] + off_o = (seq_start + offs_m[:, None]) * o_stride_s + cur_head * o_stride_h + offs_d[None, :] * o_stride_d out_ptrs = Out + off_o - tl.store(out_ptrs, acc, mask=offs_m[:, None] < seq_len) + tl.store(out_ptrs, acc, mask=(offs_m[:, None] < seq_len) & mask_d[None, :]) return @torch.no_grad() @@ -104,40 +106,41 @@ def _flash_attention_triton_fwd( k, v, o, + cu_seqlens, # q k v cu_seqlens, + max_seqlen, ): BLOCK = 64 # shape constraints - batch_size, seq_len, head_num, head_dim = q.shape + assert q.ndim == k.ndim == v.ndim == o.ndim == 3, "q, k, v, o must be 3D tensors" + _, head_num, head_dim = q.shape + batch_size = cu_seqlens.numel() - 1 sm_scale = 1.0 / (head_dim ** 0.5) # 计算scale系数 - # grid = (batch_size, head_num, triton.cdiv(seq_len, BLOCK)) # batch, head, - grid = (triton.cdiv(seq_len, BLOCK), head_num, batch_size) # batch, head, + d_pad = triton.next_power_of_2(head_dim) + grid = (triton.cdiv(max_seqlen, BLOCK), head_num, batch_size) # batch, head, num_warps = 4 _fwd_kernel[grid]( q, k, v, sm_scale, - seq_len, o, q.stride(0), q.stride(1), q.stride(2), - q.stride(3), k.stride(0), k.stride(1), k.stride(2), - k.stride(3), v.stride(0), v.stride(1), v.stride(2), - v.stride(3), o.stride(0), o.stride(1), o.stride(2), - o.stride(3), + head_dim, + cu_seqlens, BLOCK_M=BLOCK, - BLOCK_DMODEL=head_dim, + BLOCK_DMODEL=d_pad, BLOCK_N=BLOCK, num_warps=num_warps, num_stages=2, @@ -158,6 +161,8 @@ def flash_attention_v3_fwd( k, v, o, + cu_seqlens, + max_seqlen, ): head_dim = q.shape[-1] softmax_scale = head_dim ** -0.5 @@ -168,13 +173,13 @@ def flash_attention_v3_fwd( None, None, # k_new, v_new o, # out - None, - None, + cu_seqlens, + cu_seqlens, None, # cu_seqlens_q/k/k_new None, None, # seqused_q/k - None, - None, # max_seqlen_q/k + max_seqlen, + max_seqlen, # max_seqlen_q/k None, None, None, # page_table, kv_batch_idx, leftpad_k, @@ -198,53 +203,12 @@ def flash_attention_v3_fwd( _flash_attn_v3_available = False -def flash_attention_fwd(q, k, v, o): +def flash_attention_fwd(q, k, v, o, cu_seqlens, max_seqlen): """ 统一的 Flash Attention 接口。如果 _flash_attn_forward 存在, 则使用 flash_attention_v3_fwd,否则使用 Triton 版本。 """ if _flash_attn_v3_available and is_hopper(): - flash_attention_v3_fwd(q, k, v, o) + flash_attention_v3_fwd(q, k, v, o, cu_seqlens, max_seqlen) else: - _flash_attention_triton_fwd(q, k, v, o) - - -def torch_att(q, k, v): - head_dim = q.shape[-1] - q = q.transpose(1, 2) - k = k.transpose(1, 2) - v = v.transpose(1, 2) - scale = head_dim ** -0.5 - attn = (q * scale) @ k.transpose(-2, -1) - attn = attn.softmax(dim=-1) - out = attn @ v - out = out.transpose(1, 2).contiguous() - return out - - -def test(): - import torch - import numpy as np - - B, L, H, D = 4, 1025, 7, 128 - dtype = torch.float16 - q = torch.empty((B, L, H, D), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2) - k = torch.empty((B, L, H, D), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2) - v = torch.empty((B, L, H, D), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2) - o = torch.empty((B, L, H, D), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2) - torch_out = torch_att(q, k, v) - import time - - torch.cuda.synchronize() - a = time.time() - for i in range(100): - flash_attention_fwd(q, k, v, o) - # o = torch_att(q, k, v) - torch.cuda.synchronize() - b = time.time() - # print(o.shape, torch_out.shape) - print((b - a) / 100 * 1000) - - print("max ", torch.max(torch.abs(torch_out - o))) - print("mean ", torch.mean(torch.abs(torch_out - o))) - assert torch.allclose(torch_out, o, atol=1e-2, rtol=0) + _flash_attention_triton_fwd(q, k, v, o, cu_seqlens, max_seqlen) diff --git a/unit_tests/models/vit/test_flash_attention_forward.py b/unit_tests/models/vit/test_flash_attention_forward.py new file mode 100644 index 000000000..53b70e683 --- /dev/null +++ b/unit_tests/models/vit/test_flash_attention_forward.py @@ -0,0 +1,63 @@ +import torch +import math +import time +import pytest +from lightllm.models.vit.triton_kernel.flashattention_nopad import flash_attention_fwd + + +def reference_attention_varlen(q, k, v, cu): + """ + q, k, v : (total_len, n_head, D) + cu_seqlen : prefix sums (batch+1,) + """ + total, n_head, d = q.shape + out = torch.empty_like(q) + scale = 1.0 / math.sqrt(d) + + for b in range(cu.numel() - 1): + s, e = cu[b].item(), cu[b + 1].item() + q_b, k_b, v_b = q[s:e], k[s:e], v[s:e] # (seq, head, D) + + q_hsd = q_b.permute(1, 0, 2) # (head, seq, D) + k_hds = k_b.permute(1, 2, 0) # (head, D, seq) + v_hsd = v_b.permute(1, 0, 2) # (head, seq, D) + + scores = torch.matmul(q_hsd, k_hds) * scale # (head, seq, seq) + probs = torch.softmax(scores.float(), dim=-1) + + out_hsd = torch.matmul(probs, v_hsd.float()) # (head, seq, D) + out[s:e] = out_hsd.permute(1, 0, 2).to(q.dtype) # back to (seq, head, D) + + return out + + +@pytest.mark.parametrize("dtype,atol", [(torch.float16, 1e-2), (torch.bfloat16, 2e-2)]) +def test_varlen(dtype, atol, batch=4, heads=8, d=80, device="cuda:0"): + torch.manual_seed(0) + lengths = torch.randint(1, 257, (batch,)) + max_len = int(lengths.max().item()) + + cu = torch.zeros(batch + 1, dtype=torch.int32, device=device) + cu[1:] = torch.cumsum(lengths, 0) + tot = int(cu[-1]) + + q = torch.randn(tot, heads, d, dtype=dtype, device=device) + k = torch.randn_like(q) + v = torch.randn_like(q) + out_tri = torch.randn_like(q) + flash_attention_fwd(q, k, v, out_tri, cu, max_len) + a = time.time() + for _ in range(100): + flash_attention_fwd(q, k, v, out_tri, cu, max_len) + b = time.time() + print(f"flash_attention_fwd time: {(b - a) / 100 * 1000:.2f} ms") + out_ref = reference_attention_varlen(q, k, v, cu) + + max_err = (out_ref - out_tri).abs().max().item() + mean_err = (out_ref - out_tri).abs().mean().item() + print(f"{dtype}: max {max_err:.6f}, mean {mean_err:.6f}") + torch.testing.assert_close(out_tri, out_ref, atol=atol, rtol=0) + + +if __name__ == "__main__": + pytest.main()