Skip to content

Commit dcbdeaa

Browse files
change_save_meta_in_zcc_crossponding_to_paddle and fix flags (#11198)
* change_save_meta_in_zcc_crossponding_to_paddle * delete * fix * change_save_meta_in_zcc_crossponding_to_paddle and fix args * fix deadcode --------- Co-authored-by: liufengwei02 <liufengwei02@baidu.com>
1 parent 7ab35ce commit dcbdeaa

File tree

5 files changed

+341
-292
lines changed

5 files changed

+341
-292
lines changed

paddlenlp/trainer/trainer.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -713,7 +713,7 @@ def get_metadata_file_name(path):
713713
offload=self.args.load_via_cpu,
714714
safetensors=True,
715715
process_group=None,
716-
comm_method=self.args.comm_method,
716+
comm_method=self.args.flex_ckpt_comm_method,
717717
)
718718
else:
719719
try:
@@ -755,7 +755,7 @@ def get_metadata_file_name(path):
755755
offload=self.args.load_via_cpu,
756756
safetensors=True,
757757
process_group=process_group,
758-
comm_method=self.args.comm_method,
758+
comm_method=self.args.flex_ckpt_comm_method,
759759
)
760760

761761
dist.barrier()
@@ -801,7 +801,7 @@ def get_metadata_file_name(path):
801801
opt_states_path,
802802
aoa_config=self.args.aoa_config,
803803
offload=self.args.load_via_cpu,
804-
comm_method=self.args.comm_method,
804+
comm_method=self.args.flex_ckpt_comm_method,
805805
)
806806

807807
if not self.args.sharded_model_from_ema:
@@ -810,7 +810,7 @@ def get_metadata_file_name(path):
810810
master_weights_path,
811811
aoa_config=self.args.aoa_config,
812812
offload=self.args.load_via_cpu,
813-
comm_method=self.args.comm_method,
813+
comm_method=self.args.flex_ckpt_comm_method,
814814
)
815815

816816
self._load_scheduler(resume_from_checkpoint)
@@ -851,7 +851,7 @@ def bf16_filtered_sharded_state_dict(sharded_state_dict):
851851
model_states_path,
852852
aoa_config=self.args.aoa_config,
853853
offload=self.args.load_via_cpu,
854-
comm_method=self.args.comm_method,
854+
comm_method=self.args.flex_ckpt_comm_method,
855855
)
856856

857857
if self.args.bf16 and (not self.args.ignore_load_lr_and_optim) and should_load_stage1:

paddlenlp/trainer/trainer_utils.py

Lines changed: 1 addition & 267 deletions
Original file line numberDiff line numberDiff line change
@@ -26,18 +26,16 @@
2626
import math
2727
import os
2828
import random
29-
import re
3029
import threading
3130
import time
3231
from contextlib import contextmanager
3332
from enum import Enum
3433
from pathlib import Path
35-
from typing import Dict, Iterator, List, NamedTuple, Optional, Tuple, Union
34+
from typing import Dict, List, NamedTuple, Optional, Tuple, Union
3635

