Skip to content

Commit 8616300

Browse files
authored
[Model][Bugfix] Fix issues in MiDashengLM implementation for quantized models (vllm-project#25854)
Signed-off-by: zhoukz <[email protected]>
1 parent edbaadd commit 8616300

File tree

1 file changed

+122
-71
lines changed

1 file changed

+122
-71
lines changed

vllm/model_executor/models/midashenglm.py

Lines changed: 122 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
# See the License for the specific language governing permissions and
2323
# limitations under the License.
2424
"""Inference-only MiDashengLM model compatible with HuggingFace weights."""
25+
2526
import collections
2627
import collections.abc
2728
from collections.abc import Iterable, Mapping, Sequence
@@ -30,18 +31,17 @@
3031
import numpy as np
3132
import torch
3233
import torch.nn as nn
33-
import torchaudio.transforms as audio_transforms
34+
import torchaudio.functional as F
35+
from torch.nn.functional import scaled_dot_product_attention
3436
from transformers import BatchFeature
3537

36-
from vllm.attention.layer import MultiHeadAttention
3738
from vllm.config import VllmConfig
3839
from vllm.distributed import get_tensor_model_parallel_world_size
3940
from vllm.model_executor.layers.activation import get_act_fn
4041
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
4142
QKVParallelLinear,
4243
RowParallelLinear)
4344
from vllm.model_executor.layers.quantization import QuantizationConfig
44-
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
4545
from vllm.multimodal import MULTIMODAL_REGISTRY
4646
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
4747
MultiModalKwargsItems)
@@ -147,15 +147,19 @@ def __init__(
147147
super().__init__()
148148
out_features = out_features or in_features
149149
hidden_features = hidden_features or in_features
150-
self.fc1 = ColumnParallelLinear(input_size=in_features,
151-
output_size=hidden_features,
152-
quant_config=quant_config,
153-
prefix=f"{prefix}.fc1")
150+
self.fc1 = ColumnParallelLinear(
151+
input_size=in_features,
152+
output_size=hidden_features,
153+
quant_config=quant_config,
154+
prefix=f"{prefix}.fc1",
155+
)
154156
self.act = get_act_fn("gelu")
155-
self.fc2 = RowParallelLinear(input_size=hidden_features,
156-
output_size=out_features,
157-
quant_config=quant_config,
158-
prefix=f"{prefix}.fc2")
157+
self.fc2 = RowParallelLinear(
158+
input_size=hidden_features,
159+
output_size=out_features,
160+
quant_config=quant_config,
161+
prefix=f"{prefix}.fc2",
162+
)
159163

160164
def forward(self, x: torch.Tensor) -> torch.Tensor:
161165
x, _ = self.fc1(x)
@@ -171,7 +175,6 @@ def __init__(
171175
dim: int,
172176
num_heads: int = 8,
173177
qkv_bias: bool = False,
174-
causal: bool = False,
175178
quant_config: Optional[QuantizationConfig] = None,
176179
prefix: str = "",
177180
):
@@ -205,33 +208,30 @@ def __init__(
205208
quant_config=quant_config,
206209
prefix=f"{prefix}.qkv",
207210
)
208-
self.attn = MultiHeadAttention(
209-
self.num_heads,
210-
self.head_dim,
211-
self.scale,
212-
num_kv_heads=self.num_kv_heads,
213-
)
214211
self.proj = RowParallelLinear(
215212
input_size=dim,
216213
output_size=dim,
217214
quant_config=quant_config,
218215
prefix=f"{prefix}.proj",
219216
)
220-
self.causal = causal
221217

222218
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None):
223219
B, N, C = x.shape
224220

