Skip to content

Conversation

@zheliuyu
Copy link
Contributor

@zheliuyu zheliuyu commented Dec 30, 2025

What does this PR do?

I noticed that VeOmni plans to update transformers to v5: #268

Coincidentally, while recently conducting SFT tests for Qwen3 in VeOmni, I discovered some breaking changes between Transformers v4 and v5. This PR addresses a subset of these issues.

Checklist Before Starting

  • Search for similar PRs. Paste at least one query link here: ...
  • Format the PR title as [{modules}] {type}: {description} (This will be checked by the CI)
    • {modules} include misc, ci, config, docs, data, dist, omni, logging, model, optim, ckpt, release, task, perf, ops, parallel
    • If this PR involves multiple modules, separate them with , like [ci, data, model]
    • {type} is in feat, fix, refactor, chore, test
    • If this PR breaks any API (CLI arguments, config, function signature, etc.), add [BREAKING] to the beginning of the title.
    • Example: [BREAKING][parallel, model] feat: dynamic batching

Test

  • Run Qwen3 SFT. Maintain consistent loss trends pre- and post-modification.
  • All ci should pass.

API and Usage Example

bash train.sh tasks/train.py config/sft/qwen3_sft.yaml

Design & Code Changes

  1. v5 no longer provides is_safetensors_available(), and since safetensors is now a required dependency, the conditional check code has been removed.
  2. The rope_scaling has been removed and replaced with the rope_parameters.
  3. The "default" option for rope_type has been removed, with its logic moved to the compute_default_rope_parameters() function in each respective modeling file.

Checklist Before Submitting

Important

Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.

@github-actions github-actions bot added the hf_v5 Related for transformers v5 label Dec 30, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors several models to be compatible with the RoPE implementation in transformers v5. The changes mostly involve updating the RotaryEmbedding classes to use the new configuration structure and utility functions. While most models are updated correctly, there are some critical issues with models that use custom 3D position_ids for RoPE.

Specifically:

  • For qwen2_5vl and qwen2_vl models, the dynamic RoPE scaling functionality has been removed.
  • For qwen3_vl and qwen3_vl_moe models, the new @dynamic_rope_update decorator from transformers is used, but it is incompatible with the 3D position_ids of these models, which will likely lead to runtime errors.

These issues should be addressed to ensure correct functionality and avoid regressions.

return inv_freq, attention_factor

@torch.no_grad()
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The @dynamic_rope_update decorator from transformers.modeling_rope_utils is not compatible with the 3D position_ids tensor of shape (3, batch_size, sequence_length) used in this model's forward method. The decorator's internal logic expects a 2D position_ids tensor to calculate the sequence length, and will likely fail or produce incorrect results with a 3D tensor. This could lead to runtime errors or incorrect RoPE scaling when dynamic scaling is triggered.

A custom implementation for dynamic frequency updates that can handle 3D position_ids is required if this feature is to be supported.

return inv_freq, attention_factor

@torch.no_grad()
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The @dynamic_rope_update decorator is incompatible with the 3D position_ids tensor used in Qwen3VLMoeTextRotaryEmbedding. The transformers utility expects a 2D tensor to correctly infer sequence length for dynamic scaling. Using it with a 3D tensor will likely cause runtime errors or incorrect scaling behavior. If dynamic RoPE scaling is needed, a custom update mechanism that correctly handles the 3D position_ids is necessary.

Comment on lines 648 to 692
def __init__(self, config: Qwen2_5_VLConfig, device=None):
super().__init__()
# BC: "rope_type" was originally "type"
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
else:
self.rope_type = "default"
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings

self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]

inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
self.rope_type = self.config.rope_parameters["rope_type"]
rope_init_fn: Callable = self.compute_default_rope_parameters
if self.rope_type != "default":
rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)

self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq
self.original_inv_freq = inv_freq

def _dynamic_frequency_update(self, position_ids, device):
@staticmethod
def compute_default_rope_parameters(
config: Optional[Qwen2_5_VLConfig] = None,
device: Optional["torch.device"] = None,
seq_len: Optional[int] = None,
) -> tuple["torch.Tensor", float]:
"""
dynamic RoPE layers should recompute `inv_freq` in the following situations:
1 - growing beyond the cached sequence length (allow scaling)
2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
Computes the inverse frequencies according to the original RoPE implementation
Args:
config ([`~transformers.PreTrainedConfig`]):
The model configuration.
device (`torch.device`):
The device to use for initialization of the inverse frequencies.
seq_len (`int`, *optional*):
The current sequence length. Unused for this type of RoPE.
Returns:
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
"""
seq_len = torch.max(position_ids) + 1
if seq_len > self.max_seq_len_cached: # growth
inv_freq, self.attention_scaling = self.rope_init_fn(
self.config, device, seq_len=seq_len, **self.rope_kwargs
)
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
self.max_seq_len_cached = seq_len
base = config.rope_parameters["rope_theta"]
dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads

if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
self.max_seq_len_cached = self.original_max_seq_len
attention_factor = 1.0 # Unused in this type of RoPE

@torch.no_grad()
def forward(self, x, position_ids):
if "dynamic" in self.rope_type:
self._dynamic_frequency_update(position_ids, device=x.device)
# Compute the inverse frequencies
inv_freq = 1.0 / (
base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
)
return inv_freq, attention_factor
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The refactoring to align with transformers v5 has removed the _dynamic_frequency_update method and its invocation from the forward method. This effectively disables dynamic RoPE scaling for this model, which was previously supported. This is a significant change in functionality and might be unintentional. If dynamic scaling is still desired, a custom implementation compatible with the 3D position_ids used in this model should be provided, as the standard @dynamic_rope_update decorator from transformers is not suitable for this case.

