Skip to content

Commit a0708e0

Browse files
【FlexCP】add load_merge_save api (#74981)
* add load_merge_save api * remove print * remove print * rename * fix * fix
1 parent 91767eb commit a0708e0

File tree

2 files changed

+271
-1
lines changed

2 files changed

+271
-1
lines changed

python/paddle/distributed/flex_checkpoint/dcp/load_state_dict.py

Lines changed: 218 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,15 @@
1515
from __future__ import annotations
1616

1717
import copy
18+
import json
1819
import math
1920
import os
2021
from collections import defaultdict
2122
from dataclasses import dataclass
2223
from typing import TYPE_CHECKING
2324

25+
import numpy as np
26+
2427
import paddle
2528
from paddle.base.framework import (
2629
_current_expected_place,
@@ -1016,7 +1019,7 @@ def _load_state_dict(
10161019
) or all(isinstance(k, tuple) for k in copied_target_state_dict), (
10171020
"target_state_dict contains a mix of tuple and non-tuple keys. Please ensure key types are consistent."
10181021
)
1019-
1022+
logger.info(f"readitem num: {len(read_items)}.")
10201023
for item in read_items:
10211024
if any(isinstance(k, tuple) for k in copied_target_state_dict):
10221025
key = (item.local_tensor_index.tensor_key, item.global_offset)
@@ -1247,3 +1250,217 @@ def load_merged_state_dict(
12471250
key
12481251
) # Add new key and remove the old one
12491252
return state_dict_to_save
1253+
1254+
1255+
def divide_positions(m, n):
1256+
'''
1257+
Divide positions evenly among n processors with a base value and remainder handling.
1258+
1259+
Parameters:
1260+
m (int): Total number of tensor positions.
1261+
n (int): Number of processors.
1262+
1263+
Returns:
1264+
list: A list of positions indicating where to split the tensors among processors.
1265+
1266+
Raises:
1267+
ValueError: If n is zero or if m is less than n.
1268+
'''
1269+
if n == 0:
1270+
raise ValueError("n should be greater than zero")
1271+
if m < n:
1272+
raise ValueError(
1273+
"tensor number should be greater than or equal to processor number"
1274+
)
1275+
base_value = m // n
1276+
remainder = m % n
1277+
positions = [0]
1278+
for i in range(1, n):
1279+
if remainder > 0:
1280+
positions.append(positions[-1] + base_value + 1)
1281+
remainder -= 1
1282+
else:
1283+
positions.append(positions[-1] + base_value)
1284+
positions.append(m)
1285+
return positions
1286+
1287+
1288+
def merge_sharded_state_dict(
1289+
load_path: str,
1290+
save_path: str,
1291+
prefix: str | None = None,
1292+
safetensor_prefix: str = 'model',
1293+
unique_id: int | None = None,
1294+
offload: bool = False,
1295+
aoa_config: dict[str, list[str]] | None = None,
1296+
safetensors: bool = False,
1297+
file_num: int = 1,
1298+
) -> None:
1299+
"""
1300+
Load the distributed checkpoint and merge it to unsharded state_dict then save as safetensors.
1301+
1302+
Note:
1303+
save files are:
1304+
model-00001-of-00008.safetensors
1305+
model-00002-of-00008.safetensors
1306+
...
1307+
model-00008-of-00008.safetensors
1308+
model.safetensors.index.json
1309+
model is safetensor_prefix; 00008 is file_num.
1310+
1311+
Args:
1312+
load_path(str): The directory to load checkpoint files.
1313+
save_path(str): The directory to save merged_checkpoint files.
1314+
prefix(str): The flat_mapping prefix of state_dict key. e.g., 'model', Default None.
1315+
safetensor_prefix(str): The safetensors file prefix e.g., Default 'model'.
1316+
unique_id(int): The unique id of checkpoint, used to distinguish between different checkpoint versions. Default is None, in which case the id the max id of given path, and the newest version checkpoint is loaded.
1317+
offload(bool): Whether to offload the checkpoint data from GPU to CPU, set to True if GPU memory is not enough.
1318+
aoa_config(dict[str, list[str]]): AOA config to change parameters. Default is None.
1319+
safetensors(bool): Whether to use safetensors format. Default is False.
1320+
file_num(int): The number of files to split the merged_checkpoint into.
1321+
Returns:
1322+
None.
1323+
1324+
Example:
1325+
.. code-block:: python
1326+
1327+
>>> # doctest: +SKIP('run in distributed mode.')
1328+
>>> import paddle
1329+
>>> import paddle.distributed as dist
1330+
>>> ckpt_path = "./checkpoint"
1331+
>>> w1 = paddle.arange(32).reshape([4, 8])
1332+
>>> mesh = dist.ProcessMesh([0, 1])
1333+
>>> sharded_w1 = dist.shard_tensor(w1, mesh, [dist.Shard(0)])
1334+
>>> state_dict = {"w1": sharded_w1}
1335+
>>> dist.save_state_dict(state_dict, ckpt_path) # save sharded checkpoint
1336+
1337+
>>> # doctest: +SKIP('run in single-card mode.')
1338+
>>> import paddle
1339+
>>> import paddle.distributed as dist
1340+
>>> ckpt_path = "./checkpoint"
1341+
>>> save_path = "./merged_checkpoint"
1342+
>>> dist.merge_sharded_state_dict(ckpt_path, save_path) # load unsharded and save to safetensors
1343+
>>> # doctest: -SKIP
1344+
"""
1345+
if unique_id is None:
1346+
unique_id = get_max_id(load_path)
1347+
else:
1348+
assert unique_id >= 0, f'{unique_id} should be >= 0'
1349+
1350+
metadata_files, local_data_files = get_checkpoint_files(
1351+
load_path, unique_id=unique_id
1352+
)
1353+
1354+
metadata_list = []
1355+
for file in metadata_files:
1356+
metadata_list.append(paddle.load(os.path.join(load_path, file)))
1357+
1358+
# create target state_dict by local_tensor_meta
1359+
1360+
all_state_dict = []
1361+
state_dict_to_save = {}
1362+
for metadata in metadata_list:
1363+
for (
1364+
tensor_key,
1365+
local_tensor_meta,
1366+
) in metadata.state_dict_metadata.items():
1367+
if prefix is None or tensor_key.startswith(prefix):
1368+
global_shape = compute_global_shape(local_tensor_meta)
1369+
t = paddle.zeros(global_shape, dtype=local_tensor_meta[0].dtype)
1370+
if offload:
1371+
t = t.cpu()
1372+
state_dict_to_save[tensor_key] = t
1373+
else:
1374+
continue
1375+
1376+
def slice_dict(d, start, end):
1377+
"""Slice the dictionary keys and return the corresponding sub-dictionary"""
1378+
keys = list(d.keys())[start:end]
1379+
return {k: d[k] for k in keys}
1380+
1381+
positions = divide_positions(len(state_dict_to_save), file_num)
1382+
all_state_dict = [
1383+
slice_dict(state_dict_to_save, positions[i], positions[i + 1])
1384+
for i in range(file_num)
1385+
]
1386+
1387+
total = sum(len(dict_) for dict_ in all_state_dict)
1388+
assert len(state_dict_to_save) == total, (
1389+
f'split state dict filed :{len(state_dict_to_save)} should seem as {sum}'
1390+
)
1391+
1392+
SaveSafetensor = SavePartialSafetensors(
1393+
save_path, len(all_state_dict), safetensor_prefix
1394+
)
1395+
idx = 0
1396+
for state_dict_to_save in all_state_dict:
1397+
load_state_dict(
1398+
state_dict_to_save,
1399+
load_path,
1400+
offload=offload,
1401+
aoa_config=aoa_config,
1402+
safetensors=safetensors,
1403+
)
1404+
1405+
# Update dictionary keys in place
1406+
for key in list(
1407+
state_dict_to_save.keys()
1408+
): # Use list(data.keys()) to avoid runtime error
1409+
if prefix and key.startswith(prefix):
1410+
new_key = key[len(prefix) + 1 :] # Remove the "str" prefix
1411+
state_dict_to_save[new_key] = state_dict_to_save.pop(
1412+
key
1413+
) # Add new key and remove the old one
1414+
1415+
if paddle.distributed.get_rank() == 0:
1416+
SaveSafetensor.save_single_safetenors(state_dict_to_save, idx)
1417+
idx += 1
1418+
1419+
SaveSafetensor.save_index_json()
1420+
1421+
1422+
class SavePartialSafetensors:
1423+
def __init__(self, output_path, total_files_size, prefix="model"):
1424+
self.output_path = output_path
1425+
self.prefix = prefix
1426+
self.paddle_dtype_map = {
1427+
"paddle.float64": 8,
1428+
"paddle.float32": 4,
1429+
"paddle.float16": 2,
1430+
"paddle.uint16": 2,
1431+
"paddle.bfloat16": 2,
1432+
"paddle.uint8": 1,
1433+
"paddle.float8_e4m3fn": 1,
1434+
"paddle.float8_e5m2": 1,
1435+
}
1436+
self.index = {"metadata": {"total_size": 0}, "weight_map": {}}
1437+
self.safe_index_name = prefix + ".safetensors.index.json"
1438+
self.total_files_size = total_files_size
1439+
1440+
def save_single_safetenors(self, state_dict, rank):
1441+
key_list = state_dict.keys()
1442+
1443+
shard_file = f"{self.prefix}-{rank + 1:05d}-of-{self.total_files_size:05d}.safetensors"
1444+
for key in key_list:
1445+
self.index["weight_map"][key] = shard_file
1446+
self.index["metadata"]["total_size"] += int(
1447+
np.prod(state_dict[key].shape)
1448+
* self.paddle_dtype_map[str(state_dict[key].dtype)]
1449+
)
1450+
1451+
save_file_name = os.path.join(
1452+
self.output_path,
1453+
f"{self.prefix}-{rank + 1:05d}-of-{self.total_files_size:05d}.safetensors",
1454+
)
1455+
logger.info(f"save_file_name = {save_file_name}")
1456+
paddle.framework.io._safe_save(
1457+
state_dict,
1458+
save_file_name,
1459+
)
1460+
1461+
def save_index_json(self):
1462+
save_index_file = os.path.join(self.output_path, self.safe_index_name)
1463+
os.makedirs(os.path.dirname(save_index_file), exist_ok=True)
1464+
with open(save_index_file, "w", encoding="utf-8") as f:
1465+
f.write(json.dumps(self.index, indent=2) + "\n")
1466+
logger.info(f"Model index file saved in {save_index_file}.")

test/auto_parallel/semi_flexcheckpoint_merge.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,59 @@ def test_dist_checkpoint(self):
186186
self.dist_checkpoint(True, False)
187187
self.dist_checkpoint(False, False)
188188

189+
def count_files_in_temp_dir(self, single_path):
190+
if not os.path.exists(single_path):
191+
return 0
192+
files = [
193+
f
194+
for f in os.listdir(single_path)
195+
if os.path.isfile(os.path.join(single_path, f))
196+
]
197+
return len(files)
198+
199+
def test_checkpoint_load_merge_save(self):
200+
model_path = os.path.join(self.temp_dir.name, '/model')
201+
single_path = os.path.join(self.temp_dir.name, '/single_model')
202+
203+
# Test checkpoint saving
204+
with paddle.LazyGuard():
205+
model = DistMlpModel(self.mesh)
206+
for p in model.parameters():
207+
p.initialize()
208+
209+
dataset = RandomDataset(128, 1024)
210+
sampler = BatchSampler(
211+
dataset,
212+
batch_size=4,
213+
)
214+
dataloader = DataLoader(
215+
dataset,
216+
batch_sampler=sampler,
217+
)
218+
opt = paddle.optimizer.AdamW(
219+
learning_rate=0.001, parameters=model.parameters()
220+
)
221+
opt = dist.shard_optimizer(opt)
222+
223+
for step, inputs in enumerate(dataloader):
224+
data = inputs
225+
logits = model(data)
226+
loss = paddle.mean(logits)
227+
loss.backward()
228+
opt.step()
229+
opt.clear_grad()
230+
231+
dist.save_state_dict(model.state_dict(), model_path, safetensors=False)
232+
233+
dist.flex_checkpoint.dcp.load_state_dict.merge_sharded_state_dict(
234+
model_path, single_path, offload=True, safetensors=False, file_num=2
235+
)
236+
assert self.count_files_in_temp_dir(single_path) == 3, (
237+
f"Expected 3 files in temp dir, but got {self.count_files_in_temp_dir()}"
238+
)
239+
self.temp_dir.cleanup()
240+
189241

190242
if __name__ == '__main__':
191243
TestDistCheckpoint().test_dist_checkpoint()
244+
TestDistCheckpoint().test_checkpoint_load_merge_save()

0 commit comments

Comments
 (0)