Skip to content

Commit 6ee2bc8

Browse files
committed
feat: audio server and audio multimodal support
1 parent 28bf517 commit 6ee2bc8

File tree

14 files changed

+630
-22
lines changed

14 files changed

+630
-22
lines changed

lightllm/models/internvl/model.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from lightllm.models.qwen2.model import Qwen2TpPartModel
77
from lightllm.models.deepseek2.model import Deepseek2TpPartModel
88
from lightllm.models.qwen_vl.layer_infer.pre_layer_infer import LlamaMultimodalPreLayerInfer
9-
from lightllm.server.multimodal_params import MultimodalParams, ImageItem
9+
from lightllm.server.multimodal_params import AudioItem, MultimodalParams, ImageItem
1010
from lightllm.common.build_utils import repair_config
1111
from lightllm.models.internvl.layer_weights.pre_and_post_layer_weight import (
1212
InternVLLlamaPreAndPostLayerWeight,
@@ -26,6 +26,8 @@
2626
IMG_START_TOKEN = "<img>"
2727
IMG_END_TOKEN = "</img>"
2828
IMG_TOKEN = "<image>"
29+
AUDIO_START_TOKEN = "<audio>"
30+
AUDIO_END_TOKEN = "</audio>"
2931

3032

3133
# Warp of the origal tokenizer
@@ -40,6 +42,12 @@ def __init__(self, tokenizer, model_cfg, **kwargs):
4042

4143
self.image_end_tag = IMG_END_TOKEN
4244
self.image_end_id = tokenizer.convert_tokens_to_ids(self.image_end_tag)
45+
46+
self.audio_start_tag = AUDIO_START_TOKEN
47+
self.audio_start_id = tokenizer.convert_tokens_to_ids(self.audio_start_tag)
48+
49+
self.audio_end_tag = AUDIO_END_TOKEN
50+
self.audio_end_id = tokenizer.convert_tokens_to_ids(self.audio_end_tag)
4351
self.get_image_patch_func = get_image_patch_func(kwargs["weight_dir"])
4452

4553
def init_imageItem_extral_params(
@@ -68,6 +76,20 @@ def get_image_token_length(self, img: ImageItem):
6876
)
6977
* self.image_length
7078
)
79+
80+
def get_audio_token_length(self, audio: AudioItem):
81+
L = audio.audio_length
82+
L = (L if L <= 480000 else 480000) # max_length < 30s
83+
mel_len = L // 160
84+
dilation = 1
85+
L_in = mel_len
86+
for (padding, kernel_size, stride) in eval("[(1,3,1)] + [(1,3,2)] "):
87+
L_out = L_in + 2 * padding - dilation * (kernel_size - 1) - 1
88+
L_out = 1 + L_out // stride
89+
L_in = L_out
90+
audio_len_after_cnn = L_out
91+
audio_token_num = (audio_len_after_cnn - 2) // 2 + 1
92+
return audio_token_num
7193

7294
# only change the impl of the encode func:
7395
def encode(self, prompt, multimodal_params: MultimodalParams = None, **kwargs):
@@ -103,6 +125,31 @@ def encode(self, prompt, multimodal_params: MultimodalParams = None, **kwargs):
103125
except ValueError:
104126
break
105127
input_ids.extend(origin_ids[start_idx:])
128+
129+
# audio
130+
origin_ids = input_ids
131+
input_ids = []
132+
audio_id = 0
133+
start_idx = 0
134+
while True:
135+
try:
136+
start_idx = origin_ids.index(self.audio_start_id, start_idx)
137+
if start_idx + 1 >= len(origin_ids):
138+
break
139+
if origin_ids[start_idx + 1] == self.audio_end_id:
140+
input_ids.extend(origin_ids[: start_idx + 1])
141+
token_id = multimodal_params.audios[audio_id].token_id
142+
token_num = multimodal_params.audios[audio_id].token_num
143+
input_ids.extend(range(token_id, token_id + token_num))
144+
input_ids.append(self.audio_end_id)
145+
origin_ids = origin_ids[start_idx + 2 :]
146+
start_idx = 0
147+
audio_id += 1
148+
else:
149+
raise ValueError("audio token error")
150+
except ValueError:
151+
break
152+
input_ids.extend(origin_ids[start_idx:])
106153
return input_ids
107154

108155
def __getattr__(self, name):

lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_wei
4343
dtype = layer_weight.wte_weight_.dtype
4444
hidden_size = layer_weight.wte_weight_.shape[1]
4545
for batch_id, p in enumerate(infer_state.multimodal_params):
46-
for img in p["images"]:
46+
for img in p["images"] + p["audios"]:
4747
# skip the same image
4848
if img["token_id"] in img_start_token_ids:
4949
continue