Comment on lines 129 to 173
def __init__(self, config: Qwen2VLConfig, device=None):
super().__init__()
# BC: "rope_type" was originally "type"
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
else:
self.rope_type = "default"
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings

self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]

inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
self.rope_type = self.config.rope_parameters["rope_type"]
rope_init_fn: Callable = self.compute_default_rope_parameters
if self.rope_type != "default":
rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)

self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq
self.original_inv_freq = inv_freq

def _dynamic_frequency_update(self, position_ids, device):
@staticmethod
def compute_default_rope_parameters(
config: Optional[Qwen2VLConfig] = None,
device: Optional["torch.device"] = None,
seq_len: Optional[int] = None,
) -> tuple["torch.Tensor", float]:
"""
dynamic RoPE layers should recompute `inv_freq` in the following situations:
1 - growing beyond the cached sequence length (allow scaling)
2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
Computes the inverse frequencies according to the original RoPE implementation
Args:
config ([`~transformers.PreTrainedConfig`]):
The model configuration.
device (`torch.device`):
The device to use for initialization of the inverse frequencies.
seq_len (`int`, *optional*):
The current sequence length. Unused for this type of RoPE.
Returns:
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
"""
seq_len = torch.max(position_ids) + 1
if seq_len > self.max_seq_len_cached: # growth
inv_freq, self.attention_scaling = self.rope_init_fn(
self.config, device, seq_len=seq_len, **self.rope_kwargs
)
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
self.max_seq_len_cached = seq_len
base = config.rope_parameters["rope_theta"]
dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads

if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
self.max_seq_len_cached = self.original_max_seq_len
attention_factor = 1.0 # Unused in this type of RoPE

@torch.no_grad()
def forward(self, x, position_ids):
if "dynamic" in self.rope_type:
self._dynamic_frequency_update(position_ids, device=x.device)
# Compute the inverse frequencies
inv_freq = 1.0 / (
base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
)
return inv_freq, attention_factor
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similar to other models in this PR, the refactoring of Qwen2VLRotaryEmbedding has removed the _dynamic_frequency_update method and its call within forward. This disables the dynamic RoPE scaling feature for this model. This could be a regression if the feature is still intended to be supported. Please consider re-implementing this functionality in a way that is compatible with the new transformers API and the model's custom 3D position_ids.

@zheliuyu zheliuyu marked this pull request as ready for review December 30, 2025 08:57
@zheliuyu zheliuyu changed the title [model] refactor: Compatible with Transformers v5 RoPE [WIP] [model] refactor: Compatible with Transformers v5 RoPE Dec 30, 2025
@github-actions github-actions bot added the wip label Dec 30, 2025
@zheliuyu zheliuyu changed the title [WIP] [model] refactor: Compatible with Transformers v5 RoPE [model] refactor: Compatible with Transformers v5 RoPE Jan 5, 2026
@Crystal-jiang
Copy link
Collaborator

Crystal-jiang commented Jan 5, 2026

The WAN model also employs RoPE—consider refactoring it in tandem to ensure consistency and reduce code duplication.

@Crystal-jiang
Copy link
Collaborator

Could you please update the review with the precision alignment results after the refactoring? This will help ensure consistency and correctness in the redesigned implementation.

@zheliuyu
Copy link
Contributor Author

zheliuyu commented Jan 5, 2026

Could you please update the review with the precision alignment results after the refactoring? This will help ensure consistency and correctness in the redesigned implementation.

Sure, I'll set up a comparison experiment later.

@zheliuyu
Copy link
Contributor Author

zheliuyu commented Jan 5, 2026

The WAN model also employs RoPE—consider refactoring it in tandem to ensure consistency and reduce code duplication.

The RoPE in WAN model is not imported from the transformers library, so it is not affected by transformers version upgrades.

@zheliuyu
Copy link
Contributor Author

zheliuyu commented Jan 9, 2026

A quick experiment: Loss curves should be consistent between Transformers v4.57.3 and v5.0.0rc0 for the same algorithm.

bash train.sh tasks/omni/train_qwen_vl.py configs/multimodal/qwen3_vl/qwen3_vl_dense.yaml \
    --model.model_path ./Qwen3-VL-8B-Instruct \
    --data.train_path ./sharegpt4v_instruct_gpt4-vision_cap100k_coco.json \
    --data.dataloader_type native \
    --data.datasets_type iterable \
    --data.source_name sharegpt4v_sft \
    --data.num_workers 8 \
    --train.global_batch_size 24 \
    --train.micro_batch_size 1 \
    --train.max_steps 20 \
    --train.enable_profiling false
image

@zheliuyu
Copy link
Contributor Author

zheliuyu commented Jan 9, 2026

I think this PR is ready for review. @FoolPlayer @Luosuu @Coach257 @Crystal-jiang

@zheliuyu
Copy link
Contributor Author

It appears that VeOmni has developed a new approach for adapting to transformers v5, so we will temporarily close this PR.

@zheliuyu zheliuyu closed this Jan 15, 2026
@FoolPlayer
Copy link
Collaborator

It appears that VeOmni has developed a new approach for adapting to transformers v5, so we will temporarily close this PR.

Thanks for you PR! Yes, we plan to adapt v5 like #392.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

hf_v5 Related for transformers v5 wip

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants