diff --git a/veomni/models/auto.py b/veomni/models/auto.py index 5f8b6c7c..a7f6431a 100644 --- a/veomni/models/auto.py +++ b/veomni/models/auto.py @@ -28,8 +28,8 @@ from ..distributed.parallel_state import get_parallel_state from ..utils import logging -from ..utils.device import is_torch_npu_available from .loader import BaseModelLoader, get_loader +from .auto_patch import auto_patch_npu_attention if TYPE_CHECKING: @@ -126,22 +126,6 @@ def build_foundation_model( init_device=init_device, ) - if is_torch_npu_available(): - # We override the forward method (on NPU devices) instead of passing CPU FA kwargs directly to the model in the trainer, - # due to the behavior in https://github.com/pytorch/pytorch/blob/134179474539648ba7dee1317959529fbd0e7f89/torch/distributed/fsdp/_fully_shard/_fsdp_state.py#L130 - logger.info_rank0( - "We override the model’s forward method on NPU devices to ensure that the FA kwargs are on CPU, since the npu_fused_attention requires cpu FA kwargs" - ) - original_forward = model.forward - - @functools.wraps(original_forward) - def wrapped_forward(*args, **kwargs): - if "cu_seq_lens_q" in kwargs and kwargs["cu_seq_lens_q"] is not None: - kwargs["cu_seq_lens_q"] = kwargs["cu_seq_lens_q"].cpu() - if "cu_seq_lens_k" in kwargs and kwargs["cu_seq_lens_k"] is not None: - kwargs["cu_seq_lens_k"] = kwargs["cu_seq_lens_k"].cpu() - return original_forward(*args, **kwargs) - - model.forward = wrapped_forward + auto_patch_npu_attention(model) return model diff --git a/veomni/models/auto_patch.py b/veomni/models/auto_patch.py new file mode 100644 index 00000000..bc717b91 --- /dev/null +++ b/veomni/models/auto_patch.py @@ -0,0 +1,65 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +import types + +ATTENTION_KEYWORDS = ("attention", "Attention", "ATTENTION") + + +def _is_attention_module(module): + name = module.__class__.__name__ + return any(keyword in name for keyword in ATTENTION_KEYWORDS) + + +def _wrap_attention_forward(module): + """Patch forward to move cu_seq_lens_x to CPU only on NPU.""" + if hasattr(module, "_original_forward_patched"): + # Avoid double patch + return + + original_forward = module.forward + + def wrapped_forward(self, *args, **kwargs): + # Only patch for NPU fused-attention case + if torch.npu.is_available(): + for key in ("cu_seq_lens_q", "cu_seq_lens_k"): + if key in kwargs and kwargs[key] is not None: + # Avoid unnecessary sync: only convert if tensor on NPU + v = kwargs[key] + if isinstance(v, torch.Tensor) and v.device.type == "npu": + kwargs[key] = v.cpu() + + return original_forward(*args, **kwargs) + + # Monkey patch + module.forward = types.MethodType(wrapped_forward, module) + module._original_forward_patched = True + + +def auto_patch_npu_attention(model): + """ + Automatically find all attention modules in the model + and patch their forward() so that cu_seq_lens_x stays on CPU. + + Args: + model: torch.nn.Module + """ + for name, module in model.named_modules(): + if _is_attention_module(module): + _wrap_attention_forward(module) + + +__all__ = ["auto_patch_attention"] \ No newline at end of file