diff --git a/veomni/models/module_utils.py b/veomni/models/module_utils.py index 057189bf..2997c3e3 100644 --- a/veomni/models/module_utils.py +++ b/veomni/models/module_utils.py @@ -32,12 +32,13 @@ from ..utils.hdfs_io import copy from diffusers.utils import SAFE_WEIGHTS_INDEX_NAME as DIFFUSERS_SAFE_WEIGHTS_INDEX_NAME from diffusers.utils import SAFETENSORS_WEIGHTS_NAME as DIFFUSERS_SAFETENSORS_WEIGHTS_NAME +from safetensors import safe_open +from safetensors.torch import save_file from torch import distributed as dist from torch import nn from tqdm import tqdm from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, WEIGHTS_NAME from transformers.utils.hub import cached_file, get_checkpoint_shard_files -from transformers.utils.import_utils import is_safetensors_available from ..distributed.parallel_state import get_parallel_state from ..utils import logging @@ -45,11 +46,6 @@ from ..utils.helper import empty_cache, get_cache_dir, get_dtype_size -if is_safetensors_available(): - from safetensors import safe_open - from safetensors.torch import save_file - - if TYPE_CHECKING: from transformers import GenerationConfig, PretrainedConfig, PreTrainedModel, PreTrainedTokenizer, ProcessorMixin diff --git a/veomni/models/transformers/llama/modeling_llama.py b/veomni/models/transformers/llama/modeling_llama.py index f670af6f..47951399 100644 --- a/veomni/models/transformers/llama/modeling_llama.py +++ b/veomni/models/transformers/llama/modeling_llama.py @@ -33,7 +33,7 @@ BaseModelOutputWithPast, CausalLMOutputWithPast, ) -from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS +from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from transformers.models.llama.configuration_llama import LlamaConfig from transformers.processing_utils import Unpack @@ -85,62 +85,74 @@ def extra_repr(self): class LlamaRotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + def __init__(self, config: LlamaConfig, device=None): super().__init__() - # BC: "rope_type" was originally "type" - if hasattr(config, "rope_scaling") and config.rope_scaling is not None: - self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) - else: - self.rope_type = "default" self.max_seq_len_cached = config.max_position_embeddings self.original_max_seq_len = config.max_position_embeddings self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + if hasattr(self.config, "rope_parameters"): + self.rope_type = self.config.rope_parameters["rope_type"] + elif hasattr(self.config, "rope_scaling") and self.config.rope_scaling is not None: + self.rope_type = self.config.rope_scaling["rope_type"] + else: + self.rope_type = "default" + rope_init_fn: Callable = self.compute_default_rope_parameters + if self.rope_type != "default": + rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + inv_freq, self.attention_scaling = rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) - self.original_inv_freq = self.inv_freq + self.original_inv_freq = inv_freq - def _dynamic_frequency_update(self, position_ids, device): + @staticmethod + def compute_default_rope_parameters( + config: Optional[LlamaConfig] = None, + device: Optional["torch.device"] = None, + seq_len: Optional[int] = None, + ) -> tuple["torch.Tensor", float]: """ - dynamic RoPE layers should recompute `inv_freq` in the following situations: - 1 - growing beyond the cached sequence length (allow scaling) - 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + Computes the inverse frequencies according to the original RoPE implementation + Args: + config ([`~transformers.PreTrainedConfig`]): + The model configuration. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + seq_len (`int`, *optional*): + The current sequence length. Unused for this type of RoPE. + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). """ - seq_len = torch.max(position_ids) + 1 - if seq_len > self.max_seq_len_cached: # growth - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len) - self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation - self.max_seq_len_cached = seq_len - - if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset - # This .to() is needed if the model has been moved to a device after being initialized (because - # the buffer is automatically moved, but not the original copy) - self.original_inv_freq = self.original_inv_freq.to(device) - self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) - self.max_seq_len_cached = self.original_max_seq_len + if hasattr(config, "rope_parameters"): + base = config.rope_parameters["rope_theta"] + else: + base = config.rope_theta + dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads + + attention_factor = 1.0 # Unused in this type of RoPE + + # Compute the inverse frequencies + inv_freq = 1.0 / ( + base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim) + ) + return inv_freq, attention_factor @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) def forward(self, x, position_ids): - if "dynamic" in self.rope_type: - self._dynamic_frequency_update(position_ids, device=x.device) - - # Core RoPE block - inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) position_ids_expanded = position_ids[:, None, :].float() - # Force float32 (see https://github.com/huggingface/transformers/pull/29285) - device_type = x.device.type - device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() - sin = emb.sin() - - # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention - cos = cos * self.attention_scaling - sin = sin * self.attention_scaling + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) diff --git a/veomni/models/transformers/qwen2/modeling_qwen2.py b/veomni/models/transformers/qwen2/modeling_qwen2.py index cbd9117a..b7a676ae 100644 --- a/veomni/models/transformers/qwen2/modeling_qwen2.py +++ b/veomni/models/transformers/qwen2/modeling_qwen2.py @@ -27,7 +27,7 @@ BaseModelOutputWithPast, CausalLMOutputWithPast, ) -from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS +from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from transformers.models.qwen2.configuration_qwen2 import Qwen2Config from transformers.processing_utils import Unpack @@ -292,62 +292,74 @@ def forward( class Qwen2RotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + def __init__(self, config: Qwen2Config, device=None): super().__init__() - # BC: "rope_type" was originally "type" - if hasattr(config, "rope_scaling") and config.rope_scaling is not None: - self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) - else: - self.rope_type = "default" self.max_seq_len_cached = config.max_position_embeddings self.original_max_seq_len = config.max_position_embeddings self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self.original_inv_freq = self.inv_freq + if hasattr(self.config, "rope_parameters"): + self.rope_type = self.config.rope_parameters["rope_type"] + elif hasattr(self.config, "rope_scaling") and self.config.rope_scaling is not None: + self.rope_type = self.config.rope_scaling["rope_type"] + else: + self.rope_type = "default" + rope_init_fn: Callable = self.compute_default_rope_parameters + if self.rope_type != "default": + rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + inv_freq, self.attention_scaling = rope_init_fn(self.config, device) - def _dynamic_frequency_update(self, position_ids, device): + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = inv_freq + + @staticmethod + def compute_default_rope_parameters( + config: Optional[Qwen2Config] = None, + device: Optional["torch.device"] = None, + seq_len: Optional[int] = None, + ) -> tuple["torch.Tensor", float]: """ - dynamic RoPE layers should recompute `inv_freq` in the following situations: - 1 - growing beyond the cached sequence length (allow scaling) - 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + Computes the inverse frequencies according to the original RoPE implementation + Args: + config ([`~transformers.PreTrainedConfig`]): + The model configuration. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + seq_len (`int`, *optional*): + The current sequence length. Unused for this type of RoPE. + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). """ - seq_len = torch.max(position_ids) + 1 - if seq_len > self.max_seq_len_cached: # growth - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len) - self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation - self.max_seq_len_cached = seq_len - - if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset - # This .to() is needed if the model has been moved to a device after being initialized (because - # the buffer is automatically moved, but not the original copy) - self.original_inv_freq = self.original_inv_freq.to(device) - self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) - self.max_seq_len_cached = self.original_max_seq_len + if hasattr(config, "rope_parameters"): + base = config.rope_parameters["rope_theta"] + else: + base = config.rope_theta + dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads + + attention_factor = 1.0 # Unused in this type of RoPE + + # Compute the inverse frequencies + inv_freq = 1.0 / ( + base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim) + ) + return inv_freq, attention_factor @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) def forward(self, x, position_ids): - if "dynamic" in self.rope_type: - self._dynamic_frequency_update(position_ids, device=x.device) - - # Core RoPE block - inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) position_ids_expanded = position_ids[:, None, :].float() - # Force float32 (see https://github.com/huggingface/transformers/pull/29285) - device_type = x.device.type - device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() - sin = emb.sin() - - # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention - cos = cos * self.attention_scaling - sin = sin * self.attention_scaling + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) diff --git a/veomni/models/transformers/qwen2_5_omni/modeling_qwen2_5_omni.py b/veomni/models/transformers/qwen2_5_omni/modeling_qwen2_5_omni.py index 31a5c286..9511656f 100644 --- a/veomni/models/transformers/qwen2_5_omni/modeling_qwen2_5_omni.py +++ b/veomni/models/transformers/qwen2_5_omni/modeling_qwen2_5_omni.py @@ -19,7 +19,7 @@ from dataclasses import dataclass from functools import partial from types import SimpleNamespace -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import torch @@ -1548,25 +1548,63 @@ def dummy_forward(self): class Qwen2_5OmniRotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + def __init__(self, config: Qwen2_5OmniThinkerConfig, device=None): super().__init__() - # BC: "rope_type" was originally "type" - if hasattr(config, "rope_scaling") and config.rope_scaling is not None: - self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) - else: - self.rope_type = "default" self.max_seq_len_cached = config.max_position_embeddings self.original_max_seq_len = config.max_position_embeddings self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + if hasattr(self.config, "rope_parameters"): + self.rope_type = self.config.rope_parameters["rope_type"] + elif hasattr(self.config, "rope_scaling") and self.config.rope_scaling is not None: + self.rope_type = self.config.rope_scaling["rope_type"] + else: + self.rope_type = "default" + rope_init_fn: Callable = self.compute_default_rope_parameters + if self.rope_type != "default": + rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + inv_freq, self.attention_scaling = rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) - self.original_inv_freq = self.inv_freq + self.original_inv_freq = inv_freq - @torch.no_grad() - @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + @staticmethod + def compute_default_rope_parameters( + config: Optional[Qwen2_5OmniConfig] = None, + device: Optional["torch.device"] = None, + seq_len: Optional[int] = None, + ) -> tuple["torch.Tensor", float]: + """ + Computes the inverse frequencies according to the original RoPE implementation + Args: + config ([`~transformers.PreTrainedConfig`]): + The model configuration. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + seq_len (`int`, *optional*): + The current sequence length. Unused for this type of RoPE. + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). + """ + if hasattr(config, "rope_parameters"): + base = config.rope_parameters["rope_theta"] + else: + base = config.rope_theta + dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads + + attention_factor = 1.0 # Unused in this type of RoPE + + # Compute the inverse frequencies + inv_freq = 1.0 / ( + base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim) + ) + return inv_freq, attention_factor + + # Ignore copy def forward(self, x, position_ids): # In contrast to other models, Qwen2_5Omni has different position ids for the grids # So we expand the inv_freq to shape (3, ...) diff --git a/veomni/models/transformers/qwen2_5vl/modeling_qwen2_5_vl.py b/veomni/models/transformers/qwen2_5vl/modeling_qwen2_5_vl.py index 6cd53499..336f317f 100644 --- a/veomni/models/transformers/qwen2_5vl/modeling_qwen2_5_vl.py +++ b/veomni/models/transformers/qwen2_5vl/modeling_qwen2_5_vl.py @@ -23,7 +23,7 @@ from dataclasses import dataclass from functools import partial from types import SimpleNamespace -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn @@ -643,62 +643,75 @@ def dummy_forward(self): class Qwen2_5_VLRotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + def __init__(self, config: Qwen2_5_VLConfig, device=None): super().__init__() - # BC: "rope_type" was originally "type" - if hasattr(config, "rope_scaling") and config.rope_scaling is not None: - self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) - else: - self.rope_type = "default" self.max_seq_len_cached = config.max_position_embeddings self.original_max_seq_len = config.max_position_embeddings self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + if hasattr(self.config, "rope_parameters"): + self.rope_type = self.config.rope_parameters["rope_type"] + elif hasattr(self.config, "rope_scaling") and self.config.rope_scaling is not None: + self.rope_type = self.config.rope_scaling["rope_type"] + else: + self.rope_type = "default" + rope_init_fn: Callable = self.compute_default_rope_parameters + if self.rope_type != "default": + rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + inv_freq, self.attention_scaling = rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) - self.original_inv_freq = self.inv_freq + self.original_inv_freq = inv_freq - def _dynamic_frequency_update(self, position_ids, device): + @staticmethod + def compute_default_rope_parameters( + config: Optional[Qwen2_5_VLConfig] = None, + device: Optional["torch.device"] = None, + seq_len: Optional[int] = None, + ) -> tuple["torch.Tensor", float]: """ - dynamic RoPE layers should recompute `inv_freq` in the following situations: - 1 - growing beyond the cached sequence length (allow scaling) - 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + Computes the inverse frequencies according to the original RoPE implementation + Args: + config ([`~transformers.PreTrainedConfig`]): + The model configuration. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + seq_len (`int`, *optional*): + The current sequence length. Unused for this type of RoPE. + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). """ - seq_len = torch.max(position_ids) + 1 - if seq_len > self.max_seq_len_cached: # growth - inv_freq, self.attention_scaling = self.rope_init_fn( - self.config, device, seq_len=seq_len, **self.rope_kwargs - ) - self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation - self.max_seq_len_cached = seq_len + if hasattr(config, "rope_parameters"): + base = config.rope_parameters["rope_theta"] + else: + base = config.rope_theta + dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads - if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset - self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) - self.max_seq_len_cached = self.original_max_seq_len + attention_factor = 1.0 # Unused in this type of RoPE - @torch.no_grad() - def forward(self, x, position_ids): - if "dynamic" in self.rope_type: - self._dynamic_frequency_update(position_ids, device=x.device) + # Compute the inverse frequencies + inv_freq = 1.0 / ( + base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim) + ) + return inv_freq, attention_factor - # Core RoPE block. In contrast to other models, Qwen2_5_VL has different position ids for the grids + # Ignore copy + def forward(self, x, position_ids): + # In contrast to other models, Qwen2_5_VL has different position ids for the grids # So we expand the inv_freq to shape (3, ...) inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1) position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions) - # Force float32 (see https://github.com/huggingface/transformers/pull/29285) - device_type = x.device.type - device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3) emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() - sin = emb.sin() - - # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention - cos = cos * self.attention_scaling - sin = sin * self.attention_scaling + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) diff --git a/veomni/models/transformers/qwen2_vl/modeling_qwen2_vl.py b/veomni/models/transformers/qwen2_vl/modeling_qwen2_vl.py index 48faa109..d41f9bc2 100644 --- a/veomni/models/transformers/qwen2_vl/modeling_qwen2_vl.py +++ b/veomni/models/transformers/qwen2_vl/modeling_qwen2_vl.py @@ -124,62 +124,75 @@ class Qwen2VLCausalLMOutputWithPast(ModelOutput): class Qwen2VLRotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + def __init__(self, config: Qwen2VLConfig, device=None): super().__init__() - # BC: "rope_type" was originally "type" - if hasattr(config, "rope_scaling") and config.rope_scaling is not None: - self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) - else: - self.rope_type = "default" self.max_seq_len_cached = config.max_position_embeddings self.original_max_seq_len = config.max_position_embeddings self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + if hasattr(self.config, "rope_parameters"): + self.rope_type = self.config.rope_parameters["rope_type"] + elif hasattr(self.config, "rope_scaling") and self.config.rope_scaling is not None: + self.rope_type = self.config.rope_scaling["rope_type"] + else: + self.rope_type = "default" + rope_init_fn: Callable = self.compute_default_rope_parameters + if self.rope_type != "default": + rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + inv_freq, self.attention_scaling = rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) - self.original_inv_freq = self.inv_freq + self.original_inv_freq = inv_freq - def _dynamic_frequency_update(self, position_ids, device): + @staticmethod + def compute_default_rope_parameters( + config: Optional[Qwen2VLConfig] = None, + device: Optional["torch.device"] = None, + seq_len: Optional[int] = None, + ) -> tuple["torch.Tensor", float]: """ - dynamic RoPE layers should recompute `inv_freq` in the following situations: - 1 - growing beyond the cached sequence length (allow scaling) - 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + Computes the inverse frequencies according to the original RoPE implementation + Args: + config ([`~transformers.PreTrainedConfig`]): + The model configuration. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + seq_len (`int`, *optional*): + The current sequence length. Unused for this type of RoPE. + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). """ - seq_len = torch.max(position_ids) + 1 - if seq_len > self.max_seq_len_cached: # growth - inv_freq, self.attention_scaling = self.rope_init_fn( - self.config, device, seq_len=seq_len, **self.rope_kwargs - ) - self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation - self.max_seq_len_cached = seq_len + if hasattr(config, "rope_parameters"): + base = config.rope_parameters["rope_theta"] + else: + base = config.rope_theta + dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads - if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset - self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) - self.max_seq_len_cached = self.original_max_seq_len + attention_factor = 1.0 # Unused in this type of RoPE - @torch.no_grad() - def forward(self, x, position_ids): - if "dynamic" in self.rope_type: - self._dynamic_frequency_update(position_ids, device=x.device) + # Compute the inverse frequencies + inv_freq = 1.0 / ( + base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim) + ) + return inv_freq, attention_factor - # Core RoPE block. In contrast to other models, Qwen2_VL has different position ids for thw grids + # Ignore copy + def forward(self, x, position_ids): + # In contrast to other models, Qwen2_VL has different position ids for the grids # So we expand the inv_freq to shape (3, ...) inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1) position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions) - # Force float32 (see https://github.com/huggingface/transformers/pull/29285) - device_type = x.device.type - device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3) emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() - sin = emb.sin() - - # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention - cos = cos * self.attention_scaling - sin = sin * self.attention_scaling + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) diff --git a/veomni/models/transformers/qwen3/modeling_qwen3.py b/veomni/models/transformers/qwen3/modeling_qwen3.py index eafcbc55..2246a510 100644 --- a/veomni/models/transformers/qwen3/modeling_qwen3.py +++ b/veomni/models/transformers/qwen3/modeling_qwen3.py @@ -304,22 +304,61 @@ def forward( class Qwen3RotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + def __init__(self, config: Qwen3Config, device=None): super().__init__() - # BC: "rope_type" was originally "type" - if hasattr(config, "rope_scaling") and config.rope_scaling is not None: - self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) - else: - self.rope_type = "default" self.max_seq_len_cached = config.max_position_embeddings self.original_max_seq_len = config.max_position_embeddings self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + if hasattr(self.config, "rope_parameters"): + self.rope_type = self.config.rope_parameters["rope_type"] + elif hasattr(self.config, "rope_scaling") and self.config.rope_scaling is not None: + self.rope_type = self.config.rope_scaling["rope_type"] + else: + self.rope_type = "default" + rope_init_fn: Callable = self.compute_default_rope_parameters + if self.rope_type != "default": + rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + inv_freq, self.attention_scaling = rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) - self.original_inv_freq = self.inv_freq + self.original_inv_freq = inv_freq + + @staticmethod + def compute_default_rope_parameters( + config: Optional[Qwen3Config] = None, + device: Optional["torch.device"] = None, + seq_len: Optional[int] = None, + ) -> tuple["torch.Tensor", float]: + """ + Computes the inverse frequencies according to the original RoPE implementation + Args: + config ([`~transformers.PreTrainedConfig`]): + The model configuration. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + seq_len (`int`, *optional*): + The current sequence length. Unused for this type of RoPE. + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). + """ + if hasattr(config, "rope_parameters"): + base = config.rope_parameters["rope_theta"] + else: + base = config.rope_theta + dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads + + attention_factor = 1.0 # Unused in this type of RoPE + + # Compute the inverse frequencies + inv_freq = 1.0 / ( + base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim) + ) + return inv_freq, attention_factor @torch.no_grad() @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) diff --git a/veomni/models/transformers/qwen3_moe/modeling_qwen3_moe.py b/veomni/models/transformers/qwen3_moe/modeling_qwen3_moe.py index a7fb64b8..027b84a1 100644 --- a/veomni/models/transformers/qwen3_moe/modeling_qwen3_moe.py +++ b/veomni/models/transformers/qwen3_moe/modeling_qwen3_moe.py @@ -31,7 +31,7 @@ SequenceClassifierOutputWithPast, TokenClassifierOutput, ) -from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS +from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from transformers.processing_utils import Unpack from transformers.utils import ( @@ -86,62 +86,74 @@ def extra_repr(self): class Qwen3MoeRotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + def __init__(self, config: Qwen3MoeConfig, device=None): super().__init__() - # BC: "rope_type" was originally "type" - if hasattr(config, "rope_scaling") and config.rope_scaling is not None: - self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) - else: - self.rope_type = "default" self.max_seq_len_cached = config.max_position_embeddings self.original_max_seq_len = config.max_position_embeddings self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self.original_inv_freq = self.inv_freq + if hasattr(self.config, "rope_parameters"): + self.rope_type = self.config.rope_parameters["rope_type"] + elif hasattr(self.config, "rope_scaling") and self.config.rope_scaling is not None: + self.rope_type = self.config.rope_scaling["rope_type"] + else: + self.rope_type = "default" + rope_init_fn: Callable = self.compute_default_rope_parameters + if self.rope_type != "default": + rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + inv_freq, self.attention_scaling = rope_init_fn(self.config, device) - def _dynamic_frequency_update(self, position_ids, device): + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = inv_freq + + @staticmethod + def compute_default_rope_parameters( + config: Optional[Qwen3MoeConfig] = None, + device: Optional["torch.device"] = None, + seq_len: Optional[int] = None, + ) -> tuple["torch.Tensor", float]: """ - dynamic RoPE layers should recompute `inv_freq` in the following situations: - 1 - growing beyond the cached sequence length (allow scaling) - 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + Computes the inverse frequencies according to the original RoPE implementation + Args: + config ([`~transformers.PreTrainedConfig`]): + The model configuration. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + seq_len (`int`, *optional*): + The current sequence length. Unused for this type of RoPE. + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). """ - seq_len = torch.max(position_ids) + 1 - if seq_len > self.max_seq_len_cached: # growth - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len) - self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation - self.max_seq_len_cached = seq_len - - if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset - # This .to() is needed if the model has been moved to a device after being initialized (because - # the buffer is automatically moved, but not the original copy) - self.original_inv_freq = self.original_inv_freq.to(device) - self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) - self.max_seq_len_cached = self.original_max_seq_len + if hasattr(config, "rope_parameters"): + base = config.rope_parameters["rope_theta"] + else: + base = config.rope_theta + dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads + + attention_factor = 1.0 # Unused in this type of RoPE + + # Compute the inverse frequencies + inv_freq = 1.0 / ( + base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim) + ) + return inv_freq, attention_factor @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) def forward(self, x, position_ids): - if "dynamic" in self.rope_type: - self._dynamic_frequency_update(position_ids, device=x.device) - - # Core RoPE block - inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) position_ids_expanded = position_ids[:, None, :].float() - # Force float32 (see https://github.com/huggingface/transformers/pull/29285) - device_type = x.device.type - device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() - sin = emb.sin() - # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention - cos = cos * self.attention_scaling - sin = sin * self.attention_scaling + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) diff --git a/veomni/models/transformers/qwen3_vl/modeling_qwen3_vl.py b/veomni/models/transformers/qwen3_vl/modeling_qwen3_vl.py index a97f1e54..d7a045a4 100644 --- a/veomni/models/transformers/qwen3_vl/modeling_qwen3_vl.py +++ b/veomni/models/transformers/qwen3_vl/modeling_qwen3_vl.py @@ -314,38 +314,62 @@ class Qwen3VLTextRotaryEmbedding(nn.Module): def __init__(self, config: Qwen3VLTextConfig, device=None): super().__init__() - if hasattr(config, "rope_scaling") and config.rope_scaling is not None: - self.rope_type = config.rope_scaling.get("rope_type", "default") - else: - self.rope_type = "default" self.max_seq_len_cached = config.max_position_embeddings self.original_max_seq_len = config.max_position_embeddings self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self.original_inv_freq = self.inv_freq + if hasattr(self.config, "rope_parameters"): + self.rope_type = self.config.rope_parameters["rope_type"] + elif hasattr(self.config, "rope_scaling") and self.config.rope_scaling is not None: + self.rope_type = self.config.rope_scaling["rope_type"] + else: + self.rope_type = "default" + rope_init_fn: Callable = self.compute_default_rope_parameters + if self.rope_type != "default": + rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + inv_freq, self.attention_scaling = rope_init_fn(self.config, device) - self.mrope_section = config.rope_scaling.get("mrope_section", [24, 20, 20]) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = inv_freq - def apply_interleaved_mrope(self, freqs, mrope_section): - """Apply interleaved MRoPE to 3D rotary embeddings. - Reorganizes frequency layout from chunked [TTT...HHH...WWW] to - interleaved [THTHWHTHW...TT], preserving frequency continuity. - args: - x: (3, bs, seq_len, head_dim // 2) - mrope_section: (3,) - returns: - x_t: (bs, seq_len, head_dim // 2) + if hasattr(config, "rope_parameters"): + self.mrope_section = config.rope_parameters.get("mrope_section", [24, 20, 20]) + else: + self.mrope_section = config.rope_scaling.get("mrope_section", [24, 20, 20]) + + @staticmethod + def compute_default_rope_parameters( + config: Optional[Qwen3VLTextConfig] = None, + device: Optional["torch.device"] = None, + seq_len: Optional[int] = None, + ) -> tuple["torch.Tensor", float]: """ - freqs_t = freqs[0] # just overwrite the first dimension T - for dim, offset in enumerate((1, 2), start=1): # H, W - length = mrope_section[dim] * 3 - idx = slice(offset, length, 3) - freqs_t[..., idx] = freqs[dim, ..., idx] - return freqs_t + Computes the inverse frequencies according to the original RoPE implementation + Args: + config ([`~transformers.PreTrainedConfig`]): + The model configuration. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + seq_len (`int`, *optional*): + The current sequence length. Unused for this type of RoPE. + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). + """ + if hasattr(config, "rope_parameters"): + base = config.rope_parameters["rope_theta"] + else: + base = config.rope_theta + dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads + + attention_factor = 1.0 # Unused in this type of RoPE + + # Compute the inverse frequencies + inv_freq = 1.0 / ( + base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim) + ) + return inv_freq, attention_factor @torch.no_grad() @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) @@ -367,6 +391,23 @@ def forward(self, x, position_ids): return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + def apply_interleaved_mrope(self, freqs, mrope_section): + """Apply interleaved MRoPE to 3D rotary embeddings. + Reorganizes frequency layout from chunked [TTT...HHH...WWW] to + interleaved [THWTHWTHW...TT], preserving frequency continuity. + args: + x: (3, bs, seq_len, head_dim // 2) + mrope_section: (3,) + returns: + x_t: (bs, seq_len, head_dim // 2) + """ + freqs_t = freqs[0] # just overwrite the first dimension T + for dim, offset in enumerate((1, 2), start=1): # H, W + length = mrope_section[dim] * 3 + idx = slice(offset, length, 3) + freqs_t[..., idx] = freqs[dim, ..., idx] + return freqs_t + @use_kernel_forward_from_hub("RMSNorm") class Qwen3VLTextRMSNorm(nn.Module): diff --git a/veomni/models/transformers/qwen3_vl_moe/modeling_qwen3_vl_moe.py b/veomni/models/transformers/qwen3_vl_moe/modeling_qwen3_vl_moe.py index 18789dc8..e9317280 100644 --- a/veomni/models/transformers/qwen3_vl_moe/modeling_qwen3_vl_moe.py +++ b/veomni/models/transformers/qwen3_vl_moe/modeling_qwen3_vl_moe.py @@ -932,38 +932,62 @@ class Qwen3VLMoeTextRotaryEmbedding(nn.Module): def __init__(self, config: Qwen3VLMoeTextConfig, device=None): super().__init__() - if hasattr(config, "rope_scaling") and config.rope_scaling is not None: - self.rope_type = config.rope_scaling.get("rope_type", "default") - else: - self.rope_type = "default" self.max_seq_len_cached = config.max_position_embeddings self.original_max_seq_len = config.max_position_embeddings self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self.original_inv_freq = self.inv_freq + if hasattr(self.config, "rope_parameters"): + self.rope_type = self.config.rope_parameters["rope_type"] + elif hasattr(self.config, "rope_scaling") and self.config.rope_scaling is not None: + self.rope_type = self.config.rope_scaling["rope_type"] + else: + self.rope_type = "default" + rope_init_fn: Callable = self.compute_default_rope_parameters + if self.rope_type != "default": + rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + inv_freq, self.attention_scaling = rope_init_fn(self.config, device) - self.mrope_section = config.rope_scaling.get("mrope_section", [24, 20, 20]) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = inv_freq - def apply_interleaved_mrope(self, freqs, mrope_section): - """Apply interleaved MRoPE to 3D rotary embeddings. - Reorganizes frequency layout from chunked [TTT...HHH...WWW] to - interleaved [THTHWHTHW...TT], preserving frequency continuity. - args: - x: (3, bs, seq_len, head_dim // 2) - mrope_section: (3,) - returns: - x_t: (bs, seq_len, head_dim // 2) + if hasattr(config, "rope_parameters"): + self.mrope_section = config.rope_parameters.get("mrope_section", [24, 20, 20]) + else: + self.mrope_section = config.rope_scaling.get("mrope_section", [24, 20, 20]) + + @staticmethod + def compute_default_rope_parameters( + config: Optional[Qwen3VLMoeTextConfig] = None, + device: Optional["torch.device"] = None, + seq_len: Optional[int] = None, + ) -> tuple["torch.Tensor", float]: """ - freqs_t = freqs[0] # just overwrite the first dimension T - for dim, offset in enumerate((1, 2), start=1): # H, W - length = mrope_section[dim] * 3 - idx = slice(offset, length, 3) - freqs_t[..., idx] = freqs[dim, ..., idx] - return freqs_t + Computes the inverse frequencies according to the original RoPE implementation + Args: + config ([`~transformers.PreTrainedConfig`]): + The model configuration. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + seq_len (`int`, *optional*): + The current sequence length. Unused for this type of RoPE. + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). + """ + if hasattr(config, "rope_parameters"): + base = config.rope_parameters["rope_theta"] + else: + base = config.rope_theta + dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads + + attention_factor = 1.0 # Unused in this type of RoPE + + # Compute the inverse frequencies + inv_freq = 1.0 / ( + base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim) + ) + return inv_freq, attention_factor @torch.no_grad() @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) @@ -985,6 +1009,23 @@ def forward(self, x, position_ids): return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + def apply_interleaved_mrope(self, freqs, mrope_section): + """Apply interleaved MRoPE to 3D rotary embeddings. + Reorganizes frequency layout from chunked [TTT...HHH...WWW] to + interleaved [THWTHWTHW...TT], preserving frequency continuity. + args: + x: (3, bs, seq_len, head_dim // 2) + mrope_section: (3,) + returns: + x_t: (bs, seq_len, head_dim // 2) + """ + freqs_t = freqs[0] # just overwrite the first dimension T + for dim, offset in enumerate((1, 2), start=1): # H, W + length = mrope_section[dim] * 3 + idx = slice(offset, length, 3) + freqs_t[..., idx] = freqs[dim, ..., idx] + return freqs_t + @auto_docstring( custom_intro=( diff --git a/veomni/models/transformers/seed_oss/modeling_seed_oss.py b/veomni/models/transformers/seed_oss/modeling_seed_oss.py index ea0ecedd..069adcf8 100644 --- a/veomni/models/transformers/seed_oss/modeling_seed_oss.py +++ b/veomni/models/transformers/seed_oss/modeling_seed_oss.py @@ -310,20 +310,57 @@ class SeedOssRotaryEmbedding(nn.Module): def __init__(self, config: SeedOssConfig, device=None): super().__init__() - # BC: "rope_type" was originally "type" - if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict): - self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) - else: - self.rope_type = "default" self.max_seq_len_cached = config.max_position_embeddings self.original_max_seq_len = config.max_position_embeddings self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + if hasattr(self.config, "rope_parameters"): + self.rope_type = self.config.rope_parameters["rope_type"] + elif hasattr(self.config, "rope_scaling") and self.config.rope_scaling is not None: + self.rope_type = self.config.rope_scaling["rope_type"] + else: + self.rope_type = "default" + rope_init_fn: Callable = self.compute_default_rope_parameters + if self.rope_type != "default": + rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + inv_freq, self.attention_scaling = rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) - self.original_inv_freq = self.inv_freq + self.original_inv_freq = inv_freq + + @staticmethod + def compute_default_rope_parameters( + config: Optional[SeedOssConfig] = None, + device: Optional["torch.device"] = None, + seq_len: Optional[int] = None, + ) -> tuple["torch.Tensor", float]: + """ + Computes the inverse frequencies according to the original RoPE implementation + Args: + config ([`~transformers.PreTrainedConfig`]): + The model configuration. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + seq_len (`int`, *optional*): + The current sequence length. Unused for this type of RoPE. + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). + """ + if hasattr(config, "rope_parameters"): + base = config.rope_parameters["rope_theta"] + else: + base = config.rope_theta + dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads + + attention_factor = 1.0 # Unused in this type of RoPE + + # Compute the inverse frequencies + inv_freq = 1.0 / ( + base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim) + ) + return inv_freq, attention_factor @torch.no_grad() @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)