22
22
# See the License for the specific language governing permissions and
23
23
# limitations under the License.
24
24
"""Inference-only MiDashengLM model compatible with HuggingFace weights."""
25
+
25
26
import collections
26
27
import collections .abc
27
28
from collections .abc import Iterable , Mapping , Sequence
30
31
import numpy as np
31
32
import torch
32
33
import torch .nn as nn
33
- import torchaudio .transforms as audio_transforms
34
+ import torchaudio .functional as F
35
+ from torch .nn .functional import scaled_dot_product_attention
34
36
from transformers import BatchFeature
35
37
36
- from vllm .attention .layer import MultiHeadAttention
37
38
from vllm .config import VllmConfig
38
39
from vllm .distributed import get_tensor_model_parallel_world_size
39
40
from vllm .model_executor .layers .activation import get_act_fn
40
41
from vllm .model_executor .layers .linear import (ColumnParallelLinear ,
41
42
QKVParallelLinear ,
42
43
RowParallelLinear )
43
44
from vllm .model_executor .layers .quantization import QuantizationConfig
44
- from vllm .model_executor .model_loader .utils import set_default_torch_dtype
45
45
from vllm .multimodal import MULTIMODAL_REGISTRY
46
46
from vllm .multimodal .inputs import (MultiModalDataDict , MultiModalFieldConfig ,
47
47
MultiModalKwargsItems )
@@ -147,15 +147,19 @@ def __init__(
147
147
super ().__init__ ()
148
148
out_features = out_features or in_features
149
149
hidden_features = hidden_features or in_features
150
- self .fc1 = ColumnParallelLinear (input_size = in_features ,
151
- output_size = hidden_features ,
152
- quant_config = quant_config ,
153
- prefix = f"{ prefix } .fc1" )
150
+ self .fc1 = ColumnParallelLinear (
151
+ input_size = in_features ,
152
+ output_size = hidden_features ,
153
+ quant_config = quant_config ,
154
+ prefix = f"{ prefix } .fc1" ,
155
+ )
154
156
self .act = get_act_fn ("gelu" )
155
- self .fc2 = RowParallelLinear (input_size = hidden_features ,
156
- output_size = out_features ,
157
- quant_config = quant_config ,
158
- prefix = f"{ prefix } .fc2" )
157
+ self .fc2 = RowParallelLinear (
158
+ input_size = hidden_features ,
159
+ output_size = out_features ,
160
+ quant_config = quant_config ,
161
+ prefix = f"{ prefix } .fc2" ,
162
+ )
159
163
160
164
def forward (self , x : torch .Tensor ) -> torch .Tensor :
161
165
x , _ = self .fc1 (x )
@@ -171,7 +175,6 @@ def __init__(
171
175
dim : int ,
172
176
num_heads : int = 8 ,
173
177
qkv_bias : bool = False ,
174
- causal : bool = False ,
175
178
quant_config : Optional [QuantizationConfig ] = None ,
176
179
prefix : str = "" ,
177
180
):
@@ -205,33 +208,30 @@ def __init__(
205
208
quant_config = quant_config ,
206
209
prefix = f"{ prefix } .qkv" ,
207
210
)
208
- self .attn = MultiHeadAttention (
209
- self .num_heads ,
210
- self .head_dim ,
211
- self .scale ,
212
- num_kv_heads = self .num_kv_heads ,
213
- )
214
211
self .proj = RowParallelLinear (
215
212
input_size = dim ,
216
213
output_size = dim ,
217
214
quant_config = quant_config ,
218
215
prefix = f"{ prefix } .proj" ,
219
216
)
220
- self .causal = causal
221
217
222
218
def forward (self , x : torch .Tensor , mask : Optional [torch .Tensor ] = None ):
223
219
B , N , C = x .shape
224
220
225
- qkv_out , _ = self .qkv (x )
226
- q , k , v = qkv_out .split ([self .q_size , self .kv_size , self .kv_size ],
227
- dim = - 1 )
221
+ qkv , _ = self .qkv (x )
222
+ qkv = qkv .reshape (B , N , 3 , self .num_heads , C // self .num_heads )
223
+ qkv = qkv .permute (2 , 0 , 3 , 1 , 4 )
224
+ q , k , v = qkv .unbind (0 )
228
225
229
- attn_out = self .attn (q , k , v )
230
- C_local = attn_out .numel () // (B * N ) # C_local for parallel
231
- attn_out = attn_out .view (B , N , C_local )
232
-
233
- x , _ = self .proj (attn_out )
226
+ x = scaled_dot_product_attention (
227
+ q ,
228
+ k ,
229
+ v ,
230
+ attn_mask = mask [:, None , None , :] if mask is not None else None ,
231
+ )
234
232
233
+ x = x .transpose (1 , 2 ).reshape (B , N , C )
234
+ x , _ = self .proj (x )
235
235
return x
236
236
237
237
@@ -280,6 +280,63 @@ def forward(
280
280
return x
281
281
282
282
283
+ class DashengFrontend (nn .Module ):
284
+
285
+ def __init__ (self , config : DashengConfig ):
286
+ super ().__init__ ()
287
+ self .config = config
288
+
289
+ spectrogram_window = torch .hann_window (self .config .win_length )
290
+ self .register_buffer (
291
+ "spectrogram_window" ,
292
+ spectrogram_window ,
293
+ persistent = False ,
294
+ )
295
+ self .spectrogram_window : torch .Tensor
296
+
297
+ melscale_fbanks = F .melscale_fbanks (
298
+ n_freqs = self .config .n_fft // 2 + 1 ,
299
+ f_min = self .config .f_min ,
300
+ f_max = self .config .f_max ,
301
+ n_mels = self .config .n_mels ,
302
+ sample_rate = self .config .sample_rate ,
303
+ )
304
+ self .register_buffer ("melscale_fbanks" ,
305
+ melscale_fbanks ,
306
+ persistent = False )
307
+ self .melscale_fbanks : torch .Tensor
308
+
309
+ def forward (self , waveform : torch .Tensor ) -> torch .Tensor :
310
+ spectrogram = F .spectrogram (
311
+ waveform = waveform .to (torch .float32 ),
312
+ pad = 0 ,
313
+ window = self .spectrogram_window ,
314
+ n_fft = self .config .n_fft ,
315
+ hop_length = self .config .hop_length ,
316
+ win_length = self .config .win_length ,
317
+ power = 2 ,
318
+ normalized = False ,
319
+ center = self .config .center ,
320
+ )
321
+ mel_spectrogram = (
322
+ spectrogram .mT @ self .melscale_fbanks .to (torch .float32 )).mT
323
+ # x has shape [batch, freq, time].
324
+ # F.amplitude_to_DB accepts inputs shaped as:
325
+ # - [freq, time]
326
+ # - [channel, freq, time]
327
+ # - [..., channel, freq, time]
328
+ # Here we insert a channel dimension of size 1 before calling it,
329
+ # then remove that extra dimension afterward.
330
+ log_mel_spectrogram = F .amplitude_to_DB (
331
+ mel_spectrogram .unsqueeze (1 ),
332
+ multiplier = 10 ,
333
+ amin = 1e-10 ,
334
+ db_multiplier = 0 ,
335
+ top_db = 120 ,
336
+ ).squeeze (1 )
337
+ return log_mel_spectrogram .to (waveform .dtype )
338
+
339
+
283
340
class DashengAudioTransformer (nn .Module ):
284
341
285
342
def __init__ (
@@ -293,7 +350,7 @@ def __init__(
293
350
self .target_length = config .target_length
294
351
self .hop_length = config .hop_length
295
352
296
- self ._init_front_end (config )
353
+ self .front_end = DashengFrontend (config )
297
354
298
355
self .init_bn = nn .BatchNorm2d (config .n_mels , momentum = 0.01 )
299
356
@@ -318,34 +375,10 @@ def __init__(
318
375
qkv_bias = config .qkv_bias ,
319
376
init_values = config .init_values ,
320
377
quant_config = quant_config ,
321
- prefix = f"{ prefix } .block { i } " ,
378
+ prefix = f"{ prefix } .blocks. { i } " ,
322
379
) for i in range (config .depth ))
323
380
self .norm = nn .LayerNorm (config .embed_dim , eps = 1e-6 )
324
381
325
- def _init_front_end (self , config ):
326
- with set_default_torch_dtype (torch .float32 ):
327
- self .front_end = nn .Sequential (
328
- audio_transforms .MelSpectrogram (
329
- f_min = config .f_min ,
330
- f_max = config .f_max ,
331
- center = config .center ,
332
- win_length = config .win_length ,
333
- hop_length = config .hop_length ,
334
- sample_rate = config .sample_rate ,
335
- n_fft = config .n_fft ,
336
- n_mels = config .n_mels ,
337
- ),
338
- audio_transforms .AmplitudeToDB (top_db = 120 ),
339
- )
340
-
341
- mel_spectrogram = self .front_end [0 ]
342
- fb = mel_spectrogram .mel_scale .fb
343
- win = mel_spectrogram .spectrogram .window
344
- mel_spectrogram .mel_scale .fb = fb .to (torch .bfloat16 ).to (
345
- torch .float32 )
346
- mel_spectrogram .spectrogram .window = win .to (torch .bfloat16 ).to (
347
- torch .float32 )
348
-
349
382
def forward_features (
350
383
self ,
351
384
x : torch .Tensor ,
@@ -430,14 +463,16 @@ def __init__(
430
463
quant_config = quant_config ,
431
464
prefix = f"{ prefix } .net.0" ,
432
465
return_bias = False ,
433
- ), get_act_fn ("gelu" ),
466
+ ),
467
+ get_act_fn ("gelu" ),
434
468
RowParallelLinear (
435
469
input_size = out_dim ,
436
470
output_size = out_dim ,
437
471
quant_config = quant_config ,
438
472
prefix = f"{ prefix } .net.2" ,
439
473
return_bias = False ,
440
- ))
474
+ ),
475
+ )
441
476
442
477
def forward (self , x , mask = None ):
443
478
batch_size , seq_len , dim = x .shape
@@ -534,9 +569,12 @@ def _call_hf_processor(
534
569
# + Padding
535
570
min_audio_len = self .info .get_min_audio_len ()
536
571
processed_audios = [
537
- np .pad (audio , (0 , min_audio_len - audio .shape [- 1 ]),
538
- mode = 'constant' ,
539
- constant_values = 0 ) if isinstance (audio , np .ndarray )
572
+ np .pad (
573
+ audio ,
574
+ (0 , min_audio_len - audio .shape [- 1 ]),
575
+ mode = "constant" ,
576
+ constant_values = 0 ,
577
+ ) if isinstance (audio , np .ndarray )
540
578
and audio .shape [- 1 ] < min_audio_len else audio for audio in audios
541
579
]
542
580
@@ -585,8 +623,8 @@ def _get_prompt_updates(
585
623
if audio_length is None :
586
624
audio_output_lengths = []
587
625
else :
588
- audio_length_np = audio_length .cpu ().numpy () if isinstance (
589
- audio_length , torch .Tensor ) else audio_length
626
+ audio_length_np = ( audio_length .cpu ().numpy () if isinstance (
627
+ audio_length , torch .Tensor ) else audio_length )
590
628
audio_output_lengths = [
591
629
max (1 , calculate_mel_frames_dasheng (
592
630
int (length ))) # at least one frame
@@ -617,6 +655,17 @@ def get_replacement_midashenglm(item_idx: int):
617
655
dummy_inputs = MiDashengLMDummyInputsBuilder ,
618
656
)
619
657
class MiDashengLMModel (nn .Module , SupportsMultiModal , SupportsPP ):
658
+ packed_modules_mapping = {
659
+ "qkv_proj" : [
660
+ "q_proj" ,
661
+ "k_proj" ,
662
+ "v_proj" ,
663
+ ],
664
+ "gate_up_proj" : [
665
+ "gate_proj" ,
666
+ "up_proj" ,
667
+ ],
668
+ }
620
669
621
670
@classmethod
622
671
def get_placeholder_str (cls , modality : str , i : int ) -> Optional [str ]:
@@ -660,8 +709,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
660
709
def _validate_and_reshape_mm_tensor (self , mm_input : object ,
661
710
name : str ) -> torch .Tensor :
662
711
if not isinstance (mm_input , (torch .Tensor , list )):
663
- raise ValueError (f"Incorrect type of { name } . "
664
- f" Got type: { type (mm_input )} " )
712
+ raise ValueError (
713
+ f"Incorrect type of { name } . Got type: { type (mm_input )} " )
665
714
if isinstance (mm_input , torch .Tensor ):
666
715
return mm_input .reshape (- 1 , * mm_input .shape [2 :])
667
716
@@ -710,8 +759,8 @@ def _process_audio_input(
710
759
audio_input ["input_values" ].dtype )
711
760
batch_size , max_audio_tokens , embed_dim = audio_embeddings .shape
712
761
713
- audio_length_np = audio_length .cpu ().numpy () if isinstance (
714
- audio_length , torch .Tensor ) else audio_length
762
+ audio_length_np = ( audio_length .cpu ().numpy () if isinstance (
763
+ audio_length , torch .Tensor ) else audio_length )
715
764
audio_output_lengths = [
716
765
max (1 , calculate_mel_frames_dasheng (
717
766
int (length ))) # at least one frame
@@ -720,11 +769,11 @@ def _process_audio_input(
720
769
audio_output_lengths = torch .tensor (audio_output_lengths ).to (
721
770
audio_embeddings .device )
722
771
723
- audio_feature_mask = ( torch .arange (
772
+ audio_feature_mask = torch .arange (
724
773
max_audio_tokens ,
725
774
device = audio_embeddings .device ).unsqueeze (0 ).expand (
726
- batch_size , max_audio_tokens )
727
- < audio_output_lengths .unsqueeze (1 ) )
775
+ batch_size ,
776
+ max_audio_tokens ) < audio_output_lengths .unsqueeze (1 )
728
777
729
778
masked_audio_features = audio_embeddings [audio_feature_mask ].view (
730
779
- 1 , embed_dim )
@@ -762,10 +811,12 @@ def forward(
762
811
)
763
812
input_ids = None
764
813
765
- return self .decoder .model (input_ids ,
766
- positions ,
767
- intermediate_tensors ,
768
- inputs_embeds = inputs_embeds )
814
+ return self .decoder .model (
815
+ input_ids ,
816
+ positions ,
817
+ intermediate_tensors ,
818
+ inputs_embeds = inputs_embeds ,
819
+ )
769
820
770
821
def compute_logits (
771
822
self ,
0 commit comments