225-
qkv_out, _ = self.qkv(x)
226-
q, k, v = qkv_out.split([self.q_size, self.kv_size, self.kv_size],
227-
dim=-1)
221+
qkv, _ = self.qkv(x)
222+
qkv = qkv.reshape(B, N, 3, self.num_heads, C // self.num_heads)
223+
qkv = qkv.permute(2, 0, 3, 1, 4)
224+
q, k, v = qkv.unbind(0)
228225

229-
attn_out = self.attn(q, k, v)
230-
C_local = attn_out.numel() // (B * N) # C_local for parallel
231-
attn_out = attn_out.view(B, N, C_local)
232-
233-
x, _ = self.proj(attn_out)
226+
x = scaled_dot_product_attention(
227+
q,
228+
k,
229+
v,
230+
attn_mask=mask[:, None, None, :] if mask is not None else None,
231+
)
234232

233+
x = x.transpose(1, 2).reshape(B, N, C)
234+
x, _ = self.proj(x)
235235
return x
236236

237237

@@ -280,6 +280,63 @@ def forward(
280280
return x
281281

282282

283+
class DashengFrontend(nn.Module):
284+
285+
def __init__(self, config: DashengConfig):
286+
super().__init__()
287+
self.config = config
288+
289+
spectrogram_window = torch.hann_window(self.config.win_length)
290+
self.register_buffer(
291+
"spectrogram_window",
292+
spectrogram_window,
293+
persistent=False,
294+
)
295+
self.spectrogram_window: torch.Tensor
296+
297+
melscale_fbanks = F.melscale_fbanks(
298+
n_freqs=self.config.n_fft // 2 + 1,
299+
f_min=self.config.f_min,
300+
f_max=self.config.f_max,
301+
n_mels=self.config.n_mels,
302+
sample_rate=self.config.sample_rate,
303+
)
304+
self.register_buffer("melscale_fbanks",
305+
melscale_fbanks,
306+
persistent=False)
307+
self.melscale_fbanks: torch.Tensor
308+
309+
def forward(self, waveform: torch.Tensor) -> torch.Tensor:
310+
spectrogram = F.spectrogram(
311+
waveform=waveform.to(torch.float32),
312+
pad=0,
313+
window=self.spectrogram_window,
314+
n_fft=self.config.n_fft,
315+
hop_length=self.config.hop_length,
316+
win_length=self.config.win_length,
317+
power=2,
318+
normalized=False,
319+
center=self.config.center,
320+
)
321+
mel_spectrogram = (
322+
spectrogram.mT @ self.melscale_fbanks.to(torch.float32)).mT
323+
# x has shape [batch, freq, time].
324+
# F.amplitude_to_DB accepts inputs shaped as:
325+
# - [freq, time]
326+
# - [channel, freq, time]
327+
# - [..., channel, freq, time]
328+
# Here we insert a channel dimension of size 1 before calling it,
329+
# then remove that extra dimension afterward.
330+
log_mel_spectrogram = F.amplitude_to_DB(
331+
mel_spectrogram.unsqueeze(1),
332+
multiplier=10,
333+
amin=1e-10,
334+
db_multiplier=0,
335+
top_db=120,
336+
).squeeze(1)
337+
return log_mel_spectrogram.to(waveform.dtype)
338+
339+
283340
class DashengAudioTransformer(nn.Module):
284341

285342
def __init__(
@@ -293,7 +350,7 @@ def __init__(
293350
self.target_length = config.target_length
294351
self.hop_length = config.hop_length
295352

296-
self._init_front_end(config)
353+
self.front_end = DashengFrontend(config)
297354

298355
self.init_bn = nn.BatchNorm2d(config.n_mels, momentum=0.01)
299356

@@ -318,34 +375,10 @@ def __init__(
318375
qkv_bias=config.qkv_bias,
319376
init_values=config.init_values,
320377
quant_config=quant_config,
321-
prefix=f"{prefix}.block{i}",
378+
prefix=f"{prefix}.blocks.{i}",
322379
) for i in range(config.depth))
323380
self.norm = nn.LayerNorm(config.embed_dim, eps=1e-6)
324381

325-
def _init_front_end(self, config):
326-
with set_default_torch_dtype(torch.float32):
327-
self.front_end = nn.Sequential(
328-
audio_transforms.MelSpectrogram(
329-
f_min=config.f_min,
330-
f_max=config.f_max,
331-
center=config.center,
332-
win_length=config.win_length,
333-
hop_length=config.hop_length,
334-
sample_rate=config.sample_rate,
335-
n_fft=config.n_fft,
336-
n_mels=config.n_mels,
337-
),
338-
audio_transforms.AmplitudeToDB(top_db=120),
339-
)
340-
341-
mel_spectrogram = self.front_end[0]
342-
fb = mel_spectrogram.mel_scale.fb
343-
win = mel_spectrogram.spectrogram.window
344-
mel_spectrogram.mel_scale.fb = fb.to(torch.bfloat16).to(
345-
torch.float32)
346-
mel_spectrogram.spectrogram.window = win.to(torch.bfloat16).to(
347-
torch.float32)
348-
349382
def forward_features(
350383
self,
351384
x: torch.Tensor,
@@ -430,14 +463,16 @@ def __init__(
430463
quant_config=quant_config,
431464
prefix=f"{prefix}.net.0",
432465
return_bias=False,
433-
), get_act_fn("gelu"),
466+
),
467+
get_act_fn("gelu"),
434468
RowParallelLinear(
435469
input_size=out_dim,
436470
output_size=out_dim,
437471
quant_config=quant_config,
438472
prefix=f"{prefix}.net.2",
439473
return_bias=False,
440-
))
474+
),
475+
)
441476

