diff --git a/pyproject.toml b/pyproject.toml index 97a4ae1..74532e7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,7 +35,7 @@ dependencies = [ [project.optional-dependencies] vllm = [ - "vllm==0.14.0", + "vllm>=0.15.2rc1", ] [project.urls] diff --git a/qwen_asr/core/vllm_backend/qwen3_asr.py b/qwen_asr/core/vllm_backend/qwen3_asr.py index db36219..631dfe0 100644 --- a/qwen_asr/core/vllm_backend/qwen3_asr.py +++ b/qwen_asr/core/vllm_backend/qwen3_asr.py @@ -198,7 +198,6 @@ def __init__( num_heads=self.num_local_heads, head_size=self.head_dim, scale=self.scaling, - multimodal_config=multimodal_config, ) def forward( @@ -359,15 +358,9 @@ def __init__( self.proj2 = nn.Linear(config.d_model, config.output_dim) # Get attention backend - attn_backend_override = ( - multimodal_config.mm_encoder_attn_backend - if multimodal_config is not None - else None - ) self.attn_backend = get_vit_attn_backend( head_size=config.d_model // config.encoder_attention_heads, dtype=torch.get_default_dtype(), - attn_backend_override=attn_backend_override, ) def compute_attn_mask_seqlen(self, cu_seqlens: torch.Tensor) -> torch.Tensor | None: @@ -553,6 +546,13 @@ def get_feature_extractor(self, **kwargs: object) -> WhisperFeatureExtractor: def get_supported_mm_limits(self) -> Mapping[str, int | None]: return {"audio": None} + def build_data_parser(self) -> MultiModalDataParser: + """Build data parser for vllm >= 0.15.2""" + feature_extractor = self.get_feature_extractor() + return Qwen3ASRMultiModalDataParser( + target_sr=feature_extractor.sampling_rate, + ) + class Qwen3ASRDummyInputsBuilder(BaseDummyInputsBuilder[Qwen3ASRProcessingInfo]): def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: @@ -622,12 +622,6 @@ def _parse_audio_data( class Qwen3ASRMultiModalProcessor( Qwen3OmniMoeThinkerMultiModalProcessor, ): - def _get_data_parser(self) -> MultiModalDataParser: - feature_extractor = self.info.get_feature_extractor() - return Qwen3ASRMultiModalDataParser( - target_sr=feature_extractor.sampling_rate, - ) - def _get_mm_fields_config( self, hf_inputs: BatchFeature, @@ -994,4 +988,4 @@ def get_generation_prompt( "prompt_token_ids": prompt_token_ids, "multi_modal_data": {"audio": audio}, } - return cast(PromptType, prompt_dict) \ No newline at end of file + return cast(PromptType, prompt_dict) diff --git a/qwen_asr/inference/qwen3_asr.py b/qwen_asr/inference/qwen3_asr.py index d99b915..c6639d3 100644 --- a/qwen_asr/inference/qwen3_asr.py +++ b/qwen_asr/inference/qwen3_asr.py @@ -47,10 +47,18 @@ ) try: - from qwen_asr.core.vllm_backend import Qwen3ASRForConditionalGeneration from vllm import ModelRegistry - ModelRegistry.register_model("Qwen3ASRForConditionalGeneration", Qwen3ASRForConditionalGeneration) -except: + + try: + from vllm.transformers_utils.config import _CONFIG_REGISTRY + _CONFIG_REGISTRY["qwen3_asr"] = Qwen3ASRConfig + except (ImportError, AttributeError): + pass + + model_class_path = "qwen_asr.core.vllm_backend.qwen3_asr:Qwen3ASRForConditionalGeneration" + ModelRegistry.register_model("qwen3_asr", model_class_path) + ModelRegistry.register_model("Qwen3ASRForConditionalGeneration", model_class_path) +except Exception as e: pass @@ -827,4 +835,4 @@ def finish_streaming_transcribe(self, state: ASRStreamingState) -> ASRStreamingS state.text = txt state.chunk_id += 1 - return state \ No newline at end of file + return state