lightllm/models/whisper/__init__.py

Whitespace-only changes.
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
MIN_AUDIO_LEN = 480 # 最短音频长度
Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
import torch
2+
import librosa
3+
from io import BytesIO
4+
from typing import List, Union
5+
import numpy as np
6+
from torch import nn
7+
from safetensors.torch import load_file
8+
9+
import json
10+
import torch.nn.functional as F
11+
import math
12+
import os
13+
import rpyc
14+
from transformers.processing_utils import ProcessorMixin
15+
from lightllm.server.embed_cache.utils import tensor2bytes, read_shm, create_shm, get_shm_name_data, get_shm_name_embed
16+
17+
18+
class WhisperProcessor(ProcessorMixin):
19+
r"""
20+
Constructs a Whisper processor which wraps a Whisper feature extractor and a Whisper tokenizer into a single
21+
processor.
22+
23+
[`WhisperProcessor`] offers all the functionalities of [`WhisperFeatureExtractor`] and [`WhisperTokenizer`]. See
24+
the [`~WhisperProcessor.__call__`] and [`~WhisperProcessor.decode`] for more information.
25+
26+
Args:
27+
feature_extractor (`WhisperFeatureExtractor`):
28+
An instance of [`WhisperFeatureExtractor`]. The feature extractor is a required input.
29+
tokenizer (`WhisperTokenizer`):
30+
An instance of [`WhisperTokenizer`]. The tokenizer is a required input.
31+
"""
32+
attributes = ["feature_extractor"]
33+
feature_extractor_class = "WhisperFeatureExtractor"
34+
# tokenizer_class = "WhisperTokenizer"
35+
36+
def __init__(self, feature_extractor):
37+
super().__init__(feature_extractor)
38+
self.current_processor = self.feature_extractor
39+
self._in_target_context_manager = False
40+
41+
def get_decoder_prompt_ids(self, task=None, language=None, no_timestamps=True):
42+
return self.tokenizer.get_decoder_prompt_ids(task=task, language=language, no_timestamps=no_timestamps)
43+
44+
def get_T_after_cnn(self, L_in, dilation=1):
45+
for (padding, kernel_size, stride) in eval("[(1,3,1)] + [(1,3,2)] "):
46+
L_out = L_in + 2 * padding - dilation * (kernel_size - 1) - 1
47+
L_out = 1 + L_out // stride
48+
L_in = L_out
49+
return L_out
50+
51+
def __call__(self, audios, audio_lens, *args, **kwargs):
52+
"""
53+
Forwards the `audios` argument to WhisperFeatureExtractor's [`~WhisperFeatureExtractor.__call__`] and the `text`
54+
argument to [`~WhisperTokenizer.__call__`]. Please refer to the doctsring of the above two methods for more
55+
information.
56+
"""
57+
# For backward compatibility
58+
if self._in_target_context_manager:
59+
return self.current_processor(*args, **kwargs)
60+
61+
sampling_rate = kwargs.pop("sampling_rate", 16000)
62+
63+
64+
audio_lens = np.where(audio_lens <=480000, audio_lens, 480000)
65+
audio_lens = audio_lens // 160
66+
audio_lens_after_cnn = self.get_T_after_cnn(audio_lens)
67+
padded_inputs = self.feature_extractor(audios, *args, sampling_rate=sampling_rate, **kwargs)
68+
69+
return padded_inputs['input_features'], audio_lens_after_cnn
70+
71+
def batch_decode(self, *args, **kwargs):
72+
"""
73+
This method forwards all its arguments to WhisperTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please
74+
refer to the docstring of this method for more information.
75+
"""
76+
return self.tokenizer.batch_decode(*args, **kwargs)
77+
78+
def decode(self, *args, **kwargs):
79+
"""
80+
This method forwards all its arguments to WhisperTokenizer's [`~PreTrainedTokenizer.decode`]. Please refer to
81+
the docstring of this method for more information.
82+
"""
83+
return self.tokenizer.decode(*args, **kwargs)
84+
85+
def get_prompt_ids(self, text: str, return_tensors="np"):
86+
return self.tokenizer.get_prompt_ids(text, return_tensors=return_tensors)
87+
88+
89+
class WhisperAudioModel:
90+
def __init__(self, kvargs):
91+
self.max_seconds = 30
92+
self.sampling_rate = 16000
93+
self.max_length = self.max_seconds * self.sampling_rate
94+
self.cache_port = kvargs["client_port"]
95+
self.cache_client = rpyc.connect("localhost", self.cache_port)
96+
data_type = kvargs["data_type"]
97+
if data_type in ["bf16", "bfloat16"]:
98+
self.data_type = torch.bfloat16
99+
else:
100+
self.data_type = torch.float16
101+
102+
def cuda(self):
103+
self.audio = self.audio.cuda()
104+
for k, v in self.projector_weights.items():
105+
self.projector_weights[k] = v.cuda()
106+
self.device = torch.device("cuda")
107+
return self
108+
109+
110+
def load_model(self, weight_dir, config):
111+
self.audio_processor = WhisperProcessor.from_pretrained(weight_dir)
112+
from transformers.models.whisper.modeling_whisper import WhisperEncoder, WhisperConfig
113+
self.audio = WhisperEncoder(WhisperConfig(**config['audio_config'])).to(self.data_type)
114+
self.device = torch.device("cpu")
115+
self.projector_weights = {}
116+
self.load_weight(weight_dir)
117+
118+
def load_weight(self, weight_dir):
119+
weight_path = os.path.join(weight_dir, 'model.safetensors.index.json')
120+
weight_map = json.load(open(weight_path, "r"))['weight_map']
121+
params_map = {}
122+
audio_weight = {}
123+
for k,v in weight_map.items():
124+
if "mlp2" not in k and "audio_model" not in k:
125+
continue
126+
filename = weight_map[k]
127+
if filename not in params_map:
128+
tensor_data = load_file(os.path.join(weight_dir, filename))
129+
params_map[filename] = tensor_data
130+
if "mlp2" in k:
131+
self.projector_weights[k] = params_map[filename][k].to(self.data_type)
132+
if "audio_model" in k:
133+
audio_weight[k[len("audio_model.encoder."):]] = params_map[filename][k].to(self.data_type)
134+
135+
self.audio.load_state_dict(audio_weight)
136+
137+
assert "mlp2.0.bias" in self.projector_weights
138+
assert "mlp2.0.weight" in self.projector_weights
139+
assert "mlp2.1.bias" in self.projector_weights
140+
assert "mlp2.1.weight" in self.projector_weights
141+
assert "mlp2.3.bias" in self.projector_weights
142+
assert "mlp2.3.weight" in self.projector_weights
143+
144+
def forward(self, audio_values, audio_lens_after_cnn):
145+
audio_values = audio_values.to(self.data_type).to(device=self.device)
146+
audio_values = audio_values.squeeze(1)
147+
audio_lens_after_cnn = torch.tensor(audio_lens_after_cnn).cuda()
148+
max_len_in_batch = torch.max(audio_lens_after_cnn).item()
149+
150+
padding_mask = torch.ones([audio_values.size(0), max_len_in_batch]).to(dtype=audio_values.dtype,
151+
device=audio_values.device)
152+
for index in range(len(audio_values)):
153+
padding_mask[index, :audio_lens_after_cnn[index].item()] = 0
154+
last_hidden_state = self.audio(audio_values, padding_mask).last_hidden_state
155+
x = F.layer_norm(
156+
last_hidden_state,
157+
normalized_shape=(last_hidden_state.shape[-1],),
158+
weight=self.projector_weights["mlp2.0.weight"],
159+
bias=self.projector_weights["mlp2.0.bias"]
160+
)
161+
x = F.linear(
162+
x,
163+
weight=self.projector_weights["mlp2.1.weight"],
164+
bias=self.projector_weights["mlp2.1.bias"]
165+
)
166+
x = F.gelu(x)
167+
x = F.linear(
168+
x,
169+
weight=self.projector_weights["mlp2.3.weight"],
170+
bias=self.projector_weights["mlp2.3.bias"]
171+
)
172+
return x
173+
174+
def encode(self, audio_items: List[Union[str, BytesIO]]):
175+
batch_audios = []
176+
batch_audio_lens = np.zeros(len(audio_items), dtype=np.int32)
177+
uuids = []
178+
for i, item in enumerate(audio_items):
179+
if isinstance(item, int):
180+
uuids.append(item)
181+
audio_data = read_shm(get_shm_name_data(item))
182+
audio = BytesIO(audio_data)
183+
audio, _ = librosa.load(audio, sr=16000)
184+
elif isinstance(item, BytesIO):
185+
audio, _ = librosa.load(item, sr=16000)
186+
elif item.startswith("http://") or item.startswith("https://"):
187+
import requests
188+
audio = BytesIO(requests.get(item, stream=True).raw.read())
189+
audio, _ = librosa.load(audio, sr=16000)
190+
else:
191+
raise ValueError(f"cannot read audio which type is {type(item)}!")
192+
193+
# padding to min audio len
194+
from .defaults import MIN_AUDIO_LEN
195+
if audio.shape[0] < MIN_AUDIO_LEN:
196+
audio = np.pad(audio, (0, MIN_AUDIO_LEN - len(audio)), mode='constant', constant_values=0.0)
197+
198+
batch_audio_lens[i] = min(audio.shape[0], self.max_length)
199+
batch_audios.append(audio)
200+
201+
audios, audio_lens_after_cnn = self.audio_processor(batch_audios, batch_audio_lens, sampling_rate=16000, return_tensors="pt")
202+
audios = self.forward(audios, audio_lens_after_cnn)
203+
audio_lens_after_cnn = np.array(audio_lens_after_cnn, dtype=np.int32)
204+
audio_token_num = (audio_lens_after_cnn - 2) // 2 + 1
205+
206+
for i in range(len(uuids)):
207+
if not self.cache_client.root.get_item_embed(uuids[i]):
208+
cur_embed_bytes = tensor2bytes(audios[i][:audio_token_num[i]])
209+
create_shm(get_shm_name_embed(uuids[i]), cur_embed_bytes)
210+
self.cache_client.root.set_item_embed(uuids[i])

