Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 2 additions & 18 deletions veomni/models/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@

from ..distributed.parallel_state import get_parallel_state
from ..utils import logging
from ..utils.device import is_torch_npu_available
from .loader import BaseModelLoader, get_loader
from .auto_patch import auto_patch_npu_attention


if TYPE_CHECKING:
Expand Down Expand Up @@ -126,22 +126,6 @@ def build_foundation_model(
init_device=init_device,
)

if is_torch_npu_available():
# We override the forward method (on NPU devices) instead of passing CPU FA kwargs directly to the model in the trainer,
# due to the behavior in https://github.com/pytorch/pytorch/blob/134179474539648ba7dee1317959529fbd0e7f89/torch/distributed/fsdp/_fully_shard/_fsdp_state.py#L130
logger.info_rank0(
"We override the model’s forward method on NPU devices to ensure that the FA kwargs are on CPU, since the npu_fused_attention requires cpu FA kwargs"
)
original_forward = model.forward

@functools.wraps(original_forward)
def wrapped_forward(*args, **kwargs):
if "cu_seq_lens_q" in kwargs and kwargs["cu_seq_lens_q"] is not None:
kwargs["cu_seq_lens_q"] = kwargs["cu_seq_lens_q"].cpu()
if "cu_seq_lens_k" in kwargs and kwargs["cu_seq_lens_k"] is not None:
kwargs["cu_seq_lens_k"] = kwargs["cu_seq_lens_k"].cpu()
return original_forward(*args, **kwargs)

model.forward = wrapped_forward
auto_patch_npu_attention(model)

return model
65 changes: 65 additions & 0 deletions veomni/models/auto_patch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright 2025 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import torch
import types

ATTENTION_KEYWORDS = ("attention", "Attention", "ATTENTION")


def _is_attention_module(module):
name = module.__class__.__name__
return any(keyword in name for keyword in ATTENTION_KEYWORDS)


def _wrap_attention_forward(module):
"""Patch forward to move cu_seq_lens_x to CPU only on NPU."""
if hasattr(module, "_original_forward_patched"):
# Avoid double patch
return

original_forward = module.forward

def wrapped_forward(self, *args, **kwargs):
# Only patch for NPU fused-attention case
if torch.npu.is_available():
for key in ("cu_seq_lens_q", "cu_seq_lens_k"):
if key in kwargs and kwargs[key] is not None:
# Avoid unnecessary sync: only convert if tensor on NPU
v = kwargs[key]
if isinstance(v, torch.Tensor) and v.device.type == "npu":
kwargs[key] = v.cpu()

return original_forward(*args, **kwargs)

# Monkey patch
module.forward = types.MethodType(wrapped_forward, module)
module._original_forward_patched = True


def auto_patch_npu_attention(model):
"""
Automatically find all attention modules in the model
and patch their forward() so that cu_seq_lens_x stays on CPU.

Args:
model: torch.nn.Module
"""
for name, module in model.named_modules():
if _is_attention_module(module):
_wrap_attention_forward(module)


__all__ = ["auto_patch_attention"]