|
15 | 15 | from __future__ import annotations
|
16 | 16 |
|
17 | 17 | import copy
|
| 18 | +import json |
18 | 19 | import math
|
19 | 20 | import os
|
20 | 21 | from collections import defaultdict
|
21 | 22 | from dataclasses import dataclass
|
22 | 23 | from typing import TYPE_CHECKING
|
23 | 24 |
|
| 25 | +import numpy as np |
| 26 | + |
24 | 27 | import paddle
|
25 | 28 | from paddle.base.framework import (
|
26 | 29 | _current_expected_place,
|
@@ -1016,7 +1019,7 @@ def _load_state_dict(
|
1016 | 1019 | ) or all(isinstance(k, tuple) for k in copied_target_state_dict), (
|
1017 | 1020 | "target_state_dict contains a mix of tuple and non-tuple keys. Please ensure key types are consistent."
|
1018 | 1021 | )
|
1019 |
| - |
| 1022 | + logger.info(f"readitem num: {len(read_items)}.") |
1020 | 1023 | for item in read_items:
|
1021 | 1024 | if any(isinstance(k, tuple) for k in copied_target_state_dict):
|
1022 | 1025 | key = (item.local_tensor_index.tensor_key, item.global_offset)
|
@@ -1247,3 +1250,217 @@ def load_merged_state_dict(
|
1247 | 1250 | key
|
1248 | 1251 | ) # Add new key and remove the old one
|
1249 | 1252 | return state_dict_to_save
|
| 1253 | + |
| 1254 | + |
| 1255 | +def divide_positions(m, n): |
| 1256 | + ''' |
| 1257 | + Divide positions evenly among n processors with a base value and remainder handling. |
| 1258 | +
|
| 1259 | + Parameters: |
| 1260 | + m (int): Total number of tensor positions. |
| 1261 | + n (int): Number of processors. |
| 1262 | +
|
| 1263 | + Returns: |
| 1264 | + list: A list of positions indicating where to split the tensors among processors. |
| 1265 | +
|
| 1266 | + Raises: |
| 1267 | + ValueError: If n is zero or if m is less than n. |
| 1268 | + ''' |
| 1269 | + if n == 0: |
| 1270 | + raise ValueError("n should be greater than zero") |
| 1271 | + if m < n: |
| 1272 | + raise ValueError( |
| 1273 | + "tensor number should be greater than or equal to processor number" |
| 1274 | + ) |
| 1275 | + base_value = m // n |
| 1276 | + remainder = m % n |
| 1277 | + positions = [0] |
| 1278 | + for i in range(1, n): |
| 1279 | + if remainder > 0: |
| 1280 | + positions.append(positions[-1] + base_value + 1) |
| 1281 | + remainder -= 1 |
| 1282 | + else: |
| 1283 | + positions.append(positions[-1] + base_value) |
| 1284 | + positions.append(m) |
| 1285 | + return positions |
| 1286 | + |
| 1287 | + |
| 1288 | +def merge_sharded_state_dict( |
| 1289 | + load_path: str, |
| 1290 | + save_path: str, |
| 1291 | + prefix: str | None = None, |
| 1292 | + safetensor_prefix: str = 'model', |
| 1293 | + unique_id: int | None = None, |
| 1294 | + offload: bool = False, |
| 1295 | + aoa_config: dict[str, list[str]] | None = None, |
| 1296 | + safetensors: bool = False, |
| 1297 | + file_num: int = 1, |
| 1298 | +) -> None: |
| 1299 | + """ |
| 1300 | + Load the distributed checkpoint and merge it to unsharded state_dict then save as safetensors. |
| 1301 | +
|
| 1302 | + Note: |
| 1303 | + save files are: |
| 1304 | + model-00001-of-00008.safetensors |
| 1305 | + model-00002-of-00008.safetensors |
| 1306 | + ... |
| 1307 | + model-00008-of-00008.safetensors |
| 1308 | + model.safetensors.index.json |
| 1309 | + model is safetensor_prefix; 00008 is file_num. |
| 1310 | +
|
| 1311 | + Args: |
| 1312 | + load_path(str): The directory to load checkpoint files. |
| 1313 | + save_path(str): The directory to save merged_checkpoint files. |
| 1314 | + prefix(str): The flat_mapping prefix of state_dict key. e.g., 'model', Default None. |
| 1315 | + safetensor_prefix(str): The safetensors file prefix e.g., Default 'model'. |
| 1316 | + unique_id(int): The unique id of checkpoint, used to distinguish between different checkpoint versions. Default is None, in which case the id the max id of given path, and the newest version checkpoint is loaded. |
| 1317 | + offload(bool): Whether to offload the checkpoint data from GPU to CPU, set to True if GPU memory is not enough. |
| 1318 | + aoa_config(dict[str, list[str]]): AOA config to change parameters. Default is None. |
| 1319 | + safetensors(bool): Whether to use safetensors format. Default is False. |
| 1320 | + file_num(int): The number of files to split the merged_checkpoint into. |
| 1321 | + Returns: |
| 1322 | + None. |
| 1323 | +
|
| 1324 | + Example: |
| 1325 | + .. code-block:: python |
| 1326 | +
|
| 1327 | + >>> # doctest: +SKIP('run in distributed mode.') |
| 1328 | + >>> import paddle |
| 1329 | + >>> import paddle.distributed as dist |
| 1330 | + >>> ckpt_path = "./checkpoint" |
| 1331 | + >>> w1 = paddle.arange(32).reshape([4, 8]) |
| 1332 | + >>> mesh = dist.ProcessMesh([0, 1]) |
| 1333 | + >>> sharded_w1 = dist.shard_tensor(w1, mesh, [dist.Shard(0)]) |
| 1334 | + >>> state_dict = {"w1": sharded_w1} |
| 1335 | + >>> dist.save_state_dict(state_dict, ckpt_path) # save sharded checkpoint |
| 1336 | +
|
| 1337 | + >>> # doctest: +SKIP('run in single-card mode.') |
| 1338 | + >>> import paddle |
| 1339 | + >>> import paddle.distributed as dist |
| 1340 | + >>> ckpt_path = "./checkpoint" |
| 1341 | + >>> save_path = "./merged_checkpoint" |
| 1342 | + >>> dist.merge_sharded_state_dict(ckpt_path, save_path) # load unsharded and save to safetensors |
| 1343 | + >>> # doctest: -SKIP |
| 1344 | + """ |
| 1345 | + if unique_id is None: |
| 1346 | + unique_id = get_max_id(load_path) |
| 1347 | + else: |
| 1348 | + assert unique_id >= 0, f'{unique_id} should be >= 0' |
| 1349 | + |
| 1350 | + metadata_files, local_data_files = get_checkpoint_files( |
| 1351 | + load_path, unique_id=unique_id |
| 1352 | + ) |
| 1353 | + |
| 1354 | + metadata_list = [] |
| 1355 | + for file in metadata_files: |
| 1356 | + metadata_list.append(paddle.load(os.path.join(load_path, file))) |
| 1357 | + |
| 1358 | + # create target state_dict by local_tensor_meta |
| 1359 | + |
| 1360 | + all_state_dict = [] |
| 1361 | + state_dict_to_save = {} |
| 1362 | + for metadata in metadata_list: |
| 1363 | + for ( |
| 1364 | + tensor_key, |
| 1365 | + local_tensor_meta, |
| 1366 | + ) in metadata.state_dict_metadata.items(): |
| 1367 | + if prefix is None or tensor_key.startswith(prefix): |
| 1368 | + global_shape = compute_global_shape(local_tensor_meta) |
| 1369 | + t = paddle.zeros(global_shape, dtype=local_tensor_meta[0].dtype) |
| 1370 | + if offload: |
| 1371 | + t = t.cpu() |
| 1372 | + state_dict_to_save[tensor_key] = t |
| 1373 | + else: |
| 1374 | + continue |
| 1375 | + |
| 1376 | + def slice_dict(d, start, end): |
| 1377 | + """Slice the dictionary keys and return the corresponding sub-dictionary""" |
| 1378 | + keys = list(d.keys())[start:end] |
| 1379 | + return {k: d[k] for k in keys} |
| 1380 | + |
| 1381 | + positions = divide_positions(len(state_dict_to_save), file_num) |
| 1382 | + all_state_dict = [ |
| 1383 | + slice_dict(state_dict_to_save, positions[i], positions[i + 1]) |
| 1384 | + for i in range(file_num) |
| 1385 | + ] |
| 1386 | + |
| 1387 | + total = sum(len(dict_) for dict_ in all_state_dict) |
| 1388 | + assert len(state_dict_to_save) == total, ( |
| 1389 | + f'split state dict filed :{len(state_dict_to_save)} should seem as {sum}' |
| 1390 | + ) |
| 1391 | + |
| 1392 | + SaveSafetensor = SavePartialSafetensors( |
| 1393 | + save_path, len(all_state_dict), safetensor_prefix |
| 1394 | + ) |
| 1395 | + idx = 0 |
| 1396 | + for state_dict_to_save in all_state_dict: |
| 1397 | + load_state_dict( |
| 1398 | + state_dict_to_save, |
| 1399 | + load_path, |
| 1400 | + offload=offload, |
| 1401 | + aoa_config=aoa_config, |
| 1402 | + safetensors=safetensors, |
| 1403 | + ) |
| 1404 | + |
| 1405 | + # Update dictionary keys in place |
| 1406 | + for key in list( |
| 1407 | + state_dict_to_save.keys() |
| 1408 | + ): # Use list(data.keys()) to avoid runtime error |
| 1409 | + if prefix and key.startswith(prefix): |
| 1410 | + new_key = key[len(prefix) + 1 :] # Remove the "str" prefix |
| 1411 | + state_dict_to_save[new_key] = state_dict_to_save.pop( |
| 1412 | + key |
| 1413 | + ) # Add new key and remove the old one |
| 1414 | + |
| 1415 | + if paddle.distributed.get_rank() == 0: |
| 1416 | + SaveSafetensor.save_single_safetenors(state_dict_to_save, idx) |
| 1417 | + idx += 1 |
| 1418 | + |
| 1419 | + SaveSafetensor.save_index_json() |
| 1420 | + |
| 1421 | + |
| 1422 | +class SavePartialSafetensors: |
| 1423 | + def __init__(self, output_path, total_files_size, prefix="model"): |
| 1424 | + self.output_path = output_path |
| 1425 | + self.prefix = prefix |
| 1426 | + self.paddle_dtype_map = { |
| 1427 | + "paddle.float64": 8, |
| 1428 | + "paddle.float32": 4, |
| 1429 | + "paddle.float16": 2, |
| 1430 | + "paddle.uint16": 2, |
| 1431 | + "paddle.bfloat16": 2, |
| 1432 | + "paddle.uint8": 1, |
| 1433 | + "paddle.float8_e4m3fn": 1, |
| 1434 | + "paddle.float8_e5m2": 1, |
| 1435 | + } |
| 1436 | + self.index = {"metadata": {"total_size": 0}, "weight_map": {}} |
| 1437 | + self.safe_index_name = prefix + ".safetensors.index.json" |
| 1438 | + self.total_files_size = total_files_size |
| 1439 | + |
| 1440 | + def save_single_safetenors(self, state_dict, rank): |
| 1441 | + key_list = state_dict.keys() |
| 1442 | + |
| 1443 | + shard_file = f"{self.prefix}-{rank + 1:05d}-of-{self.total_files_size:05d}.safetensors" |
| 1444 | + for key in key_list: |
| 1445 | + self.index["weight_map"][key] = shard_file |
| 1446 | + self.index["metadata"]["total_size"] += int( |
| 1447 | + np.prod(state_dict[key].shape) |
| 1448 | + * self.paddle_dtype_map[str(state_dict[key].dtype)] |
| 1449 | + ) |
| 1450 | + |
| 1451 | + save_file_name = os.path.join( |
| 1452 | + self.output_path, |
| 1453 | + f"{self.prefix}-{rank + 1:05d}-of-{self.total_files_size:05d}.safetensors", |
| 1454 | + ) |
| 1455 | + logger.info(f"save_file_name = {save_file_name}") |
| 1456 | + paddle.framework.io._safe_save( |
| 1457 | + state_dict, |
| 1458 | + save_file_name, |
| 1459 | + ) |
| 1460 | + |
| 1461 | + def save_index_json(self): |
| 1462 | + save_index_file = os.path.join(self.output_path, self.safe_index_name) |
| 1463 | + os.makedirs(os.path.dirname(save_index_file), exist_ok=True) |
| 1464 | + with open(save_index_file, "w", encoding="utf-8") as f: |
| 1465 | + f.write(json.dumps(self.index, indent=2) + "\n") |
| 1466 | + logger.info(f"Model index file saved in {save_index_file}.") |
0 commit comments