Skip to content

Commit ff28fd6

Browse files
committed
add msprobe support
1 parent 69a5161 commit ff28fd6

File tree

2 files changed

+317
-0
lines changed

2 files changed

+317
-0
lines changed

docs/source/Megatron-SWIFT/Ascend.md

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,3 +41,160 @@ while iteration < args.train_iters:
4141
...
4242
prof.stop()
4343
```
44+
45+
## NPU 精度数据采集
46+
### 安装msprobe
47+
```shell
48+
pip install mindstudio-probe
49+
```
50+
51+
### 代码修改
52+
为了支持 msprobe 工具进行精度调试,我们需要修改 `swift/megatron/model/mm_gpt_model.py` 文件中的 `_patch_word_embeddings` 函数。主要改动是调整函数参数和内部实现逻辑,使其能够正确地对嵌入层进行patch
53+
54+
下面是具体的修改内容:
55+
56+
修改前:
57+
```python
58+
def _patch_word_embeddings(self, kwargs):
59+
origin_forward = VocabParallelEmbedding.forward
60+
61+
def forward(_self, input_):
62+
from ..trainers.utils import split_cp_inputs
63+
args = get_args()
64+
reduce_scatter_embeddings = _self.reduce_scatter_embeddings
65+
_self.reduce_scatter_embeddings = False
66+
input_ = torch.masked_fill(input_, input_ < 0, 0)
67+
res = origin_forward(_self, input_)
68+
_self.reduce_scatter_embeddings = reduce_scatter_embeddings
69+
packed_seq_params = kwargs.get('packed_seq_params')
70+
# ...其他逻辑...
71+
return res
72+
VocabParallelEmbedding.forward = forward
73+
try:
74+
yield
75+
finally:
76+
VocabParallelEmbedding.forward = origin_forward
77+
78+
def forward(
79+
self,
80+
input_ids: torch.Tensor,
81+
position_ids: torch.Tensor,
82+
attention_mask: torch.Tensor = None,
83+
decoder_input: torch.Tensor = None,
84+
labels: torch.Tensor = None,
85+
inference_params: InferenceParams = None,
86+
packed_seq_params: PackedSeqParams = None,
87+
**kwargs,
88+
) -> torch.Tensor:
89+
if decoder_input is not None:
90+
pass
91+
elif self.pre_process:
92+
kwargs.update({'input_ids': input_ids, 'packed_seq_params': packed_seq_params})
93+
with self._patch_word_embeddings(kwargs):
94+
decoder_input = self.language_model.embedding(input_ids=input_ids, position_ids=position_ids)
95+
96+
# ...其他逻辑...
97+
```
98+
99+
修改后:
100+
```python
101+
def _patch_word_embeddings(self, kwargs, emb): # 修改1
102+
origin_forward = emb.word_embeddings.forward # 修改2
103+
104+
def forward(input_): # 修改3
105+
from ..trainers.utils import split_cp_inputs
106+
args = get_args()
107+
_self = emb.word_embeddings # 修改4
108+
reduce_scatter_embeddings = _self.reduce_scatter_embeddings
109+
_self.reduce_scatter_embeddings = False
110+
input_ = torch.masked_fill(input_, input_ < 0, 0)
111+
res = origin_forward(input_) # 修改5
112+
_self.reduce_scatter_embeddings = reduce_scatter_embeddings
113+
packed_seq_params = kwargs.get('packed_seq_params')
114+
# ...其他逻辑...
115+
return res
116+
117+
emb.word_embeddings.forward = forward # 修改6
118+
try:
119+
yield
120+
finally:
121+
emb.word_embeddings.forward = origin_forward # 修改7
122+
123+
def forward(
124+
self,
125+
input_ids: torch.Tensor,
126+
position_ids: torch.Tensor,
127+
attention_mask: torch.Tensor = None,
128+
decoder_input: torch.Tensor = None,
129+
labels: torch.Tensor = None,
130+
inference_params: InferenceParams = None,
131+
packed_seq_params: PackedSeqParams = None,
132+
**kwargs,
133+
) -> torch.Tensor:
134+
if decoder_input is not None:
135+
pass
136+
elif self.pre_process:
137+
kwargs.update({'input_ids': input_ids, 'packed_seq_params': packed_seq_params})
138+
with self._patch_word_embeddings(kwargs, self.language_model.embedding): # 修改8
139+
decoder_input = self.language_model.embedding(input_ids=input_ids, position_ids=position_ids)
140+
141+
# ...其他逻辑...
142+
```
143+
144+
主要变化包括:
145+
1. `_patch_word_embeddings` 方法增加了 `emb` 参数,用于接收 embedding 模块实例
146+
2. 直接获取 `emb.word_embeddings.forward` 而不是 `VocabParallelEmbedding.forward`
147+
3. 内部 `forward` 函数签名从 `(_self, input_)` 改为 `(input_)`
148+
4. 在函数内部通过 `emb.word_embeddings` 获取 `_self`
149+
5. 调用原始 forward 时直接传入 `input_`
150+
6. 使用 `emb.word_embeddings.forward` 进行替换和恢复操作(修改6、7)
151+
7. 在调用 `_patch_word_embeddings` 时传入 `self.language_model.embedding` 实例
152+
153+
对文件swift/megatron/trainers/base.py中的train_step函数进行修改
154+
修改前:
155+
```python
156+
def train_step(self, forward_step_func, data_iterator, model, optimizer, opt_param_scheduler, config, *args,
157+
**kwargs):
158+
new_data_iterator = self._replace_data_iterator(data_iterator, model)
159+
return self._origin_train_step(forward_step_func, new_data_iterator, model, optimizer, opt_param_scheduler,
160+
config, *args, **kwargs)
161+
162+
```
163+
修改后:
164+
```python
165+
def train_step(self, forward_step_func, data_iterator, model, optimizer, opt_param_scheduler, config, *args,
166+
**kwargs):
167+
new_data_iterator = self._replace_data_iterator(data_iterator, model)
168+
from msprobe.pytorch import PrecisionDebugger
169+
debugger = PrecisionDebugger(dump_path='./dump_path', level='mix', model=model)
170+
debugger.start()
171+
try:
172+
origin_train_step_out = self._origin_train_step(
173+
forward_step_func, new_data_iterator, model, optimizer, opt_param_scheduler,config, *args, **kwargs)
174+
finally:
175+
debugger.stop()
176+
debugger.step()
177+
return origin_train_step_out
178+
```
179+
180+
181+
### 使能
182+
183+
另外,由于msprobe不支持融合计算,需要在启动脚本添加`--no_bias_dropout_fusion True``--no_bias_swiglu_fusion True``--cross_entropy_loss_fusion False`
184+
185+
#### 示例
186+
```shell
187+
PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \
188+
NPROC_PER_NODE=2 \
189+
CUDA_VISIBLE_DEVICES=0,1 \
190+
megatron sft \
191+
--load Qwen2.5-7B-Instruct-mcore \
192+
--dataset 'AI-ModelScope/alpaca-gpt4-data-zh#500' \
193+
'AI-ModelScope/alpaca-gpt4-data-en#500' \
194+
'swift/self-cognition#500' \
195+
--tensor_model_parallel_size 2 \
196+
...
197+
--no_bias_dropout_fusion True \
198+
--no_bias_swiglu_fusion True \
199+
--cross_entropy_loss_fusion False
200+
```

docs/source_en/Megatron-SWIFT/Ascend.md

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,3 +41,163 @@ while iteration < args.train_iters:
4141
...
4242
prof.stop()
4343
```
44+
45+
# NPU Accuracy Data Collection
46+
### Installing msprobe
47+
```shell
48+
pip install mindstudio-probe
49+
```
50+
51+
### Code Modification
52+
To support accuracy debugging with the msprobe tool, we need to modify the `_patch_word_embeddings` function in the `swift/megatron/model/mm_gpt_model.py` file. The main changes are to adjust the function parameters and internal implementation logic so that it can correctly patch the embedding layer.
53+
54+
The specific modification content is as follows:
55+
56+
Before modification:
57+
```python
58+
def _patch_word_embeddings(self, kwargs):
59+
origin_forward = VocabParallelEmbedding.forward
60+
61+
def forward(_self, input_):
62+
from ..trainers.utils import split_cp_inputs
63+
args = get_args()
64+
reduce_scatter_embeddings = _self.reduce_scatter_embeddings
65+
_self.reduce_scatter_embeddings = False
66+
input_ = torch.masked_fill(input_, input_ < 0, 0)
67+
res = origin_forward(_self, input_)
68+
_self.reduce_scatter_embeddings = reduce_scatter_embeddings
69+
packed_seq_params = kwargs.get('packed_seq_params')
70+
# ...other logic...
71+
return res
72+
VocabParallelEmbedding.forward = forward
73+
try:
74+
yield
75+
finally:
76+
VocabParallelEmbedding.forward = origin_forward
77+
78+
def forward(
79+
self,
80+
input_ids: torch.Tensor,
81+
position_ids: torch.Tensor,
82+
attention_mask: torch.Tensor = None,
83+
decoder_input: torch.Tensor = None,
84+
labels: torch.Tensor = None,
85+
inference_params: InferenceParams = None,
86+
packed_seq_params: PackedSeqParams = None,
87+
**kwargs,
88+
) -> torch.Tensor:
89+
if decoder_input is not None:
90+
pass
91+
elif self.pre_process:
92+
kwargs.update({'input_ids': input_ids, 'packed_seq_params': packed_seq_params})
93+
with self._patch_word_embeddings(kwargs):
94+
decoder_input = self.language_model.embedding(input_ids=input_ids, position_ids=position_ids)
95+
96+
# ...other logic...
97+
```
98+
99+
After modification:
100+
```python
101+
def _patch_word_embeddings(self, kwargs, emb): # Modification 1
102+
origin_forward = emb.word_embeddings.forward # Modification 2
103+
104+
def forward(input_): # Modification 3
105+
from ..trainers.utils import split_cp_inputs
106+
args = get_args()
107+
_self = emb.word_embeddings # Modification 4
108+
reduce_scatter_embeddings = _self.reduce_scatter_embeddings
109+
_self.reduce_scatter_embeddings = False
110+
input_ = torch.masked_fill(input_, input_ < 0, 0)
111+
res = origin_forward(input_) # Modification 5
112+
_self.reduce_scatter_embeddings = reduce_scatter_embeddings
113+
packed_seq_params = kwargs.get('packed_seq_params')
114+
# ...other logic...
115+
return res
116+
117+
emb.word_embeddings.forward = forward # Modification 6
118+
try:
119+
yield
120+
finally:
121+
emb.word_embeddings.forward = origin_forward # Modification 7
122+
123+
def forward(
124+
self,
125+
input_ids: torch.Tensor,
126+
position_ids: torch.Tensor,
127+
attention_mask: torch.Tensor = None,
128+
decoder_input: torch.Tensor = None,
129+
labels: torch.Tensor = None,
130+
inference_params: InferenceParams = None,
131+
packed_seq_params: PackedSeqParams = None,
132+
**kwargs,
133+
) -> torch.Tensor:
134+
if decoder_input is not None:
135+
pass
136+
elif self.pre_process:
137+
kwargs.update({'input_ids': input_ids, 'packed_seq_params': packed_seq_params})
138+
with self._patch_word_embeddings(kwargs, self.language_model.embedding): # Modification 8
139+
decoder_input = self.language_model.embedding(input_ids=input_ids, position_ids=position_ids)
140+
141+
# ...other logic...
142+
```
143+
144+
Major changes include:
145+
1. The `_patch_word_embeddings` method adds an `emb` parameter to receive the embedding module instance
146+
2. Directly obtain `emb.word_embeddings.forward` instead of `VocabParallelEmbedding.forward`
147+
3. The internal `forward` function signature changed from `(_self, input_)` to `(input_)`
148+
4. Get `_self` through `emb.word_embeddings` inside the function
149+
5. Pass `input_` directly when calling the original forward
150+
6. Use `emb.word_embeddings.forward` for replacement and recovery operations (Modifications 6, 7)
151+
7. Pass the `self.language_model.embedding` instance when calling `_patch_word_embeddings`
152+
153+
154+
Modify the train_step function in the file swift/megatron/trainers/base.py
155+
156+
Before modification:
157+
```python
158+
def train_step(self, forward_step_func, data_iterator, model, optimizer, opt_param_scheduler, config, *args,
159+
**kwargs):
160+
new_data_iterator = self._replace_data_iterator(data_iterator, model)
161+
return self._origin_train_step(forward_step_func, new_data_iterator, model, optimizer, opt_param_scheduler,
162+
config, *args, **kwargs)
163+
164+
```
165+
166+
After modification:
167+
```python
168+
def train_step(self, forward_step_func, data_iterator, model, optimizer, opt_param_scheduler, config, *args,
169+
**kwargs):
170+
new_data_iterator = self._replace_data_iterator(data_iterator, model)
171+
from msprobe.pytorch import PrecisionDebugger
172+
debugger = PrecisionDebugger(dump_path='./dump_path', level='mix', model=model)
173+
debugger.start()
174+
try:
175+
origin_train_step_out = self._origin_train_step(
176+
forward_step_func, new_data_iterator, model, optimizer, opt_param_scheduler,config, *args, **kwargs)
177+
finally:
178+
debugger.stop()
179+
debugger.step()
180+
return origin_train_step_out
181+
182+
```
183+
184+
### Enable
185+
186+
Additionally, since msprobe does not support fusion computation, you need to add `--no_bias_dropout_fusion True`, `--no_bias_swiglu_fusion True`, `--cross_entropy_loss_fusion False` to the launch script.
187+
188+
#### Example
189+
```shell
190+
PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \
191+
NPROC_PER_NODE=2 \
192+
CUDA_VISIBLE_DEVICES=0,1 \
193+
megatron sft \
194+
--load Qwen2.5-7B-Instruct-mcore \
195+
--dataset 'AI-ModelScope/alpaca-gpt4-data-zh#500' \
196+
'AI-ModelScope/alpaca-gpt4-data-en#500' \
197+
'swift/self-cognition#500' \
198+
--tensor_model_parallel_size 2 \
199+
...
200+
--no_bias_dropout_fusion True \
201+
--no_bias_swiglu_fusion True \
202+
--cross_entropy_loss_fusion False
203+
```

0 commit comments

Comments
 (0)