Skip to content

Commit 10cffdb

Browse files
committed
add dump tool
1 parent 6c130eb commit 10cffdb

File tree

4 files changed

+31
-1
lines changed

4 files changed

+31
-1
lines changed

msprobe_config.json

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
{
2+
"task": "statistics",
3+
"dump_path": "./dump_path",
4+
"rank": [],
5+
"step": [],
6+
"level": "mix",
7+
"async_dump": false,
8+
9+
"statistics": {
10+
"scope": [],
11+
"list": [],
12+
"tensor_list": [],
13+
"data_mode": ["all"],
14+
"summary_mode": "statistics"
15+
}
16+
}

requirements/framework.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ importlib_metadata
1414
jieba
1515
json_repair
1616
matplotlib
17+
mindstudio-probe
1718
modelscope>=1.23
1819
nltk
1920
numpy

swift/megatron/argument/megatron_args.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,10 @@ class ExtraMegatronArguments(RLHFMegatronArgumentsMixin, MegatronTunerMixin):
327327
# qwen3_vl, qwen3_omni
328328
mrope_interleaved: Optional[bool] = None
329329

330+
# dump
331+
enable_msprobe: bool = False
332+
msprobe_config: str = './msprobe_config.json'
333+
330334
@staticmethod
331335
def load_args_config(ckpt_dir: Optional[str]) -> Dict[str, Any]:
332336
res = {}

swift/megatron/trainers/base.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -503,8 +503,17 @@ def _all_reduce_metric(self,
503503
def train_step(self, forward_step_func, data_iterator, model, optimizer, opt_param_scheduler, config, *args,
504504
**kwargs):
505505
new_data_iterator = self._replace_data_iterator(data_iterator, model)
506-
return self._origin_train_step(forward_step_func, new_data_iterator, model, optimizer, opt_param_scheduler,
506+
debugger_on = self.args.enable_msprobe
507+
if debugger_on:
508+
from msprobe.pytorch import PrecisionDebugger
509+
debugger = PrecisionDebugger(config=self.args.msprobe_config, model=model)
510+
debugger.start()
511+
origin_train_step_out = self._origin_train_step(forward_step_func, new_data_iterator, model, optimizer, opt_param_scheduler,
507512
config, *args, **kwargs)
513+
if debugger_on:
514+
debugger.stop()
515+
debugger.step()
516+
return origin_train_step_out
508517

509518
# Code borrowed from NVIDIA/Megatron-LM
510519
def evaluate(

0 commit comments

Comments
 (0)