Skip to content

Commit f3f29e6

Browse files
committed
Merge remote-tracking branch 'origin/main' into pd_master
2 parents e9d8758 + c2dbc8f commit f3f29e6

File tree

7 files changed

+154
-6
lines changed

7 files changed

+154
-6
lines changed

.github/workflows/docker-publish.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,12 @@ jobs:
3333
- name: Checkout repository
3434
uses: actions/checkout@v3
3535

36+
- name: Set up swap space
37+
if: runner.os == 'Linux'
38+
uses: pierotofy/[email protected]
39+
with:
40+
swap-size-gb: 10
41+
3642
# clean cache image
3743
- name: Clean up Docker space
3844
run: |

lightllm/models/llama/model.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,11 @@ def _init_custom(self):
8888
and self.config.get("rope_scaling", {}).get("rope_type", "base") == "llama3"
8989
):
9090
self._init_to_get_llama3_rotary()
91+
elif (
92+
self.config.get("rope_scaling", None) is not None
93+
and self.config.get("rope_scaling", {}).get("type", "base") == "mrope"
94+
):
95+
self._init_to_get_mrope_rotary()
9196
else:
9297
self._init_to_get_rotary()
9398
return
@@ -332,3 +337,47 @@ def _init_to_get_llama3_rotary(self, default_base=10000):
332337
self._cos_cached = torch.cos(freqs).to(self.data_type).cuda()
333338
self._sin_cached = torch.sin(freqs).to(self.data_type).cuda()
334339
return
340+
341+
def _init_to_get_mrope_rotary(self, default_base=10000):
342+
partial_head_dim = int(self.config.get("partial_rotary_factor", 1) * self.head_dim_)
343+
if self.config.get("rope_scaling", {}) is None:
344+
rope_scaling_factor = 1.0
345+
else:
346+
rope_scaling_factor = self.config.get("rope_scaling", {}).get("factor", 1.0)
347+
348+
base = self.config.get("rope_theta", float(default_base))
349+
350+
if "max_sequence_length" in self.config:
351+
max_seq_len = self.config["max_sequence_length"]
352+
else:
353+
max_position_embeddings = self.config.get(
354+
"max_position_embeddings", 2048 if base <= 10000.0 + 1e-5 else 16384
355+
)
356+
max_seq_len = max_position_embeddings * rope_scaling_factor
357+
358+
# NTK
359+
try:
360+
ntk_alpha = float(os.environ.get("LIGHTLLM_NTK_ALPHA", 1))
361+
assert ntk_alpha >= 1
362+
if ntk_alpha > 1:
363+
logger.info(f"Note: NTK enabled, alpha set to {ntk_alpha}")
364+
max_seq_len *= ntk_alpha
365+
base = base * (ntk_alpha ** (partial_head_dim / (partial_head_dim - 2))) # Base change formula
366+
except:
367+
pass
368+
369+
inv_freq = 1.0 / (
370+
base ** (torch.arange(0, partial_head_dim, 2, device="cpu", dtype=torch.float32) / partial_head_dim)
371+
)
372+
373+
t = (
374+
torch.arange(max(max_seq_len + 1024 * 128, self.max_seq_length), device="cpu", dtype=torch.float32)
375+
/ rope_scaling_factor
376+
)
377+
freqs = torch.outer(t, inv_freq).unsqueeze(0).expand(3, -1, -1)
378+
freqs = torch.cat((freqs, freqs), dim=-1)
379+
380+
self._cos_cached = torch.cos(freqs).to(self.data_type).cuda()
381+
self._sin_cached = torch.sin(freqs).to(self.data_type).cuda()
382+
383+
return
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import torch
2+
import numpy as np
3+
from lightllm.models.llama.infer_struct import LlamaInferStateInfo
4+
5+
6+
class Qwen2VLInferStateInfo(LlamaInferStateInfo):
7+
def __init__(self):
8+
super().__init__()
9+
self.position_cos = None
10+
self.position_sin = None
11+
12+
def init_some_extra_state(self, model, input_ids: torch.Tensor):
13+
if self.is_prefill:
14+
b_seq_len_numpy = self.b_seq_len.cpu().numpy()
15+
self.max_seq_len = b_seq_len_numpy.max()
16+
position_ids = torch.from_numpy(
17+
np.concatenate([np.arange(0, b_seq_len_numpy[i]) for i in range(len(b_seq_len_numpy))])
18+
).cuda()
19+
self.position_sin = model._sin_cached[:, position_ids, :].unsqueeze(1)
20+
self.position_cos = model._cos_cached[:, position_ids, :].unsqueeze(1)
21+
position_ids = None
22+
else:
23+
position_ids = self.b_seq_len - 1
24+
self.position_sin = model._sin_cached[:, position_ids, :].unsqueeze(1)
25+
self.position_cos = model._cos_cached[:, position_ids, :].unsqueeze(1)
26+
return
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import torch
2+
import torch.functional as F
3+
import torch.distributed as dist
4+
import numpy as np
5+
6+
from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer
7+
from functools import partial
8+
9+
10+
def rotate_half(x):
11+
x1 = x[..., : x.shape[-1] // 2]
12+
x2 = x[..., x.shape[-1] // 2 :]
13+
return torch.cat((-x2, x1), dim=-1)
14+
15+
16+
def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
17+
mrope_section = mrope_section * 2
18+
cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(unsqueeze_dim)
19+
sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(unsqueeze_dim)
20+
21+
q_embed = (q * cos) + (rotate_half(q) * sin)
22+
k_embed = (k * cos) + (rotate_half(k) * sin)
23+
24+
return q_embed, k_embed
25+
26+
27+
class Qwen2VLTransformerLayerInfer(LlamaTransformerLayerInfer):
28+
def __init__(self, layer_num, network_config, mode=[]):
29+
super().__init__(layer_num, network_config, mode)
30+
self.mrope_section = network_config["rope_scaling"]["mrope_section"]
31+
32+
def _get_qkv(self, input, cache_kv, infer_state, layer_weight):
33+
q = layer_weight.q_proj.mm(input)
34+
cache_kv = layer_weight.kv_proj.mm(
35+
input, out=cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_)
36+
).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_)
37+
seq_len, _ = q.shape
38+
q = q.view(1, seq_len, -1, self.head_dim_).transpose(1, 2)
39+
k = cache_kv[:, : self.tp_k_head_num_, :].view(1, seq_len, -1, self.head_dim_).transpose(1, 2)
40+
new_q, new_k = apply_multimodal_rotary_pos_emb(
41+
q, k, infer_state.position_cos, infer_state.position_sin, self.mrope_section
42+
)
43+
new_q = new_q.transpose(1, 2).reshape(1, seq_len, -1)
44+
cache_kv[:, : self.tp_k_head_num_, :] = new_k.squeeze(0).permute(1, 0, 2)
45+
46+
return new_q, cache_kv

