Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,7 @@ class Qwen3ASRThinkerConfig(PretrainedConfig):
model_type = "qwen3_asr_thinker"

attribute_map = {}
pad_token_id = None
sub_configs = {
"audio_config": Qwen3ASRAudioEncoderConfig,
"text_config": Qwen3ASRTextConfig,
Expand Down
20 changes: 18 additions & 2 deletions qwen_asr/core/transformers_backend/modeling_qwen3_asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,13 @@
from transformers.processing_utils import Unpack
from transformers.utils import auto_docstring, can_return_tuple
from transformers.utils.deprecation import deprecate_kwarg
from transformers.utils.generic import TransformersKwargs, check_model_inputs
from transformers.utils.generic import TransformersKwargs

try:
from transformers.utils.generic import check_model_inputs
except ImportError:
def check_model_inputs(func=None, **_kw):
return (lambda f: f) if func is None else func

from .configuration_qwen3_asr import (
Qwen3ASRAudioEncoderConfig,
Expand Down Expand Up @@ -791,14 +797,24 @@ def __init__(self, config: Qwen3ASRConfig, device=None):
self.original_max_seq_len = config.max_position_embeddings

self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
if self.rope_type not in ROPE_INIT_FUNCTIONS:
self.rope_init_fn = self._compute_default_rope_parameters
else:
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]

inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq

self.mrope_section = config.rope_scaling.get("mrope_section", [24, 20, 20])

@staticmethod
def _compute_default_rope_parameters(config, device=None, seq_len=None, **_kw):
base = getattr(config, "rope_theta", 10000.0)
dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32).to(device) / dim))
return inv_freq, 1.0

def apply_interleaved_mrope(self, freqs, mrope_section):
"""Apply interleaved MRoPE to 3D rotary embeddings.
Reorganizes frequency layout from chunked [TTT...HHH...WWW] to
Expand Down