Skip to content

Commit 44ebcbe

Browse files
feat: auto-patch all Attention layers to ensure cu_seq_lens stays on CPU for NPU fused-attention.
1 parent 742f88f commit 44ebcbe

File tree

3 files changed

+116
-18
lines changed

3 files changed

+116
-18
lines changed

configs/sft/qwen3_sft.yaml

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
model:
2+
model_path: Qwen/Qwen3-32B
3+
attn_implementation: flash_attention_2
4+
5+
data:
6+
train_path: None
7+
data_type: conversation
8+
datasets_type: iterable
9+
dataloader_type: native
10+
chat_template: default
11+
max_seq_len: 2048
12+
train_size: 40000000
13+
text_keys: messages
14+
15+
train:
16+
num_train_epochs: 2
17+
max_steps: 2000
18+
use_wandb: false
19+
output_dir: qwen3_sft
20+
data_parallel_mode: fsdp2
21+
ulysses_parallel_size: 1
22+
expert_parallel_size: 1
23+
global_batch_size: 16
24+
micro_batch_size: 1
25+
rmpad: false
26+
rmpad_with_pos_ids: true
27+
bsz_warmup_ratio: 0.007
28+
optimizer: adamw
29+
lr: 1.0e-5
30+
lr_warmup_ratio: 0.007
31+
lr_decay_style: constant
32+
lr_decay_ratio: 1.0
33+
weight_decay: 0.01
34+
max_grad_norm: 1.0
35+
enable_mixed_precision: true
36+
enable_gradient_checkpointing: true
37+
enable_full_shard: true
38+
enable_fsdp_offload: false
39+
enable_activation_offload: false
40+
init_device: meta
41+
enable_full_determinism: false
42+
empty_cache_steps: 500
43+
ckpt_manager: dcp
44+
load_checkpoint_path: ""
45+
save_steps: 2000
46+
save_epochs: 2
47+
save_hf_weights: true
48+
wandb_project: Qwen3_32B_sft
49+
wandb_name: Qwen3_32B_fsdp2

veomni/models/auto.py

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@
2828

2929
from ..distributed.parallel_state import get_parallel_state
3030
from ..utils import logging
31-
from ..utils.device import is_torch_npu_available
3231
from .loader import BaseModelLoader, get_loader
32+
from .auto_patch import auto_patch_npu_attention
3333

3434

3535
if TYPE_CHECKING:
@@ -126,22 +126,6 @@ def build_foundation_model(
126126
init_device=init_device,
127127
)
128128

129-
if is_torch_npu_available():
130-
# We override the forward method (on NPU devices) instead of passing CPU FA kwargs directly to the model in the trainer,
131-
# due to the behavior in https://github.com/pytorch/pytorch/blob/134179474539648ba7dee1317959529fbd0e7f89/torch/distributed/fsdp/_fully_shard/_fsdp_state.py#L130
132-
logger.info_rank0(
133-
"We override the model’s forward method on NPU devices to ensure that the FA kwargs are on CPU, since the npu_fused_attention requires cpu FA kwargs"
134-
)
135-
original_forward = model.forward
136-
137-
@functools.wraps(original_forward)
138-
def wrapped_forward(*args, **kwargs):
139-
if "cu_seq_lens_q" in kwargs and kwargs["cu_seq_lens_q"] is not None:
140-
kwargs["cu_seq_lens_q"] = kwargs["cu_seq_lens_q"].cpu()
141-
if "cu_seq_lens_k" in kwargs and kwargs["cu_seq_lens_k"] is not None:
142-
kwargs["cu_seq_lens_k"] = kwargs["cu_seq_lens_k"].cpu()
143-
return original_forward(*args, **kwargs)
144-
145-
model.forward = wrapped_forward
129+
auto_patch_npu_attention(model)
146130

147131
return model

veomni/models/auto_patch.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# Copyright 2025 Bytedance Ltd. and/or its affiliates
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
import torch
17+
import types
18+
19+
ATTENTION_KEYWORDS = ("attention", "Attention", "ATTENTION")
20+
21+
22+
def _is_attention_module(module):
23+
name = module.__class__.__name__
24+
return any(keyword in name for keyword in ATTENTION_KEYWORDS)
25+
26+
27+
def _wrap_attention_forward(module):
28+
"""Patch forward to move cu_seq_lens_x to CPU only on NPU."""
29+
if hasattr(module, "_original_forward_patched"):
30+
# Avoid double patch
31+
return
32+
33+
original_forward = module.forward
34+
35+
def wrapped_forward(self, *args, **kwargs):
36+
# Only patch for NPU fused-attention case
37+
if torch.npu.is_available():
38+
for key in ("cu_seq_lens_q", "cu_seq_lens_k"):
39+
if key in kwargs and kwargs[key] is not None:
40+
# Avoid unnecessary sync: only convert if tensor on NPU
41+
v = kwargs[key]
42+
if isinstance(v, torch.Tensor) and v.device.type == "npu":
43+
kwargs[key] = v.cpu()
44+
45+
return original_forward(*args, **kwargs)
46+
47+
# Monkey patch
48+
module.forward = types.MethodType(wrapped_forward, module)
49+
module._original_forward_patched = True
50+
51+
52+
def auto_patch_npu_attention(model):
53+
"""
54+
Automatically find all attention modules in the model
55+
and patch their forward() so that cu_seq_lens_x stays on CPU.
56+
57+
Args:
58+
model: torch.nn.Module
59+
"""
60+
for name, module in model.named_modules():
61+
if _is_attention_module(module):
62+
_wrap_attention_forward(module)
63+
64+
65+
__all__ = ["auto_patch_attention"]

0 commit comments

Comments
 (0)