3736
import numpy as np
3837
import paddle
3938
import paddle.distributed as dist
40-
from paddle import Tensor
4139
from paddle.distributed import fleet
4240
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer import (
4341
DygraphShardingOptimizer,
@@ -46,8 +44,6 @@
4644
from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker
4745
from paddle.io import IterableDataset
4846
from paddle.optimizer.lr import LambdaDecay
49-
from safetensors import safe_open
50-
from safetensors.paddle import save_file
5147

5248
from paddlenlp.ops import Topology
5349

@@ -1449,265 +1445,3 @@ def buffer_params():
14491445
continue
14501446
param_list.append(param)
14511447
optimizer._create_accumulators(paddle.base.framework.default_main_program().global_block(), param_list)
1452-
1453-
1454-
def _parse_size(size_str: str) -> int:
1455-
"""Parses a size string like '100MB', '2GB' into the number of bytes."""
1456-
size_str = size_str.upper().strip()
1457-
match = re.match(r"^(\d+\.?\d*)\s*(B|KB|MB|GB|TB)?$", size_str)
1458-
if not match:
1459-
raise ValueError(f"Could not parse size string: '{size_str}'")
1460-
1461-
num_str, unit = match.groups()
1462-
num = float(num_str)
1463-
1464-
if unit == "B" or unit is None:
1465-
return int(num)
1466-
elif unit == "KB":
1467-
return int(num * 1024)
1468-
elif unit == "MB":
1469-
return int(num * 1024**2)
1470-
elif unit == "GB":
1471-
return int(num * 1024**3)
1472-
elif unit == "TB":
1473-
return int(num * 1024**4)
1474-
else:
1475-
# This case should not be reached due to regex
1476-
raise ValueError(f"Unknown unit: '{unit}'")
1477-
1478-
1479-
def save_full_param(
1480-
itr: Iterator[tuple[str, Tensor]],
1481-
save_dir: str,
1482-
rank: int,
1483-
moe_sharding_world_size: int,
1484-
max_shard_size: str = "2GB",
1485-
num_saver_ranks: int = 8,
1486-
) -> None:
1487-
"""
1488-
Saves model weights from an iterator into shards, supporting max shard size
1489-
and a limited number of saver ranks.
1490-
1491-
Only ranks less than `num_saver_ranks` will perform disk I/O. All other ranks
1492-
will iterate through the data to maintain synchronization but will not save.
1493-
The parameter distribution logic is based on `num_saver_ranks`, ensuring all
1494-
parameters are handled by a designated saver rank.
1495-
1496-
Args:
1497-
itr (Iterator): An iterator that yields (param_key, param_tensor).
1498-
save_dir (str): The directory where shard files will be saved.
1499-
rank (int): The rank of the current process.
1500-
moe_sharding_world_size (int): The total number of processes.
1501-
max_shard_size (str): The maximum size for each shard file, e.g., "500MB", "2GB".
1502-
num_saver_ranks (int): The number of ranks (starting from 0) that will save files.
1503-
"""
1504-
1505-
# 1. Non-saver ranks simply consume the iterator to stay in sync.
1506-
if rank >= num_saver_ranks:
1507-
logger.info(f"[Rank {rank}/{moe_sharding_world_size}] (Non-saver) Consuming iterator for synchronization...")
1508-
for _ in itr:
1509-
pass
1510-
logger.info(f"[Rank {rank}/{moe_sharding_world_size}] (Non-saver) Iterator consumption complete.")
1511-
return
1512-
1513-
max_shard_size_bytes = _parse_size(max_shard_size)
1514-
logger.info(
1515-
f"[Rank {rank}/{moe_sharding_world_size}] (Saver) Initializing save. "
1516-
f"Max shard size set to: {max_shard_size_bytes / 1024**3:.2f} GB"
1517-
)
1518-
1519-
os.makedirs(save_dir, exist_ok=True)
1520-
1521-
current_shard_state_dict = {}
1522-
current_shard_size_bytes = 0
1523-
sub_shard_index = 0
1524-
1525-
def _save_current_shard():
1526-
nonlocal sub_shard_index, current_shard_state_dict, current_shard_size_bytes
1527-
if not current_shard_state_dict:
1528-
return
1529-
1530-
# Filename includes the main shard number (rank) and the sub-shard index
1531-
cur_rank = paddle.distributed.get_rank()
1532-
shard_filename = f"shard_{cur_rank}-{sub_shard_index}.safetensors"
1533-
save_path = os.path.join(save_dir, shard_filename)
1534-
1535-
logger.info(
1536-
f"[Rank {rank}/{moe_sharding_world_size}] Saving sub-shard {sub_shard_index}... "
1537-
f"Size: {current_shard_size_bytes / 1024**2:.2f} MB, "
1538-
f"Params: {len(current_shard_state_dict)}, "
1539-
f"Path: {save_path}"
1540-
)
1541-
1542-
save_file(current_shard_state_dict, save_path)
1543-
1544-
# Reset for the next shard
1545-
sub_shard_index += 1
1546-
current_shard_state_dict = {}
1547-
current_shard_size_bytes = 0
1548-
1549-
logger.info(f"[Rank {rank}/{moe_sharding_world_size}] Starting to process the weight iterator...")
1550-
1551-
total_size = 0
1552-
1553-
for i, (param_key, param) in enumerate(itr):
1554-
param_size_bytes = param.numel() * param.element_size()
1555-
total_size += param_size_bytes.item()
1556-
if i % num_saver_ranks == rank:
1557-
if current_shard_size_bytes > 0 and (current_shard_size_bytes + param_size_bytes > max_shard_size_bytes):
1558-
_save_current_shard()
1559-
1560-
current_shard_state_dict[param_key] = param
1561-
current_shard_size_bytes += param_size_bytes
1562-
1563-
if current_shard_size_bytes >= max_shard_size_bytes:
1564-
_save_current_shard()
1565-
_save_current_shard()
1566-
logger.info(f"[Rank {rank}/{moe_sharding_world_size}] (Saver) All shards saved successfully.")
1567-
return total_size
1568-
1569-
1570-
def replace_name_and_gen_index(path, total_size):
1571-
index_mapping = {}
1572-
cur_rank = paddle.distributed.get_rank()
1573-
safetensor_files = [fname for fname in os.listdir(path) if fname.endswith(".safetensors")]
1574-
files_num = len(safetensor_files)
1575-
all_files_num = []
1576-
if paddle.distributed.get_world_size() > 1:
1577-
paddle.distributed.all_gather_object(all_files_num, files_num)
1578-
else:
1579-
all_files_num.append(files_num)
1580-
total_files_num = sum(all_files_num)
1581-
1582-
start_idx = []
1583-
acc = 1
1584-
for files_num in all_files_num:
1585-
start_idx.append(acc)
1586-
acc += files_num
1587-
1588-
env_local_size = int(os.environ.get("PADDLE_LOCAL_SIZE", 8))
1589-
env_local_rank = dist.get_rank() % env_local_size
1590-
assert env_local_rank >= 0, f"expected positive local rank, got {env_local_rank}"
1591-
1592-
cur_file_index = start_idx[cur_rank] // env_local_size
1593-
total_files_num = total_files_num // env_local_size
1594-
1595-
index_mapping = {}
1596-
if env_local_rank == 0:
1597-
for file in safetensor_files:
1598-
cur_file_index += 1
1599-
file_path = os.path.join(path, file)
1600-
new_file_name = f"model-{cur_file_index:05d}-of-{total_files_num:05d}.safetensors"
1601-
with safe_open(file_path, framework="np") as f:
1602-
for key in f.keys():
1603-
index_mapping[key] = new_file_name
1604-
new_file_path = os.path.join(path, new_file_name)
1605-
os.rename(file_path, new_file_path)
1606-
1607-
index_mapping_list = []
1608-
if paddle.distributed.get_world_size() > 1:
1609-
paddle.distributed.all_gather_object(index_mapping_list, index_mapping)
1610-
else:
1611-
index_mapping_list.append(index_mapping)
1612-
index_mapping = {}
1613-
for mapping in index_mapping_list:
1614-
index_mapping.update(mapping)
1615-
1616-
# Save signal file for each card
1617-
saved_signal_path = os.path.join(path, f"saved_signal_{dist.get_rank()}")
1618-
with open(saved_signal_path, mode="w+") as f:
1619-
f.write("1")
1620-
1621-
if env_local_rank == 0:
1622-
index_file_name = "model.safetensors.index.json"
1623-
index_infos = {}
1624-
index_infos["metadata"] = {}
1625-
index_infos["metadata"]["total_size"] = total_size
1626-
index_infos["weight_map"] = dict(sorted(index_mapping.items()))
1627-
with open(os.path.join(path, index_file_name), "w") as f:
1628-
json.dump(index_infos, f, indent=4)
1629-
1630-
# For PDC signal
1631-
if strtobool(os.getenv("FLAG_LLM_PDC", "False")):
1632-
for i in range(paddle.distributed.get_world_size()):
1633-
saved_signal_path = os.path.join(path, f".model_weights.done.{i}")
1634-
paddle.save(i, saved_signal_path)
1635-
1636-
1637-
class HFFormatFullParamSaver:
1638-
def __init__(
1639-
self,
1640-
model,
1641-
aoa_config,
1642-
h_group=None,
1643-
v_group=None,
1644-
num_splits=None,
1645-
shard_idx=None,
1646-
saved_in_one_node=False,
1647-
memory_growth_threshold=8 * (2**30),
1648-
):
1649-
self.model = model
1650-
self.aoa_config = aoa_config
1651-
self.h_group = h_group
1652-
self.v_group = v_group
1653-
self.num_splits = num_splits
1654-
self.shard_idx = shard_idx
1655-
self.saved_in_one_node = saved_in_one_node
1656-
self.memory_growth_threshold = memory_growth_threshold
1657-
self.determin_saver_based_group()
1658-
1659-
def get_full_param_iter(self):
1660-
assert (self.v_group and self.h_group) or not (
1661-
self.v_group or self.h_group
1662-
), f"both h_group and v_group are provided or none of them, but got {self.v_group} and {self.h_group}"
1663-
if self.v_group and self.h_group:
1664-
assert self.shard_idx is not None, "expected shard_idx is not None"
1665-
assert self.num_splits is not None, "expected num_splits is not None"
1666-
1667-
param_iter = self.model.full(
1668-
aoa_config=self.aoa_config,
1669-
h_group=self.h_group,
1670-
v_group=self.v_group,
1671-
num_splits=self.num_splits,
1672-
shard_idx=self.shard_idx,
1673-
memory_growth_threshold=self.memory_growth_threshold,
1674-
)
1675-
else:
1676-
param_iter = self.model.full(
1677-
aoa_config=self.aoa_config, memory_growth_threshold=self.memory_growth_threshold
1678-
)
1679-
return param_iter
1680-
1681-
def determin_saver_based_group(self):
1682-
self.num_saver_ranks = paddle.distributed.get_world_size()
1683-
self.rank = paddle.distributed.get_rank()
1684-
1685-
if self.h_group and self.v_group:
1686-
self.num_saver_ranks = self.h_group.nranks * self.v_group.nranks
1687-
self.rank = self.h_group.rank + self.v_group.rank * self.h_group.nranks
1688-
1689-
if self.saved_in_one_node:
1690-
local_world_size = int(os.environ.get("PADDLE_LOCAL_SIZE", 8))
1691-
self.num_saver_ranks = min(local_world_size, self.num_saver_ranks)
1692-
1693-
def save_checkpoint(self, path, max_shard_size="16GB"):
1694-
total_saved_size = save_full_param(
1695-
itr=self.get_full_param_iter(),
1696-
save_dir=path,
1697-
rank=self.rank,
1698-
moe_sharding_world_size=self.num_saver_ranks,
1699-
max_shard_size=max_shard_size,
1700-
num_saver_ranks=self.num_saver_ranks,
1701-
)
1702-
if paddle.distributed.get_world_size() > 1:
1703-
paddle.distributed.barrier()
1704-
1705-
# TODO(): fix total size
1706-
all_sizes = []
1707-
if paddle.distributed.get_world_size() > 1:
1708-
paddle.distributed.all_gather_object(all_sizes, total_saved_size)
1709-
else:
1710-
all_sizes.append(total_saved_size)
1711-
total_size = sum(all_sizes)
1712-
replace_name_and_gen_index(path, total_size)
1713-
return total_saved_size

paddlenlp/trainer/training_args.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -422,10 +422,14 @@ class TrainingArguments:
422422
load_from_hf (bool, optional):
423423
Whether to load a checkpoint in the HuggingFace format.
424424
Defaults to False.
425-
comm_method (str, optional):
425+
flex_ckpt_comm_method (str, optional):
426426
Communication method used for checkpoint resharding.
427427
Choices are "send_recv", "broadcast", "multi_group_broadcast", and "grouped_send_recv".
428428
Defaults to "broadcast".
429+
replicate_saved_into_local (bool, optional):
430+
Whether to save checkpoint replicas into local files in a distributed save/load system.
431+
If set to True, replicas will be stored locally on each node/machine.
432+
Defaults to False.
429433
"""
430434

431435
output_dir: str = field(
@@ -1169,17 +1173,22 @@ class TrainingArguments:
11691173
metadata={"help": "Whether to load a checkpoint in the HuggingFace format."},
11701174
)
11711175

1172-
comm_method: Optional[str] = field(
1176+
flex_ckpt_comm_method: Optional[str] = field(
11731177
default="broadcast",
11741178
metadata={
11751179
"help": (
1176-
"Communication method for checkpoint resharding. "
1180+
"Communication method used by FlexCheckpoint for checkpoint resharding. "
11771181
'Choices are "send_recv", "broadcast", "multi_group_broadcast", and "grouped_send_recv". '
11781182
'Default is "broadcast".'
11791183
)
11801184
},
11811185
)
11821186

1187+
replicate_saved_into_local: Optional[bool] = field(
1188+
default=False,
1189+
metadata={"help": "Whether to save replicas cross files in distributed save load system."},
1190+
)
1191+
11831192
def __post_init__(self):
11841193
world_size = paddle.distributed.get_world_size()
11851194
if in_auto_parallel_align_mode():

0 commit comments

Comments
 (0)