lightllm/server/api_cli.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,10 @@ def make_argument_parser() -> argparse.ArgumentParser:
205205
use env FIRST_ALLOWED_TOKENS to set the range, like FIRST_ALLOWED_TOKENS=1,2 ..""",
206206
)
207207
parser.add_argument(
208-
"--enable_multimodal", action="store_true", help="Whether or not to allow to load additional multimodal models."
208+
"--enable_multimodal", action="store_true", help="Whether or not to allow to load additional visual models."
209+
)
210+
parser.add_argument(
211+
"--enable_multimodal_audio", action="store_true", help="Whether or not to allow to load additional audio models (requird --enable_multimodal)."
209212
)
210213
parser.add_argument(
211214
"--enable_mps", action="store_true", help="Whether to enable nvidia mps for multimodal service."

lightllm/server/api_start.py

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import uuid
55
import subprocess
66
import signal
7+
from lightllm.server.audioserver.manager import start_audio_process
78
from lightllm.utils.net_utils import alloc_can_use_network_port, PortLocker
89
from lightllm.utils.start_utils import process_manager
910
from .metrics.manager import start_metric_manager
@@ -173,11 +174,11 @@ def normal_or_p_d_start(args):
173174

174175
node_world_size = args.tp // args.nnodes
175176
can_use_ports = alloc_can_use_network_port(
176-
num=6 + node_world_size + args.visual_dp * args.visual_tp, used_nccl_ports=already_uesd_ports
177+
num=7 + node_world_size + args.visual_dp * args.visual_tp, used_nccl_ports=already_uesd_ports
177178
)
178179
logger.info(f"alloced ports: {can_use_ports}")
179-
router_port, detokenization_port, detokenization_pub_port, visual_port, cache_port, metric_port = can_use_ports[0:6]
180-
can_use_ports = can_use_ports[6:]
180+
router_port, detokenization_port, detokenization_pub_port, visual_port, audio_port, cache_port, metric_port = can_use_ports[0:7]
181+
can_use_ports = can_use_ports[7:]
181182

182183
visual_model_tp_ports = []
183184
for _ in range(args.visual_dp):
@@ -190,6 +191,7 @@ def normal_or_p_d_start(args):
190191
args.detokenization_port = detokenization_port
191192
args.detokenization_pub_port = detokenization_pub_port
192193
args.visual_port = visual_port
194+
args.audio_port = audio_port
193195
args.cache_port = cache_port
194196
args.metric_port = metric_port
195197

@@ -218,14 +220,33 @@ def normal_or_p_d_start(args):
218220
],
219221
start_args=[(cache_port, args)],
220222
)
221-
process_manager.start_submodule_processes(
222-
start_funcs=[
223-
start_visual_process,
224-
],
225-
start_args=[
226-
(args, router_port, visual_port, cache_port, visual_model_tp_ports),
227-
],
228-
)
223+
if args.enable_multimodal_audio:
224+
process_manager.start_submodule_processes(
225+
start_funcs=[
226+
start_visual_process,
227+
],
228+
start_args=[
229+
(args, audio_port, visual_port, cache_port, visual_model_tp_ports),
230+
],
231+
)
232+
process_manager.start_submodule_processes(
233+
start_funcs=[
234+
start_audio_process,
235+
],
236+
start_args=[
237+
(args, router_port, audio_port, cache_port),
238+
],
239+
)
240+
241+
else:
242+
process_manager.start_submodule_processes(
243+
start_funcs=[
244+
start_visual_process,
245+
],
246+
start_args=[
247+
(args, router_port, visual_port, cache_port, visual_model_tp_ports),
248+
],
249+
)
229250

230251
process_manager.start_submodule_processes(
231252
start_funcs=[

lightllm/server/audioserver/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)