lightllm/models/qwen2_vl/model.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,16 @@
1212
from typing import List, Optional, Union
1313
from transformers.utils import TensorType, logging
1414
from lightllm.common.build_utils import repair_config
15+
from lightllm.models.qwen2_vl.infer_struct import Qwen2VLInferStateInfo
16+
from lightllm.models.qwen2_vl.layer_infer.transformer_layer_infer import Qwen2VLTransformerLayerInfer
1517

16-
# from lightllm.models.qwen2_vl.vision_process import Qwen2VLImageProcessor
1718
import torch
1819
from PIL import Image
1920
from .vision_process import smart_resize
2021
from lightllm.models.qwen2.layer_weights import transformer_layer_weight, pre_and_post_layer_weight
2122
from lightllm.models.qwen2.model import Qwen2TpPartModel
2223
import os
2324

24-
# from lightllm.models.qwen2_vl.layer_weight.pre_and_post_layer_weight import Qwen2VLPreAndPostLayerWeight
25-
2625
# Warp of the origal tokenizer
2726
class QWen2VLTokenizer:
2827
def __init__(self, tokenizer=None, image_processor=None, **kwargs):
@@ -89,10 +88,10 @@ def __getattr__(self, name):
8988

9089
class Qwen2VLTpPartModel(Qwen2TpPartModel):
9190