442477
def forward(self, x, mask=None):
443478
batch_size, seq_len, dim = x.shape
@@ -534,9 +569,12 @@ def _call_hf_processor(
534569
# + Padding
535570
min_audio_len = self.info.get_min_audio_len()
536571
processed_audios = [
537-
np.pad(audio, (0, min_audio_len - audio.shape[-1]),
538-
mode='constant',
539-
constant_values=0) if isinstance(audio, np.ndarray)
572+
np.pad(
573+
audio,
574+
(0, min_audio_len - audio.shape[-1]),
575+
mode="constant",
576+
constant_values=0,
577+
) if isinstance(audio, np.ndarray)
540578
and audio.shape[-1] < min_audio_len else audio for audio in audios
541579
]
542580

@@ -585,8 +623,8 @@ def _get_prompt_updates(
585623
if audio_length is None:
586624
audio_output_lengths = []
587625
else:
588-
audio_length_np = audio_length.cpu().numpy() if isinstance(
589-
audio_length, torch.Tensor) else audio_length
626+
audio_length_np = (audio_length.cpu().numpy() if isinstance(
627+
audio_length, torch.Tensor) else audio_length)
590628
audio_output_lengths = [
591629
max(1, calculate_mel_frames_dasheng(
592630
int(length))) # at least one frame
@@ -617,6 +655,17 @@ def get_replacement_midashenglm(item_idx: int):
617655
dummy_inputs=MiDashengLMDummyInputsBuilder,
618656
)
619657
class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP):
658+
packed_modules_mapping = {
659+
"qkv_proj": [
660+
"q_proj",
661+
"k_proj",
662+
"v_proj",
663+
],
664+
"gate_up_proj": [
665+
"gate_proj",
666+
"up_proj",
667+
],
668+
}
620669

621670
@classmethod
622671
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
@@ -660,8 +709,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
660709
def _validate_and_reshape_mm_tensor(self, mm_input: object,
661710
name: str) -> torch.Tensor:
662711
if not isinstance(mm_input, (torch.Tensor, list)):
663-
raise ValueError(f"Incorrect type of {name}. "
664-
f"Got type: {type(mm_input)}")
712+
raise ValueError(
713+
f"Incorrect type of {name}. Got type: {type(mm_input)}")
665714
if isinstance(mm_input, torch.Tensor):
666715
return mm_input.reshape(-1, *mm_input.shape[2:])
667716

@@ -710,8 +759,8 @@ def _process_audio_input(
710759
audio_input["input_values"].dtype)
711760
batch_size, max_audio_tokens, embed_dim = audio_embeddings.shape
712761

713-
audio_length_np = audio_length.cpu().numpy() if isinstance(
714-
audio_length, torch.Tensor) else audio_length
762+
audio_length_np = (audio_length.cpu().numpy() if isinstance(
763+
audio_length, torch.Tensor) else audio_length)
715764
audio_output_lengths = [
716765
max(1, calculate_mel_frames_dasheng(
717766
int(length))) # at least one frame
@@ -720,11 +769,11 @@ def _process_audio_input(
720769
audio_output_lengths = torch.tensor(audio_output_lengths).to(
721770
audio_embeddings.device)
722771

723-
audio_feature_mask = (torch.arange(
772+
audio_feature_mask = torch.arange(
724773
max_audio_tokens,
725774
device=audio_embeddings.device).unsqueeze(0).expand(
726-
batch_size, max_audio_tokens)
727-
< audio_output_lengths.unsqueeze(1))
775+
batch_size,
776+
max_audio_tokens) < audio_output_lengths.unsqueeze(1)
728777

729778
masked_audio_features = audio_embeddings[audio_feature_mask].view(
730779
-1, embed_dim)
@@ -762,10 +811,12 @@ def forward(
762811
)
763812
input_ids = None
764813

765-
return self.decoder.model(input_ids,
766-
positions,
767-
intermediate_tensors,
768-
inputs_embeds=inputs_embeds)
814+
return self.decoder.model(
815+
input_ids,
816+
positions,
817+
intermediate_tensors,
818+
inputs_embeds=inputs_embeds,
819+
)
769820

770821
def compute_logits(
771822
self,

0 commit comments

Comments
 (0)