Skip to content

Commit 0a2247e

Browse files
committed
fix device_map & ddp rank0 (#4650)
1 parent 3317016 commit 0a2247e

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

swift/llm/model/patcher.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from swift.llm import deep_getattr, to_device, to_float_dtype
2020
from swift.utils import get_dist_setting, get_logger, is_mp_ddp, safe_ddp_context, use_torchacc
2121
from swift.utils.torch_utils import _get_max_memory, _sync_max_memory, get_device_count
22-
from .model_arch import get_model_arch
2322
from .utils import HfConfigFactory, get_llm_model
2423

2524
logger = get_logger()
@@ -354,6 +353,7 @@ def patch_tp_plan(load_model: bool):
354353
transformers.__version__) < version.parse('4.50') or 'WORLD_SIZE' not in os.environ:
355354
yield
356355
return
356+
logger.info('Patch tp_plan.')
357357
WORLD_SIZE = os.environ.get('WORLD_SIZE')
358358
os.environ['_PATCH_WORLD_SIZE'] = WORLD_SIZE
359359
os.environ.pop('WORLD_SIZE')

swift/utils/torch_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -368,8 +368,8 @@ def from_tensor(obj):
368368

369369
def set_default_ddp_config():
370370
# It runs normally with Python as well.
371-
rank = int(os.getenv('RANK', -1))
372-
if rank == -1:
371+
rank, local_rank, _, _ = get_dist_setting()
372+
if rank == -1 or local_rank == -1:
373373
os.environ['NPROC_PER_NODE'] = '1'
374374
os.environ['RANK'] = '0'
375375
os.environ['LOCAL_RANK'] = '0'

0 commit comments

Comments
 (0)