92-
# weight class
93-
# pre_and_post_weight_class = Qwen2VLPreAndPostLayerWeight
94-
# infer class
9591
pre_layer_infer_class = LlamaMultimodalPreLayerInfer
92+
transformer_layer_infer_class = Qwen2VLTransformerLayerInfer
93+
94+
infer_state_class = Qwen2VLInferStateInfo
9695

9796
def __init__(self, kvargs):
9897
super().__init__(kvargs)

lightllm/server/httpserver/manager.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,12 @@
2525
from lightllm.server.core.objs.io_objs import GroupReqObjs
2626
from fastapi import Request
2727
from lightllm.server.core.objs.shm_req_manager import ShmReqManager
28+
from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt
2829
from lightllm.utils.log_utils import init_logger
2930
from lightllm.server.metrics.manager import MetricClient
3031
from lightllm.utils.statics_utils import MovingAverage
3132
from lightllm.utils.config_utils import get_vocab_size
33+
from lightllm.utils.envs_utils import get_unique_server_name
3234

3335
logger = init_logger(__name__)
3436

@@ -103,6 +105,10 @@ def __init__(
103105
# 有的模型的vocab size 读取tokenizer和config.json中不一致
104106
self.vocab_size = max(get_vocab_size(args.model_dir), self.tokenizer.vocab_size)
105107

108+
# The timemark of the latest inference(prefill/decode) which is used to check the health status of the system.
109+
# If the timemark is not updated for a pre-set time, a prob request will be sent to the backend.
110+
self.latest_success_infer_time_mark = SharedInt(f"{get_unique_server_name()}_latest_success_infer_time_mark")
111+
self.latest_success_infer_time_mark.set_value(int(time.time()))
106112
return
107113

108114
# connect cache server, calculate md5, alloc resource, return uuid
@@ -483,6 +489,9 @@ async def _wait_to_token_package(
483489

484490
out_token_counter += 1
485491

492+
# update inference timemark
493+
self.latest_success_infer_time_mark.set_value(int(time.time()))
494+
486495
yield sub_req_id, out_str, metadata, finish_status
487496
# 如果有子请求完成,就更新计数
488497
if finish_status.is_finished():

lightllm/utils/health_check.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
11
import os
2+
import time
23
import asyncio
34
import numpy as np
45
from dataclasses import dataclass
56
from lightllm.server.core.objs import SamplingParams
67
from lightllm.server.multimodal_params import MultimodalParams
78
from lightllm.server.httpserver.manager import HttpServerManager
9+
from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt
810
from fastapi import Request
911
from lightllm.server.req_id_generator import ReqIDGenerator
1012
from lightllm.utils.log_utils import init_logger
13+
from lightllm.utils.envs_utils import get_unique_server_name
1114

1215
logger = init_logger(__name__)
1316

@@ -24,6 +27,7 @@ class HealthObj:
2427
_failure_threshold: int = int(os.getenv("HEALTH_FAILURE_THRESHOLD", 3))
2528
timeout: int = int(os.getenv("HEALTH_TIMEOUT", 100))
2629
dynamic_timeout: int = int(os.getenv("HEALTH_TIMEOUT", 100))
30+
latest_success_infer_time_mark = SharedInt(f"{get_unique_server_name()}_latest_success_infer_time_mark")
2731

2832
def begin_check(self):
2933
self._is_health_checking = True
@@ -48,13 +52,22 @@ def is_health(self):
4852
def is_checking(self):
4953
return self._is_health_checking
5054

55+
def has_latest_inference(self):
56+
last_timemark = self.latest_success_infer_time_mark.get_value()
57+
time_diff = time.time() - last_timemark
58+
return time_diff < self.timeout
59+
5160

5261
health_obj = HealthObj()
5362

5463

5564
async def health_check(args, httpserver_manager: HttpServerManager, request: Request):
5665
if health_obj.is_checking():
5766
return health_obj.is_health()
67+
68+
if health_obj.is_health() and health_obj.has_latest_inference():
69+
return health_obj.is_health()
70+
5871
health_obj.begin_check()
5972
try:
6073
request_dict = {"inputs": "你好!", "parameters": {"do_sample": True, "temperature": 0.8, "max_new_tokens": 2}}

0 commit comments

Comments
 (0)