Skip to content

Commit 165b151

Browse files
committed
add msprobe support
1 parent 69a5161 commit 165b151

File tree

6 files changed

+297
-2
lines changed

6 files changed

+297
-2
lines changed

docs/source/Megatron-SWIFT/Ascend.md

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,3 +41,135 @@ while iteration < args.train_iters:
4141
...
4242
prof.stop()
4343
```
44+
45+
## NPU 精度数据采集
46+
### 配置
47+
48+
按需修改ms-swift目录下msprobe_config.json文件中的dump_path、level等配置项
49+
更多配置可参考[配置示例](https://gitcode.com/Ascend/mstt/blob/master/debug/accuracy_tools/msprobe/docs/zh/dump/config_json_examples.md)[配置文件介绍](https://gitcode.com/Ascend/mstt/blob/master/debug/accuracy_tools/msprobe/docs/zh/dump/config_json_introduct.md)
50+
51+
52+
### 代码修改
53+
为了支持 msprobe 工具进行精度调试,我们需要修改 `swift/megatron/model/mm_gpt_model.py` 文件中的 `_patch_word_embeddings` 函数。主要改动是调整函数参数和内部实现逻辑,使其能够正确地对嵌入层进行patch
54+
55+
下面是具体的修改内容:
56+
57+
修改前:
58+
```python
59+
def _patch_word_embeddings(self, kwargs):
60+
origin_forward = VocabParallelEmbedding.forward
61+
62+
def forward(_self, input_):
63+
from ..trainers.utils import split_cp_inputs
64+
args = get_args()
65+
reduce_scatter_embeddings = _self.reduce_scatter_embeddings
66+
_self.reduce_scatter_embeddings = False
67+
input_ = torch.masked_fill(input_, input_ < 0, 0)
68+
res = origin_forward(_self, input_)
69+
_self.reduce_scatter_embeddings = reduce_scatter_embeddings
70+
packed_seq_params = kwargs.get('packed_seq_params')
71+
# ...其他逻辑...
72+
return res
73+
VocabParallelEmbedding.forward = forward
74+
try:
75+
yield
76+
finally:
77+
VocabParallelEmbedding.forward = origin_forward
78+
79+
def forward(
80+
self,
81+
input_ids: torch.Tensor,
82+
position_ids: torch.Tensor,
83+
attention_mask: torch.Tensor = None,
84+
decoder_input: torch.Tensor = None,
85+
labels: torch.Tensor = None,
86+
inference_params: InferenceParams = None,
87+
packed_seq_params: PackedSeqParams = None,
88+
**kwargs,
89+
) -> torch.Tensor:
90+
if decoder_input is not None:
91+
pass
92+
elif self.pre_process:
93+
kwargs.update({'input_ids': input_ids, 'packed_seq_params': packed_seq_params})
94+
with self._patch_word_embeddings(kwargs):
95+
decoder_input = self.language_model.embedding(input_ids=input_ids, position_ids=position_ids)
96+
97+
# ...其他逻辑...
98+
```
99+
100+
修改后:
101+
```python
102+
def _patch_word_embeddings(self, kwargs, emb): # 修改1
103+
origin_forward = emb.word_embeddings.forward # 修改2
104+
105+
def forward(input_): # 修改3
106+
from ..trainers.utils import split_cp_inputs
107+
args = get_args()
108+
_self = emb.word_embeddings # 修改4
109+
reduce_scatter_embeddings = _self.reduce_scatter_embeddings
110+
_self.reduce_scatter_embeddings = False
111+
input_ = torch.masked_fill(input_, input_ < 0, 0)
112+
res = origin_forward(input_) # 修改5
113+
_self.reduce_scatter_embeddings = reduce_scatter_embeddings
114+
packed_seq_params = kwargs.get('packed_seq_params')
115+
# ...其他逻辑...
116+
return res
117+
118+
emb.word_embeddings.forward = forward # 修改6
119+
try:
120+
yield
121+
finally:
122+
emb.word_embeddings.forward = origin_forward # 修改7
123+
124+
def forward(
125+
self,
126+
input_ids: torch.Tensor,
127+
position_ids: torch.Tensor,
128+
attention_mask: torch.Tensor = None,
129+
decoder_input: torch.Tensor = None,
130+
labels: torch.Tensor = None,
131+
inference_params: InferenceParams = None,
132+
packed_seq_params: PackedSeqParams = None,
133+
**kwargs,
134+
) -> torch.Tensor:
135+
if decoder_input is not None:
136+
pass
137+
elif self.pre_process:
138+
kwargs.update({'input_ids': input_ids, 'packed_seq_params': packed_seq_params})
139+
with self._patch_word_embeddings(kwargs, self.language_model.embedding): # 修改8
140+
decoder_input = self.language_model.embedding(input_ids=input_ids, position_ids=position_ids)
141+
142+
# ...其他逻辑...
143+
```
144+
145+
主要变化包括:
146+
1. `_patch_word_embeddings` 方法增加了 `emb` 参数,用于接收 embedding 模块实例
147+
2. 直接获取 `emb.word_embeddings.forward` 而不是 `VocabParallelEmbedding.forward`
148+
3. 内部 `forward` 函数签名从 `(_self, input_)` 改为 `(input_)`
149+
4. 在函数内部通过 `emb.word_embeddings` 获取 `_self`
150+
5. 调用原始 forward 时直接传入 `input_`
151+
6. 使用 `emb.word_embeddings.forward` 进行替换和恢复操作(修改6、7)
152+
7. 在调用 `_patch_word_embeddings` 时传入 `self.language_model.embedding` 实例
153+
154+
155+
### 使能
156+
在启动脚本添加`--enable_msprobe True`
157+
158+
另外,由于msprobe不支持融合计算,还需要添加`--no_bias_dropout_fusion True``--no_bias_swiglu_fusion True``--cross_entropy_loss_fusion False`
159+
#### 示例
160+
```shell
161+
PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \
162+
NPROC_PER_NODE=2 \
163+
CUDA_VISIBLE_DEVICES=0,1 \
164+
megatron sft \
165+
--load Qwen2.5-7B-Instruct-mcore \
166+
--dataset 'AI-ModelScope/alpaca-gpt4-data-zh#500' \
167+
'AI-ModelScope/alpaca-gpt4-data-en#500' \
168+
'swift/self-cognition#500' \
169+
--tensor_model_parallel_size 2 \
170+
...
171+
--no_bias_dropout_fusion True \
172+
--no_bias_swiglu_fusion True \
173+
--cross_entropy_loss_fusion False \
174+
--enable_msprobe True
175+
```

docs/source_en/Megatron-SWIFT/Ascend.md

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,3 +41,135 @@ while iteration < args.train_iters:
4141
...
4242
prof.stop()
4343
```
44+
45+
# NPU Accuracy Data Collection
46+
### Configuration
47+
48+
Modify the dump_path, level and other configuration items in the msprobe_config.json file under the ms-swift directory as needed.
49+
More configurations can be found in [Configuration Examples](https://gitcode.com/Ascend/mstt/blob/master/debug/accuracy_tools/msprobe/docs/zh/dump/config_json_examples.md) and [Configuration File Introduction](https://gitcode.com/Ascend/mstt/blob/master/debug/accuracy_tools/msprobe/docs/zh/dump/config_json_introduct.md)
50+
51+
52+
### Code Modification
53+
To support accuracy debugging with the msprobe tool, we need to modify the `_patch_word_embeddings` function in the `swift/megatron/model/mm_gpt_model.py` file. The main changes are to adjust the function parameters and internal implementation logic so that it can correctly patch the embedding layer.
54+
55+
The specific modification content is as follows:
56+
57+
Before modification:
58+
```python
59+
def _patch_word_embeddings(self, kwargs):
60+
origin_forward = VocabParallelEmbedding.forward
61+
62+
def forward(_self, input_):
63+
from ..trainers.utils import split_cp_inputs
64+
args = get_args()
65+
reduce_scatter_embeddings = _self.reduce_scatter_embeddings
66+
_self.reduce_scatter_embeddings = False
67+
input_ = torch.masked_fill(input_, input_ < 0, 0)
68+
res = origin_forward(_self, input_)
69+
_self.reduce_scatter_embeddings = reduce_scatter_embeddings
70+
packed_seq_params = kwargs.get('packed_seq_params')
71+
# ...other logic...
72+
return res
73+
VocabParallelEmbedding.forward = forward
74+
try:
75+
yield
76+
finally:
77+
VocabParallelEmbedding.forward = origin_forward
78+
79+
def forward(
80+
self,
81+
input_ids: torch.Tensor,
82+
position_ids: torch.Tensor,
83+
attention_mask: torch.Tensor = None,
84+
decoder_input: torch.Tensor = None,
85+
labels: torch.Tensor = None,
86+
inference_params: InferenceParams = None,
87+
packed_seq_params: PackedSeqParams = None,
88+
**kwargs,
89+
) -> torch.Tensor:
90+
if decoder_input is not None:
91+
pass
92+
elif self.pre_process:
93+
kwargs.update({'input_ids': input_ids, 'packed_seq_params': packed_seq_params})
94+
with self._patch_word_embeddings(kwargs):
95+
decoder_input = self.language_model.embedding(input_ids=input_ids, position_ids=position_ids)
96+
97+
# ...other logic...
98+
```
99+
100+
After modification:
101+
```python
102+
def _patch_word_embeddings(self, kwargs, emb): # Modification 1
103+
origin_forward = emb.word_embeddings.forward # Modification 2
104+
105+
def forward(input_): # Modification 3
106+
from ..trainers.utils import split_cp_inputs
107+
args = get_args()
108+
_self = emb.word_embeddings # Modification 4
109+
reduce_scatter_embeddings = _self.reduce_scatter_embeddings
110+
_self.reduce_scatter_embeddings = False
111+
input_ = torch.masked_fill(input_, input_ < 0, 0)
112+
res = origin_forward(input_) # Modification 5
113+
_self.reduce_scatter_embeddings = reduce_scatter_embeddings
114+
packed_seq_params = kwargs.get('packed_seq_params')
115+
# ...other logic...
116+
return res
117+
118+
emb.word_embeddings.forward = forward # Modification 6
119+
try:
120+
yield
121+
finally:
122+
emb.word_embeddings.forward = origin_forward # Modification 7
123+
124+
def forward(
125+
self,
126+
input_ids: torch.Tensor,
127+
position_ids: torch.Tensor,
128+
attention_mask: torch.Tensor = None,
129+
decoder_input: torch.Tensor = None,
130+
labels: torch.Tensor = None,
131+
inference_params: InferenceParams = None,
132+
packed_seq_params: PackedSeqParams = None,
133+
**kwargs,
134+
) -> torch.Tensor:
135+
if decoder_input is not None:
136+
pass
137+
elif self.pre_process:
138+
kwargs.update({'input_ids': input_ids, 'packed_seq_params': packed_seq_params})
139+
with self._patch_word_embeddings(kwargs, self.language_model.embedding): # Modification 8
140+
decoder_input = self.language_model.embedding(input_ids=input_ids, position_ids=position_ids)
141+
142+
# ...other logic...
143+
```
144+
145+
Major changes include:
146+
1. The `_patch_word_embeddings` method adds an `emb` parameter to receive the embedding module instance
147+
2. Directly obtain `emb.word_embeddings.forward` instead of `VocabParallelEmbedding.forward`
148+
3. The internal `forward` function signature changed from `(_self, input_)` to `(input_)`
149+
4. Get `_self` through `emb.word_embeddings` inside the function
150+
5. Pass `input_` directly when calling the original forward
151+
6. Use `emb.word_embeddings.forward` for replacement and recovery operations (Modifications 6, 7)
152+
7. Pass the `self.language_model.embedding` instance when calling `_patch_word_embeddings`
153+
154+
155+
### Enablement
156+
Add `--enable_msprobe True` to the startup script
157+
158+
In addition, since msprobe does not support fused computation, you also need to add `--no_bias_dropout_fusion True`, `--no_bias_swiglu_fusion True`, `--cross_entropy_loss_fusion False`
159+
#### Example
160+
```shell
161+
PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \
162+
NPROC_PER_NODE=2 \
163+
CUDA_VISIBLE_DEVICES=0,1 \
164+
megatron sft \
165+
--load Qwen2.5-7B-Instruct-mcore \
166+
--dataset 'AI-ModelScope/alpaca-gpt4-data-zh#500' \
167+
'AI-ModelScope/alpaca-gpt4-data-en#500' \
168+
'swift/self-cognition#500' \
169+
--tensor_model_parallel_size 2 \
170+
...
171+
--no_bias_dropout_fusion True \
172+
--no_bias_swiglu_fusion True \
173+
--cross_entropy_loss_fusion False \
174+
--enable_msprobe True
175+
```

msprobe_config.json

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
{
2+
"task": "statistics",
3+
"dump_path": "./dump_path",
4+
"rank": [],
5+
"step": [],
6+
"level": "mix",
7+
"async_dump": false,
8+
"statistics": {
9+
"scope": [],
10+
"list": [],
11+
"tensor_list": [],
12+
"data_mode": ["all"],
13+
"summary_mode": "statistics"
14+
}
15+
}

requirements/framework.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ importlib_metadata
1414
jieba
1515
json_repair
1616
matplotlib
17+
mindstudio-probe
1718
modelscope>=1.23
1819
nltk
1920
numpy

swift/megatron/argument/megatron_args.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,10 @@ class ExtraMegatronArguments(RLHFMegatronArgumentsMixin, MegatronTunerMixin):
359359
# qwen3_vl, qwen3_omni
360360
mrope_interleaved: Optional[bool] = None
361361

362+
# dump
363+
enable_msprobe: bool = False
364+
msprobe_config: str = './msprobe_config.json'
365+
362366
@staticmethod
363367
def load_args_config(ckpt_dir: Optional[str]) -> Dict[str, Any]:
364368
res = {}

swift/megatron/trainers/base.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -522,8 +522,19 @@ def _all_reduce_metric(self,
522522
def train_step(self, forward_step_func, data_iterator, model, optimizer, opt_param_scheduler, config, *args,
523523
**kwargs):
524524
new_data_iterator = self._replace_data_iterator(data_iterator, model)
525-
return self._origin_train_step(forward_step_func, new_data_iterator, model, optimizer, opt_param_scheduler,
526-
config, *args, **kwargs)
525+
debugger_on = self.args.enable_msprobe
526+
if debugger_on:
527+
from msprobe.pytorch import PrecisionDebugger
528+
debugger = PrecisionDebugger(config_path=self.args.msprobe_config, model=model)
529+
debugger.start()
530+
try:
531+
origin_train_step_out = self._origin_train_step(forward_step_func, new_data_iterator, model, optimizer, opt_param_scheduler,
532+
config, *args, **kwargs)
533+
finally:
534+
if debugger_on:
535+
debugger.stop()
536+
debugger.step()
537+
return origin_train_step_out
527538

528539
# Code borrowed from NVIDIA/Megatron-LM
529540
def evaluate(

0 commit comments

Comments
 (0)