Skip to content

Commit 5b1da6b

Browse files
authored
[megatron] write use last rank (#5324)
1 parent fc2bc3d commit 5b1da6b

File tree

6 files changed

+35
-16
lines changed

6 files changed

+35
-16
lines changed

swift/llm/template/template/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def _swift_prepare_inputs(self, inputs):
4040
messages = inputs.messages
4141
# Only during inference or training, and only if the loss_scale is set to 'last_round',
4242
# will the previous 'think' entries be deleted.
43-
if not self.is_training or self.loss_scale.name == 'last_round':
43+
if not self.is_training or self.loss_scale.name in {'last_round', 'last_round_with_ignore_empty_think'}:
4444
for i, message in enumerate(messages):
4545
# Delete the content before '</think>' in all assistant turns except the last round.
4646
if message['role'] == 'assistant' and isinstance(message['content'], str) and i != len(messages) - 1:

swift/megatron/init.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from datetime import datetime
99
from typing import List, Optional, Tuple
1010

11-
import numpy as np
1211
import peft
1312
import torch
1413
import torch.nn as nn
@@ -17,8 +16,8 @@
1716
from tqdm import tqdm
1817

1918
from swift.llm import git_clone_github
20-
from swift.utils import (JsonlWriter, format_time, get_logger, is_flash_attn_3_available, is_master,
21-
is_megatron_available, safe_ddp_context, split_list, subprocess_run)
19+
from swift.utils import (JsonlWriter, format_time, get_logger, is_flash_attn_3_available, is_megatron_available,
20+
safe_ddp_context, split_list, subprocess_run)
2221

2322
logger = get_logger()
2423

@@ -75,10 +74,10 @@ def training_log(loss_dict, total_loss_dict, learning_rate, decoupled_learning_r
7574
"""Log training information such as losses, timing, ...."""
7675
nonlocal jsonl_writer
7776
args = get_args()
78-
if is_master() and jsonl_writer is None:
77+
if jsonl_writer is None:
7978
logging_path = os.path.join(args.save, 'logging.jsonl')
8079
logger.info(f'logging_path: {logging_path}')
81-
jsonl_writer = JsonlWriter(logging_path, enable_async=True)
80+
jsonl_writer = JsonlWriter(logging_path, enable_async=True, write_on_rank='last')
8281
timers = get_timers()
8382
writer = get_tensorboard_writer()
8483
wandb_writer = get_wandb_writer()
@@ -300,7 +299,7 @@ def training_log(loss_dict, total_loss_dict, learning_rate, decoupled_learning_r
300299
report_memory_flag = False
301300
timers.log(timers_to_log, normalizer=args.log_interval)
302301

303-
if is_master():
302+
if is_last_rank():
304303
logs = {}
305304
for key in origin_total_loss_dict:
306305
if key not in [advanced_iters_key, skipped_iters_key, nan_iters_key]:
@@ -819,6 +818,9 @@ def _patch_megatron():
819818
except Exception:
820819
pass
821820

821+
import megatron.core
822+
logger.info(f'megatron.core.__version__: {megatron.core.__version__}')
823+
822824

823825
def init_megatron_env() -> None:
824826
if 'MEGATRON_LM_PATH' not in os.environ:

swift/megatron/trainers/base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from megatron.training.checkpointing import load_checkpoint
2121
from packaging import version
2222

23-
from swift.utils import JsonlWriter, deep_getattr, get_logger, is_master
23+
from swift.utils import JsonlWriter, deep_getattr, get_logger
2424
from ..utils import adapter_state_dict_context, copy_original_module_weight, prepare_mcore_model
2525
from .utils import get_swift_datasets_provider
2626

@@ -34,7 +34,7 @@ def __init__(self, args):
3434
self.stimer = StragglerDetector()
3535
logging_path = os.path.join(args.save, 'logging.jsonl')
3636
logger.info(f'logging_path: {logging_path}')
37-
self.jsonl_writer = JsonlWriter(logging_path, enable_async=True)
37+
self.jsonl_writer = JsonlWriter(logging_path, enable_async=True, write_on_rank='last') # for evaluate
3838
self._patch_megatron()
3939

4040
@contextmanager
@@ -372,7 +372,7 @@ def evaluate(self,
372372
timers.log(['evaluate'])
373373

374374
rerun_state_machine.set_mode(rerun_mode)
375-
if is_master():
375+
if is_last_rank():
376376
logs = {}
377377
for key, val in total_loss_dict.items():
378378
logs[f'eval_{key}'] = round(val.item(), 8)

swift/utils/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Copyright (c) Alibaba, Inc. and its affiliates.
22

33
from .env import (get_dist_setting, get_hf_endpoint, get_node_setting, get_pai_tensorboard_dir, is_deepspeed_enabled,
4-
is_dist, is_local_master, is_master, is_mp, is_mp_ddp, is_pai_training_job, use_hf_hub)
4+
is_dist, is_last_rank, is_local_master, is_master, is_mp, is_mp_ddp, is_pai_training_job, use_hf_hub)
55
from .import_utils import (is_flash_attn_2_available, is_flash_attn_3_available, is_liger_available,
66
is_lmdeploy_available, is_megatron_available, is_swanlab_available, is_trl_available,
77
is_unsloth_available, is_vllm_ascend_available, is_vllm_available, is_wandb_available)

swift/utils/env.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,11 @@ def is_master():
5252
return rank in {-1, 0}
5353

5454

55+
def is_last_rank():
56+
rank, _, world_size, _ = get_dist_setting()
57+
return rank in {-1, world_size - 1}
58+
59+
5560
def is_dist():
5661
"""Determine if the training is distributed"""
5762
rank, local_rank, _, _ = get_dist_setting()

swift/utils/io_utils.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from modelscope.hub.api import ModelScopeConfig
1212
from tqdm import tqdm
1313

14-
from .env import is_master
14+
from .env import is_last_rank, is_master
1515
from .logger import get_logger
1616
from .utils import check_json_format
1717

@@ -46,8 +46,20 @@ def write_to_jsonl(fpath: str, obj_list: List[Any], encoding: str = 'utf-8') ->
4646

4747
class JsonlWriter:
4848

49-
def __init__(self, fpath: str, *, encoding: str = 'utf-8', strict: bool = True, enable_async: bool = False):
50-
self.fpath = os.path.abspath(os.path.expanduser(fpath)) if is_master() else None
49+
def __init__(self,
50+
fpath: str,
51+
*,
52+
encoding: str = 'utf-8',
53+
strict: bool = True,
54+
enable_async: bool = False,
55+
write_on_rank: Literal['master', 'last'] = 'master'):
56+
if write_on_rank == 'master':
57+
self.is_write_rank = is_master()
58+
elif write_on_rank == 'last':
59+
self.is_write_rank = is_last_rank()
60+
else:
61+
raise ValueError(f"Invalid `write_on_rank`: {write_on_rank}, should be 'master' or 'last'")
62+
self.fpath = os.path.abspath(os.path.expanduser(fpath)) if self.is_write_rank else None
5163
self.encoding = encoding
5264
self.strict = strict
5365
self.enable_async = enable_async
@@ -66,7 +78,7 @@ def _append(self, obj: Union[Dict, List[Dict]], gather_obj: bool = False):
6678
obj_list = [obj]
6779
if gather_obj and dist.is_initialized():
6880
obj_list = gather_object(obj_list)
69-
if not is_master():
81+
if not self.is_write_rank:
7082
return
7183
obj_list = check_json_format(obj_list)
7284
for i, _obj in enumerate(obj_list):
@@ -85,7 +97,7 @@ def append(self, obj: Union[Dict, List[Dict]], gather_obj: bool = False):
8597
def _write_buffer(self, text: str):
8698
if not text:
8799
return
88-
assert is_master(), f'is_master(): {is_master()}'
100+
assert self.is_write_rank, f'self.is_write_rank: {self.is_write_rank}'
89101
try:
90102
os.makedirs(os.path.dirname(self.fpath), exist_ok=True)
91103
with open(self.fpath, 'a', encoding=self.encoding) as f:

0 commit comments

Comments
 (0)