Skip to content

Commit 5e2584c

Browse files
committed
feat: audio server and audio multimodal support
1 parent 28bf517 commit 5e2584c

File tree

14 files changed

+643
-22
lines changed

14 files changed

+643
-22
lines changed

lightllm/models/internvl/model.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from lightllm.models.qwen2.model import Qwen2TpPartModel
77
from lightllm.models.deepseek2.model import Deepseek2TpPartModel
88
from lightllm.models.qwen_vl.layer_infer.pre_layer_infer import LlamaMultimodalPreLayerInfer
9-
from lightllm.server.multimodal_params import MultimodalParams, ImageItem
9+
from lightllm.server.multimodal_params import AudioItem, MultimodalParams, ImageItem
1010
from lightllm.common.build_utils import repair_config
1111
from lightllm.models.internvl.layer_weights.pre_and_post_layer_weight import (
1212
InternVLLlamaPreAndPostLayerWeight,
@@ -26,6 +26,8 @@
2626
IMG_START_TOKEN = "<img>"
2727
IMG_END_TOKEN = "</img>"
2828
IMG_TOKEN = "<image>"
29+
AUDIO_START_TOKEN = "<audio>"
30+
AUDIO_END_TOKEN = "</audio>"
2931

3032

3133
# Warp of the origal tokenizer
@@ -40,6 +42,12 @@ def __init__(self, tokenizer, model_cfg, **kwargs):
4042

4143
self.image_end_tag = IMG_END_TOKEN
4244
self.image_end_id = tokenizer.convert_tokens_to_ids(self.image_end_tag)
45+
46+
self.audio_start_tag = AUDIO_START_TOKEN
47+
self.audio_start_id = tokenizer.convert_tokens_to_ids(self.audio_start_tag)
48+
49+
self.audio_end_tag = AUDIO_END_TOKEN
50+
self.audio_end_id = tokenizer.convert_tokens_to_ids(self.audio_end_tag)
4351
self.get_image_patch_func = get_image_patch_func(kwargs["weight_dir"])
4452

4553
def init_imageItem_extral_params(
@@ -69,6 +77,20 @@ def get_image_token_length(self, img: ImageItem):
6977
* self.image_length
7078
)
7179

80+
def get_audio_token_length(self, audio: AudioItem):
81+
L = audio.audio_length
82+
L = L if L <= 480000 else 480000 # max_length < 30s
83+
mel_len = L // 160
84+
dilation = 1
85+
L_in = mel_len
86+
for (padding, kernel_size, stride) in eval("[(1,3,1)] + [(1,3,2)] "):
87+
L_out = L_in + 2 * padding - dilation * (kernel_size - 1) - 1
88+
L_out = 1 + L_out // stride
89+
L_in = L_out
90+
audio_len_after_cnn = L_out
91+
audio_token_num = (audio_len_after_cnn - 2) // 2 + 1
92+
return audio_token_num
93+
7294
# only change the impl of the encode func:
7395
def encode(self, prompt, multimodal_params: MultimodalParams = None, **kwargs):
7496
# TEXT<image>TEXT<image>TEXT --> TEXT<img></img>TEXT<img></img>TEXT
@@ -103,6 +125,31 @@ def encode(self, prompt, multimodal_params: MultimodalParams = None, **kwargs):
103125
except ValueError:
104126
break
105127
input_ids.extend(origin_ids[start_idx:])
128+
129+
# audio
130+
origin_ids = input_ids
131+
input_ids = []
132+
audio_id = 0
133+
start_idx = 0
134+
while True:
135+
try:
136+
start_idx = origin_ids.index(self.audio_start_id, start_idx)
137+
if start_idx + 1 >= len(origin_ids):
138+
break
139+
if origin_ids[start_idx + 1] == self.audio_end_id:
140+
input_ids.extend(origin_ids[: start_idx + 1])
141+
token_id = multimodal_params.audios[audio_id].token_id
142+
token_num = multimodal_params.audios[audio_id].token_num
143+
input_ids.extend(range(token_id, token_id + token_num))
144+
input_ids.append(self.audio_end_id)
145+
origin_ids = origin_ids[start_idx + 2 :]
146+
start_idx = 0
147+
audio_id += 1
148+
else:
149+
raise ValueError("audio token error")
150+
except ValueError:
151+
break
152+
input_ids.extend(origin_ids[start_idx:])
106153
return input_ids
107154

108155
def __getattr__(self, name):

lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_wei
4343
dtype = layer_weight.wte_weight_.dtype
4444
hidden_size = layer_weight.wte_weight_.shape[1]
4545
for batch_id, p in enumerate(infer_state.multimodal_params):
46-
for img in p["images"]:
46+
for img in p["images"] + p["audios"]:
4747
# skip the same image
4848
if img["token_id"] in img_start_token_ids:
4949
continue

lightllm/models/whisper/__init__.py

Whitespace-only changes.
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
MIN_AUDIO_LEN = 480 # 最短音频长度
Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
import torch
2+
import librosa
3+
from io import BytesIO
4+
from typing import List, Union
5+
import numpy as np
6+
from torch import nn
7+
from safetensors.torch import load_file
8+
9+
import json
10+
import torch.nn.functional as F
11+
import math
12+
import os
13+
import rpyc
14+
from transformers.processing_utils import ProcessorMixin
15+
from lightllm.server.embed_cache.utils import tensor2bytes, read_shm, create_shm, get_shm_name_data, get_shm_name_embed
16+
17+
18+
class WhisperProcessor(ProcessorMixin):
19+
r"""
20+
Constructs a Whisper processor which wraps a Whisper feature extractor and a Whisper tokenizer into a single
21+
processor.
22+
23+
[`WhisperProcessor`] offers all the functionalities of [`WhisperFeatureExtractor`] and [`WhisperTokenizer`]. See
24+
the [`~WhisperProcessor.__call__`] and [`~WhisperProcessor.decode`] for more information.
25+
26+
Args:
27+
feature_extractor (`WhisperFeatureExtractor`):
28+
An instance of [`WhisperFeatureExtractor`]. The feature extractor is a required input.
29+
tokenizer (`WhisperTokenizer`):
30+
An instance of [`WhisperTokenizer`]. The tokenizer is a required input.
31+
"""
32+
attributes = ["feature_extractor"]
33+
feature_extractor_class = "WhisperFeatureExtractor"
34+
# tokenizer_class = "WhisperTokenizer"
35+
36+
def __init__(self, feature_extractor):
37+
super().__init__(feature_extractor)
38+
self.current_processor = self.feature_extractor
39+
self._in_target_context_manager = False
40+
41+
def get_decoder_prompt_ids(self, task=None, language=None, no_timestamps=True):
42+
return self.tokenizer.get_decoder_prompt_ids(task=task, language=language, no_timestamps=no_timestamps)
43+
44+
def get_T_after_cnn(self, L_in, dilation=1):
45+
for (padding, kernel_size, stride) in eval("[(1,3,1)] + [(1,3,2)] "):
46+
L_out = L_in + 2 * padding - dilation * (kernel_size - 1) - 1
47+
L_out = 1 + L_out // stride
48+
L_in = L_out
49+
return L_out
50+
51+
def __call__(self, audios, audio_lens, *args, **kwargs):
52+
"""
53+
Forwards the `audios` argument to WhisperFeatureExtractor's [`~WhisperFeatureExtractor.__call__`] and the `text`
54+
argument to [`~WhisperTokenizer.__call__`]. Please refer to the doctsring of the above two methods for more
55+
information.
56+
"""
57+
# For backward compatibility
58+
if self._in_target_context_manager:
59+
return self.current_processor(*args, **kwargs)
60+
61+
sampling_rate = kwargs.pop("sampling_rate", 16000)
62+
63+
audio_lens = np.where(audio_lens <= 480000, audio_lens, 480000)
64+
audio_lens = audio_lens // 160
65+
audio_lens_after_cnn = self.get_T_after_cnn(audio_lens)
66+
padded_inputs = self.feature_extractor(audios, *args, sampling_rate=sampling_rate, **kwargs)
67+
68+
return padded_inputs["input_features"], audio_lens_after_cnn
69+
70+
def batch_decode(self, *args, **kwargs):
71+
"""
72+
This method forwards all its arguments to WhisperTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please
73+
refer to the docstring of this method for more information.
74+
"""
75+
return self.tokenizer.batch_decode(*args, **kwargs)
76+
77+
def decode(self, *args, **kwargs):
78+
"""
79+
This method forwards all its arguments to WhisperTokenizer's [`~PreTrainedTokenizer.decode`]. Please refer to
80+
the docstring of this method for more information.
81+
"""
82+
return self.tokenizer.decode(*args, **kwargs)
83+
84+
def get_prompt_ids(self, text: str, return_tensors="np"):
85+
return self.tokenizer.get_prompt_ids(text, return_tensors=return_tensors)
86+
87+
88+
class WhisperAudioModel:
89+
def __init__(self, kvargs):
90+
self.max_seconds = 30
91+
self.sampling_rate = 16000
92+
self.max_length = self.max_seconds * self.sampling_rate
93+
self.cache_port = kvargs["client_port"]
94+
self.cache_client = rpyc.connect("localhost", self.cache_port)
95+
data_type = kvargs["data_type"]
96+
if data_type in ["bf16", "bfloat16"]:
97+
self.data_type = torch.bfloat16
98+
else:
99+
self.data_type = torch.float16
100+
101+
def cuda(self):
102+
self.audio = self.audio.cuda()
103+
for k, v in self.projector_weights.items():
104+
self.projector_weights[k] = v.cuda()
105+
self.device = torch.device("cuda")
106+
return self
107+
108+
def load_model(self, weight_dir, config):
109+
self.audio_processor = WhisperProcessor.from_pretrained(weight_dir)
110+
from transformers.models.whisper.modeling_whisper import WhisperEncoder, WhisperConfig
111+
112+
self.audio = WhisperEncoder(WhisperConfig(**config["audio_config"])).to(self.data_type)
113+
self.device = torch.device("cpu")
114+
self.projector_weights = {}
115+
self.load_weight(weight_dir)
116+
117+
def load_weight(self, weight_dir):
118+
weight_path = os.path.join(weight_dir, "model.safetensors.index.json")
119+
weight_map = json.load(open(weight_path, "r"))["weight_map"]
120+
params_map = {}
121+
audio_weight = {}
122+
for k, v in weight_map.items():
123+
if "mlp2" not in k and "audio_model" not in k:
124+
continue
125+
filename = weight_map[k]
126+
if filename not in params_map:
127+
tensor_data = load_file(os.path.join(weight_dir, filename))
128+
params_map[filename] = tensor_data
129+
if "mlp2" in k:
130+
self.projector_weights[k] = params_map[filename][k].to(self.data_type)
131+
if "audio_model" in k:
132+
audio_weight[k[len("audio_model.encoder.") :]] = params_map[filename][k].to(self.data_type)
133+
134+
self.audio.load_state_dict(audio_weight)
135+
136+
assert "mlp2.0.bias" in self.projector_weights
137+
assert "mlp2.0.weight" in self.projector_weights
138+
assert "mlp2.1.bias" in self.projector_weights
139+
assert "mlp2.1.weight" in self.projector_weights
140+
assert "mlp2.3.bias" in self.projector_weights
141+
assert "mlp2.3.weight" in self.projector_weights
142+
143+
def forward(self, audio_values, audio_lens_after_cnn):
144+
audio_values = audio_values.to(self.data_type).to(device=self.device)
145+
audio_values = audio_values.squeeze(1)
146+
audio_lens_after_cnn = torch.tensor(audio_lens_after_cnn).cuda()
147+
max_len_in_batch = torch.max(audio_lens_after_cnn).item()
148+
149+
padding_mask = torch.ones([audio_values.size(0), max_len_in_batch]).to(
150+
dtype=audio_values.dtype, device=audio_values.device
151+
)
152+
for index in range(len(audio_values)):
153+
padding_mask[index, : audio_lens_after_cnn[index].item()] = 0
154+
last_hidden_state = self.audio(audio_values, padding_mask).last_hidden_state
155+
x = F.layer_norm(
156+
last_hidden_state,
157+
normalized_shape=(last_hidden_state.shape[-1],),
158+
weight=self.projector_weights["mlp2.0.weight"],
159+
bias=self.projector_weights["mlp2.0.bias"],
160+
)
161+
x = F.linear(x, weight=self.projector_weights["mlp2.1.weight"], bias=self.projector_weights["mlp2.1.bias"])
162+
x = F.gelu(x)
163+
x = F.linear(x, weight=self.projector_weights["mlp2.3.weight"], bias=self.projector_weights["mlp2.3.bias"])
164+
return x
165+
166+
def encode(self, audio_items: List[Union[str, BytesIO]]):
167+
batch_audios = []
168+
batch_audio_lens = np.zeros(len(audio_items), dtype=np.int32)
169+
uuids = []
170+
for i, item in enumerate(audio_items):
171+
if isinstance(item, int):
172+
uuids.append(item)
173+
audio_data = read_shm(get_shm_name_data(item))
174+
audio = BytesIO(audio_data)
175+
audio, _ = librosa.load(audio, sr=16000)
176+
elif isinstance(item, BytesIO):
177+
audio, _ = librosa.load(item, sr=16000)
178+
elif item.startswith("http://") or item.startswith("https://"):
179+
import requests
180+
181+
audio = BytesIO(requests.get(item, stream=True).raw.read())
182+
audio, _ = librosa.load(audio, sr=16000)
183+
else:
184+
raise ValueError(f"cannot read audio which type is {type(item)}!")
185+
186+
# padding to min audio len
187+
from .defaults import MIN_AUDIO_LEN
188+
189+
if audio.shape[0] < MIN_AUDIO_LEN:
190+
audio = np.pad(audio, (0, MIN_AUDIO_LEN - len(audio)), mode="constant", constant_values=0.0)
191+
192+
batch_audio_lens[i] = min(audio.shape[0], self.max_length)
193+
batch_audios.append(audio)
194+
195+
audios, audio_lens_after_cnn = self.audio_processor(
196+
batch_audios, batch_audio_lens, sampling_rate=16000, return_tensors="pt"
197+
)
198+
audios = self.forward(audios, audio_lens_after_cnn)
199+
audio_lens_after_cnn = np.array(audio_lens_after_cnn, dtype=np.int32)
200+
audio_token_num = (audio_lens_after_cnn - 2) // 2 + 1
201+
202+
for i in range(len(uuids)):
203+
if not self.cache_client.root.get_item_embed(uuids[i]):
204+
cur_embed_bytes = tensor2bytes(audios[i][: audio_token_num[i]])
205+
create_shm(get_shm_name_embed(uuids[i]), cur_embed_bytes)
206+
self.cache_client.root.set_item_embed(uuids[i])

lightllm/server/api_cli.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,12 @@ def make_argument_parser() -> argparse.ArgumentParser:
205205
use env FIRST_ALLOWED_TOKENS to set the range, like FIRST_ALLOWED_TOKENS=1,2 ..""",
206206
)
207207
parser.add_argument(
208-
"--enable_multimodal", action="store_true", help="Whether or not to allow to load additional multimodal models."
208+
"--enable_multimodal", action="store_true", help="Whether or not to allow to load additional visual models."
209+
)
210+
parser.add_argument(
211+
"--enable_multimodal_audio",
212+
action="store_true",
213+
help="Whether or not to allow to load additional audio models (requird --enable_multimodal).",
209214
)
210215
parser.add_argument(
211216
"--enable_mps", action="store_true", help="Whether to enable nvidia mps for multimodal service."

lightllm/server/api_start.py

Lines changed: 40 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import uuid
55
import subprocess
66
import signal
7+
from lightllm.server.audioserver.manager import start_audio_process
78
from lightllm.utils.net_utils import alloc_can_use_network_port, PortLocker
89
from lightllm.utils.start_utils import process_manager
910
from .metrics.manager import start_metric_manager
@@ -173,11 +174,19 @@ def normal_or_p_d_start(args):
173174

174175
node_world_size = args.tp // args.nnodes
175176
can_use_ports = alloc_can_use_network_port(
176-
num=6 + node_world_size + args.visual_dp * args.visual_tp, used_nccl_ports=already_uesd_ports
177+
num=7 + node_world_size + args.visual_dp * args.visual_tp, used_nccl_ports=already_uesd_ports
177178
)
178179
logger.info(f"alloced ports: {can_use_ports}")
179-
router_port, detokenization_port, detokenization_pub_port, visual_port, cache_port, metric_port = can_use_ports[0:6]
180-
can_use_ports = can_use_ports[6:]
180+
(
181+
router_port,
182+
detokenization_port,
183+
detokenization_pub_port,
184+
visual_port,
185+
audio_port,
186+
cache_port,
187+
metric_port,
188+
) = can_use_ports[0:7]
189+
can_use_ports = can_use_ports[7:]
181190

182191
visual_model_tp_ports = []
183192
for _ in range(args.visual_dp):
@@ -190,6 +199,7 @@ def normal_or_p_d_start(args):
190199
args.detokenization_port = detokenization_port
191200
args.detokenization_pub_port = detokenization_pub_port
192201
args.visual_port = visual_port
202+
args.audio_port = audio_port
193203
args.cache_port = cache_port
194204
args.metric_port = metric_port
195205

@@ -218,14 +228,33 @@ def normal_or_p_d_start(args):
218228
],
219229
start_args=[(cache_port, args)],
220230
)
221-
process_manager.start_submodule_processes(
222-
start_funcs=[
223-
start_visual_process,
224-
],
225-
start_args=[
226-
(args, router_port, visual_port, cache_port, visual_model_tp_ports),
227-
],
228-
)
231+
if args.enable_multimodal_audio:
232+
process_manager.start_submodule_processes(
233+
start_funcs=[
234+
start_visual_process,
235+
],
236+
start_args=[
237+
(args, audio_port, visual_port, cache_port, visual_model_tp_ports),
238+
],
239+
)
240+
process_manager.start_submodule_processes(
241+
start_funcs=[
242+
start_audio_process,
243+
],
244+
start_args=[
245+
(args, router_port, audio_port, cache_port),
246+
],
247+
)
248+
249+
else:
250+
process_manager.start_submodule_processes(
251+
start_funcs=[
252+
start_visual_process,
253+
],
254+
start_args=[
255+
(args, router_port, visual_port, cache_port, visual_model_tp_ports),
256+
],
257+
)
229258

230259
process_manager.start_submodule_processes(
231260
start_funcs=[

lightllm/server/audioserver/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)