Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions lightllm/common/basemodel/layer_infer/cache_tensor_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,10 @@ def __init__(self):
self.cuda_graph_cur_batch_size = None
self.is_cuda_graph = False
self.managed_total_tensor_bytes = 0
# 防止误用导致显存泄露,添加标记变量。
# 当使用者没有合法的调用 cache_env_in 和 cache_env_out 的时候
# 如果调用了alloc_tensor 接口,则退化为 torch.empty 申请方式。
self.cache_env_ok = False

def cache_env_in(
self, is_cuda_graph: bool = False, cur_batch_size: int = 0, cuda_graph_max_batch_size: int = 0
Expand All @@ -107,6 +111,7 @@ def cache_env_in(
assert self.inner_cuda_graph_manager.cuda_graph_max_batch_size == cuda_graph_max_batch_size
self.cuda_graph_cur_batch_size = cur_batch_size
assert cur_batch_size != 0
self.cache_env_ok = True
return

def cache_env_out(self):
Expand All @@ -115,6 +120,7 @@ def cache_env_out(self):
self.free_shape_dtype_to_bufs.clear()
self.calcu_shape_cache.clear()
self.changed_ptr.clear()
self.cache_env_ok = False
return

def alloc_tensor(
Expand All @@ -129,6 +135,11 @@ def alloc_tensor(
# shape 类型转换
if isinstance(shape, list):
shape = torch.Size(shape)

# cache manager 没有被正常使用时
if not self.cache_env_ok:
return torch.empty(shape, dtype=data_type, device=device, requires_grad=False)

# 是 cuda graph的时候,由cuda graph manager 接管
if self.is_cuda_graph:
return self.inner_cuda_graph_manager.alloc_tensor_for_cuda_graph(
Expand Down
83 changes: 8 additions & 75 deletions lightllm/models/qwen2_5_vl/qwen2_5_visual.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
from transformers.utils import TensorType
from lightllm.server.multimodal_params import MultimodalParams, ImageItem
from lightllm.models.qwen2_vl.qwen2_visual import PatchEmbed, VisionRotaryEmbedding
from lightllm.models.vit.triton_kernel.flashattention_nopad import flash_attention_fwd
from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager

# adapted from
# https://github.com/huggingface/transformers/blob/
Expand Down Expand Up @@ -123,7 +125,7 @@ def apply_rotary_pos_emb_vision(
return q_embed, k_embed


class Qwen2_5_VLVisionAttention(nn.Module):
class Qwen2_5_VLVisionFlashAttention(nn.Module):
def __init__(self, dim: int, num_heads: int = 16) -> None:
super().__init__()
self.num_heads = num_heads
Expand All @@ -148,94 +150,28 @@ def forward(
cos, sin = position_embeddings
q, k = apply_rotary_pos_emb_vision(q, k, cos, sin)

attention_mask = torch.full(
[1, seq_length, seq_length],
torch.finfo(q.dtype).min,
device=q.device,
dtype=q.dtype,
)
for i in range(1, len(cu_seqlens)):
attention_mask[
...,
cu_seqlens[i - 1] : cu_seqlens[i],
cu_seqlens[i - 1] : cu_seqlens[i],
] = 0

q = q.transpose(0, 1)
k = k.transpose(0, 1)
v = v.transpose(0, 1)
attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim)
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
attn_output = torch.matmul(attn_weights, v)
attn_output = attn_output.transpose(0, 1)
attn_output = attn_output.reshape(seq_length, -1)
attn_output = self.proj(attn_output)
return attn_output


class Qwen2_5_VLVisionSdpaAttention(nn.Module):
def __init__(self, dim: int, num_heads: int = 16) -> None:
super().__init__()
self.num_heads = num_heads
self.qkv = nn.Linear(dim, dim * 3, bias=True)
self.proj = nn.Linear(dim, dim)

def forward(
self,
hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: Optional[torch.Tensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> torch.Tensor:
seq_length = hidden_states.shape[0]
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
if position_embeddings is None:
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
cos = emb.cos()
sin = emb.sin()
else:
cos, sin = position_embeddings
q, k = apply_rotary_pos_emb_vision(q, k, cos, sin)

attention_mask = torch.zeros([1, seq_length, seq_length], device=q.device, dtype=torch.bool)
for i in range(1, len(cu_seqlens)):
attention_mask[
...,
cu_seqlens[i - 1] : cu_seqlens[i],
cu_seqlens[i - 1] : cu_seqlens[i],
] = True
q = q.transpose(0, 1)
k = k.transpose(0, 1)
v = v.transpose(0, 1)
attn_output = F.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0)
attn_output = attn_output.transpose(0, 1)
cu_seqlens = cu_seqlens.to(q.device, torch.int32)
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
attn_output = g_cache_manager.alloc_tensor(q.shape, q.dtype, device=q.device)
flash_attention_fwd(q, k, v, attn_output, cu_seqlens, max_seqlen)
attn_output = attn_output.reshape(seq_length, -1)
attn_output = self.proj(attn_output)
return attn_output


QWEN2_5_VL_VISION_ATTENTION_CLASSES = {
"eager": Qwen2_5_VLVisionAttention,
# "flash_attention_2": Qwen2_5_VLVisionFlashAttention2,
"sdpa": Qwen2_5_VLVisionSdpaAttention,
}


class Qwen2_5_VLVisionBlock(nn.Module):
def __init__(
self,
hidden_size,
intermediate_size,
num_heads,
hidden_act,
attn_implementation: str = "eager",
) -> None:
super().__init__()
self.norm1 = Qwen2RMSNorm(hidden_size, eps=1e-6)
self.norm2 = Qwen2RMSNorm(hidden_size, eps=1e-6)

self.attn = QWEN2_5_VL_VISION_ATTENTION_CLASSES[attn_implementation](hidden_size, num_heads=num_heads)
self.attn = Qwen2_5_VLVisionFlashAttention(hidden_size, num_heads=num_heads)
self.mlp = Qwen2_5_VLMLP(
hidden_size=hidden_size,
intermediate_size=intermediate_size,
Expand Down Expand Up @@ -312,8 +248,6 @@ def __init__(

self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size

self.attn_implementation = "eager"

self.patch_embed = PatchEmbed(
patch_size=self.patch_size,
temporal_patch_size=self.temporal_patch_size,
Expand All @@ -331,7 +265,6 @@ def __init__(
self.intermediate_size,
self.num_heads,
self.hidden_act,
self.attn_implementation,
)
for _ in range(self.depth)
]
Expand Down
61 changes: 13 additions & 48 deletions lightllm/models/qwen2_vl/qwen2_visual.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@
from transformers.utils import TensorType
from lightllm.server.multimodal_params import MultimodalParams, ImageItem
from lightllm.models.qwen2_vl.vision_process import Qwen2VLImageProcessor

from lightllm.models.vit.triton_kernel.flashattention_nopad import flash_attention_fwd
from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager

from transformers.utils import is_flash_attn_2_available

Expand Down Expand Up @@ -210,7 +211,7 @@ def forward(

# adapted from
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
class VisionFlashAttention2(nn.Module):
class VisionFlashAttention(nn.Module):
def __init__(self, dim: int, num_heads: int = 16) -> None:
super().__init__()
self.num_heads = num_heads
Expand All @@ -222,63 +223,31 @@ def forward(
) -> torch.Tensor:
seq_length = hidden_states.shape[0]
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb)
k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb)
q = q.squeeze(0)
k = k.squeeze(0)

cu_seqlens = cu_seqlens.to(q.device, torch.int32)
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape(
seq_length, -1
)
attn_output = self.proj(attn_output)
return attn_output


# adapted from
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
class VisionSdpaAttention(nn.Module):
def __init__(self, dim: int, num_heads: int = 16) -> None:
super().__init__()
self.num_heads = num_heads
self.qkv = nn.Linear(dim, dim * 3, bias=True)
self.proj = nn.Linear(dim, dim)

def forward(
self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None
) -> torch.Tensor:
seq_length = hidden_states.shape[0]
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
attn_output = g_cache_manager.alloc_tensor(q.shape, q.dtype, device=q.device)

attention_mask = torch.zeros([1, seq_length, seq_length], device=q.device, dtype=torch.bool)
for i in range(1, len(cu_seqlens)):
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True
q = q.transpose(0, 1)
k = k.transpose(0, 1)
v = v.transpose(0, 1)
attn_output = F.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0)
attn_output = attn_output.transpose(0, 1)
flash_attention_fwd(q, k, v, attn_output, cu_seqlens, max_seqlen)
attn_output = attn_output.reshape(seq_length, -1)
attn_output = self.proj(attn_output)
return attn_output


QWEN2_VL_VISION_ATTENTION_CLASSES = {
"eager": VisionAttention,
# "flash_attention_2": VisionFlashAttention2,
"sdpa": VisionSdpaAttention,
}

# adapted from
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
class Qwen2VLVisionBlock(nn.Module):
def __init__(self, embed_dim, mlp_ratio, num_heads, hidden_act, attn_implementation: str = "eager") -> None:
def __init__(self, embed_dim, mlp_ratio, num_heads, hidden_act) -> None:
super().__init__()
self.norm1 = LayerNorm(embed_dim, eps=1e-6)
self.norm2 = LayerNorm(embed_dim, eps=1e-6)
mlp_hidden_dim = int(embed_dim * mlp_ratio)

self.attn = QWEN2_VL_VISION_ATTENTION_CLASSES[attn_implementation](embed_dim, num_heads=num_heads)
self.attn = VisionFlashAttention(embed_dim, num_heads=num_heads)
self.mlp = VisionMlp(dim=embed_dim, hidden_dim=mlp_hidden_dim, hidden_act=hidden_act)

def forward(self, hidden_states, cu_seqlens, rotary_pos_emb) -> torch.Tensor:
Expand Down Expand Up @@ -318,8 +287,6 @@ def __init__(
self.spatial_merge_size = spatial_merge_size
self.temporal_patch_size = temporal_patch_size

self.attn_implementation = "eager"

self.patch_embed = PatchEmbed(
patch_size=self.patch_size,
temporal_patch_size=self.temporal_patch_size,
Expand All @@ -332,9 +299,7 @@ def __init__(

self.blocks = nn.ModuleList(
[
Qwen2VLVisionBlock(
self.embed_dim, self.mlp_ratio, self.num_heads, self.hidden_act, self.attn_implementation
)
Qwen2VLVisionBlock(self.embed_dim, self.mlp_ratio, self.num_heads, self.hidden_act)
for _ in range(self.depth)
]
)
Expand Down
17 changes: 8 additions & 9 deletions lightllm/models/vit/layer_infer/transformer_layer_infer.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,8 @@
import torch
import torch.functional as F
import torch.distributed as dist
import numpy as np
from typing import Tuple
from functools import partial
import triton


from lightllm.models.vit.layer_weights.transformer_layer_weight import ViTTransformerLayerWeight
from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward, torch_rms_norm
from lightllm.models.vit.triton_kernel.flashattention_nopad import flash_attention_fwd
from lightllm.utils.dist_utils import get_current_rank_in_dp, get_dp_world_size
from lightllm.models.vit.triton_kernel.gelu_vit import gelu_fwd
Expand Down Expand Up @@ -103,9 +98,13 @@ def _get_qkv(self, input, layer_weight: ViTTransformerLayerWeight) -> torch.Tens

def _context_attention_kernel(self, q, k, v) -> torch.Tensor:
out = g_cache_manager.alloc_tensor(q.shape, q.dtype, device=q.device)
batch_size = q.shape[0]
seq_len = q.shape[1]
flash_attention_fwd(q, k, v, out)
batch_size, seq_len, head_num, head_dim = q.shape
total_len = batch_size * seq_len
reshape = lambda t: t.view(total_len, head_num, head_dim)
q, k, v, out = map(reshape, (q, k, v, out))
cu_seqlens = torch.arange(batch_size + 1, dtype=torch.int32, device=q.device) * seq_len
max_seqlen = seq_len
flash_attention_fwd(q, k, v, out, cu_seqlens, max_seqlen)
return out.reshape(batch_size, seq_len, -1)

def _get_o(self, input, layer_weight: ViTTransformerLayerWeight) -> torch.Tensor:
Expand Down
Loading