|
26 | 26 | import math |
27 | 27 | import os |
28 | 28 | import random |
29 | | -import re |
30 | 29 | import threading |
31 | 30 | import time |
32 | 31 | from contextlib import contextmanager |
33 | 32 | from enum import Enum |
34 | 33 | from pathlib import Path |
35 | | -from typing import Dict, Iterator, List, NamedTuple, Optional, Tuple, Union |
| 34 | +from typing import Dict, List, NamedTuple, Optional, Tuple, Union |
36 | 35 |
|
37 | 36 | import numpy as np |
38 | 37 | import paddle |
39 | 38 | import paddle.distributed as dist |
40 | | -from paddle import Tensor |
41 | 39 | from paddle.distributed import fleet |
42 | 40 | from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer import ( |
43 | 41 | DygraphShardingOptimizer, |
|
46 | 44 | from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker |
47 | 45 | from paddle.io import IterableDataset |
48 | 46 | from paddle.optimizer.lr import LambdaDecay |
49 | | -from safetensors import safe_open |
50 | | -from safetensors.paddle import save_file |
51 | 47 |
|
52 | 48 | from paddlenlp.ops import Topology |
53 | 49 |
|
@@ -1449,265 +1445,3 @@ def buffer_params(): |
1449 | 1445 | continue |
1450 | 1446 | param_list.append(param) |
1451 | 1447 | optimizer._create_accumulators(paddle.base.framework.default_main_program().global_block(), param_list) |
1452 | | - |
1453 | | - |
1454 | | -def _parse_size(size_str: str) -> int: |
1455 | | - """Parses a size string like '100MB', '2GB' into the number of bytes.""" |
1456 | | - size_str = size_str.upper().strip() |
1457 | | - match = re.match(r"^(\d+\.?\d*)\s*(B|KB|MB|GB|TB)?$", size_str) |
1458 | | - if not match: |
1459 | | - raise ValueError(f"Could not parse size string: '{size_str}'") |
1460 | | - |
1461 | | - num_str, unit = match.groups() |
1462 | | - num = float(num_str) |
1463 | | - |
1464 | | - if unit == "B" or unit is None: |
1465 | | - return int(num) |
1466 | | - elif unit == "KB": |
1467 | | - return int(num * 1024) |
1468 | | - elif unit == "MB": |
1469 | | - return int(num * 1024**2) |
1470 | | - elif unit == "GB": |
1471 | | - return int(num * 1024**3) |
1472 | | - elif unit == "TB": |
1473 | | - return int(num * 1024**4) |
1474 | | - else: |
1475 | | - # This case should not be reached due to regex |
1476 | | - raise ValueError(f"Unknown unit: '{unit}'") |
1477 | | - |
1478 | | - |
1479 | | -def save_full_param( |
1480 | | - itr: Iterator[tuple[str, Tensor]], |
1481 | | - save_dir: str, |
1482 | | - rank: int, |
1483 | | - moe_sharding_world_size: int, |
1484 | | - max_shard_size: str = "2GB", |
1485 | | - num_saver_ranks: int = 8, |
1486 | | -) -> None: |
1487 | | - """ |
1488 | | - Saves model weights from an iterator into shards, supporting max shard size |
1489 | | - and a limited number of saver ranks. |
1490 | | -
|
1491 | | - Only ranks less than `num_saver_ranks` will perform disk I/O. All other ranks |
1492 | | - will iterate through the data to maintain synchronization but will not save. |
1493 | | - The parameter distribution logic is based on `num_saver_ranks`, ensuring all |
1494 | | - parameters are handled by a designated saver rank. |
1495 | | -
|
1496 | | - Args: |
1497 | | - itr (Iterator): An iterator that yields (param_key, param_tensor). |
1498 | | - save_dir (str): The directory where shard files will be saved. |
1499 | | - rank (int): The rank of the current process. |
1500 | | - moe_sharding_world_size (int): The total number of processes. |
1501 | | - max_shard_size (str): The maximum size for each shard file, e.g., "500MB", "2GB". |
1502 | | - num_saver_ranks (int): The number of ranks (starting from 0) that will save files. |
1503 | | - """ |
1504 | | - |
1505 | | - # 1. Non-saver ranks simply consume the iterator to stay in sync. |
1506 | | - if rank >= num_saver_ranks: |
1507 | | - logger.info(f"[Rank {rank}/{moe_sharding_world_size}] (Non-saver) Consuming iterator for synchronization...") |
1508 | | - for _ in itr: |
1509 | | - pass |
1510 | | - logger.info(f"[Rank {rank}/{moe_sharding_world_size}] (Non-saver) Iterator consumption complete.") |
1511 | | - return |
1512 | | - |
1513 | | - max_shard_size_bytes = _parse_size(max_shard_size) |
1514 | | - logger.info( |
1515 | | - f"[Rank {rank}/{moe_sharding_world_size}] (Saver) Initializing save. " |
1516 | | - f"Max shard size set to: {max_shard_size_bytes / 1024**3:.2f} GB" |
1517 | | - ) |
1518 | | - |
1519 | | - os.makedirs(save_dir, exist_ok=True) |
1520 | | - |
1521 | | - current_shard_state_dict = {} |
1522 | | - current_shard_size_bytes = 0 |
1523 | | - sub_shard_index = 0 |
1524 | | - |
1525 | | - def _save_current_shard(): |
1526 | | - nonlocal sub_shard_index, current_shard_state_dict, current_shard_size_bytes |
1527 | | - if not current_shard_state_dict: |
1528 | | - return |
1529 | | - |
1530 | | - # Filename includes the main shard number (rank) and the sub-shard index |
1531 | | - cur_rank = paddle.distributed.get_rank() |
1532 | | - shard_filename = f"shard_{cur_rank}-{sub_shard_index}.safetensors" |
1533 | | - save_path = os.path.join(save_dir, shard_filename) |
1534 | | - |
1535 | | - logger.info( |
1536 | | - f"[Rank {rank}/{moe_sharding_world_size}] Saving sub-shard {sub_shard_index}... " |
1537 | | - f"Size: {current_shard_size_bytes / 1024**2:.2f} MB, " |
1538 | | - f"Params: {len(current_shard_state_dict)}, " |
1539 | | - f"Path: {save_path}" |
1540 | | - ) |
1541 | | - |
1542 | | - save_file(current_shard_state_dict, save_path) |
1543 | | - |
1544 | | - # Reset for the next shard |
1545 | | - sub_shard_index += 1 |
1546 | | - current_shard_state_dict = {} |
1547 | | - current_shard_size_bytes = 0 |
1548 | | - |
1549 | | - logger.info(f"[Rank {rank}/{moe_sharding_world_size}] Starting to process the weight iterator...") |
1550 | | - |
1551 | | - total_size = 0 |
1552 | | - |
1553 | | - for i, (param_key, param) in enumerate(itr): |
1554 | | - param_size_bytes = param.numel() * param.element_size() |
1555 | | - total_size += param_size_bytes.item() |
1556 | | - if i % num_saver_ranks == rank: |
1557 | | - if current_shard_size_bytes > 0 and (current_shard_size_bytes + param_size_bytes > max_shard_size_bytes): |
1558 | | - _save_current_shard() |
1559 | | - |
1560 | | - current_shard_state_dict[param_key] = param |
1561 | | - current_shard_size_bytes += param_size_bytes |
1562 | | - |
1563 | | - if current_shard_size_bytes >= max_shard_size_bytes: |
1564 | | - _save_current_shard() |
1565 | | - _save_current_shard() |
1566 | | - logger.info(f"[Rank {rank}/{moe_sharding_world_size}] (Saver) All shards saved successfully.") |
1567 | | - return total_size |
1568 | | - |
1569 | | - |
1570 | | -def replace_name_and_gen_index(path, total_size): |
1571 | | - index_mapping = {} |
1572 | | - cur_rank = paddle.distributed.get_rank() |
1573 | | - safetensor_files = [fname for fname in os.listdir(path) if fname.endswith(".safetensors")] |
1574 | | - files_num = len(safetensor_files) |
1575 | | - all_files_num = [] |
1576 | | - if paddle.distributed.get_world_size() > 1: |
1577 | | - paddle.distributed.all_gather_object(all_files_num, files_num) |
1578 | | - else: |
1579 | | - all_files_num.append(files_num) |
1580 | | - total_files_num = sum(all_files_num) |
1581 | | - |
1582 | | - start_idx = [] |
1583 | | - acc = 1 |
1584 | | - for files_num in all_files_num: |
1585 | | - start_idx.append(acc) |
1586 | | - acc += files_num |
1587 | | - |
1588 | | - env_local_size = int(os.environ.get("PADDLE_LOCAL_SIZE", 8)) |
1589 | | - env_local_rank = dist.get_rank() % env_local_size |
1590 | | - assert env_local_rank >= 0, f"expected positive local rank, got {env_local_rank}" |
1591 | | - |
1592 | | - cur_file_index = start_idx[cur_rank] // env_local_size |
1593 | | - total_files_num = total_files_num // env_local_size |
1594 | | - |
1595 | | - index_mapping = {} |
1596 | | - if env_local_rank == 0: |
1597 | | - for file in safetensor_files: |
1598 | | - cur_file_index += 1 |
1599 | | - file_path = os.path.join(path, file) |
1600 | | - new_file_name = f"model-{cur_file_index:05d}-of-{total_files_num:05d}.safetensors" |
1601 | | - with safe_open(file_path, framework="np") as f: |
1602 | | - for key in f.keys(): |
1603 | | - index_mapping[key] = new_file_name |
1604 | | - new_file_path = os.path.join(path, new_file_name) |
1605 | | - os.rename(file_path, new_file_path) |
1606 | | - |
1607 | | - index_mapping_list = [] |
1608 | | - if paddle.distributed.get_world_size() > 1: |
1609 | | - paddle.distributed.all_gather_object(index_mapping_list, index_mapping) |
1610 | | - else: |
1611 | | - index_mapping_list.append(index_mapping) |
1612 | | - index_mapping = {} |
1613 | | - for mapping in index_mapping_list: |
1614 | | - index_mapping.update(mapping) |
1615 | | - |
1616 | | - # Save signal file for each card |
1617 | | - saved_signal_path = os.path.join(path, f"saved_signal_{dist.get_rank()}") |
1618 | | - with open(saved_signal_path, mode="w+") as f: |
1619 | | - f.write("1") |
1620 | | - |
1621 | | - if env_local_rank == 0: |
1622 | | - index_file_name = "model.safetensors.index.json" |
1623 | | - index_infos = {} |
1624 | | - index_infos["metadata"] = {} |
1625 | | - index_infos["metadata"]["total_size"] = total_size |
1626 | | - index_infos["weight_map"] = dict(sorted(index_mapping.items())) |
1627 | | - with open(os.path.join(path, index_file_name), "w") as f: |
1628 | | - json.dump(index_infos, f, indent=4) |
1629 | | - |
1630 | | - # For PDC signal |
1631 | | - if strtobool(os.getenv("FLAG_LLM_PDC", "False")): |
1632 | | - for i in range(paddle.distributed.get_world_size()): |
1633 | | - saved_signal_path = os.path.join(path, f".model_weights.done.{i}") |
1634 | | - paddle.save(i, saved_signal_path) |
1635 | | - |
1636 | | - |
1637 | | -class HFFormatFullParamSaver: |
1638 | | - def __init__( |
1639 | | - self, |
1640 | | - model, |
1641 | | - aoa_config, |
1642 | | - h_group=None, |
1643 | | - v_group=None, |
1644 | | - num_splits=None, |
1645 | | - shard_idx=None, |
1646 | | - saved_in_one_node=False, |
1647 | | - memory_growth_threshold=8 * (2**30), |
1648 | | - ): |
1649 | | - self.model = model |
1650 | | - self.aoa_config = aoa_config |
1651 | | - self.h_group = h_group |
1652 | | - self.v_group = v_group |
1653 | | - self.num_splits = num_splits |
1654 | | - self.shard_idx = shard_idx |
1655 | | - self.saved_in_one_node = saved_in_one_node |
1656 | | - self.memory_growth_threshold = memory_growth_threshold |
1657 | | - self.determin_saver_based_group() |
1658 | | - |
1659 | | - def get_full_param_iter(self): |
1660 | | - assert (self.v_group and self.h_group) or not ( |
1661 | | - self.v_group or self.h_group |
1662 | | - ), f"both h_group and v_group are provided or none of them, but got {self.v_group} and {self.h_group}" |
1663 | | - if self.v_group and self.h_group: |
1664 | | - assert self.shard_idx is not None, "expected shard_idx is not None" |
1665 | | - assert self.num_splits is not None, "expected num_splits is not None" |
1666 | | - |
1667 | | - param_iter = self.model.full( |
1668 | | - aoa_config=self.aoa_config, |
1669 | | - h_group=self.h_group, |
1670 | | - v_group=self.v_group, |
1671 | | - num_splits=self.num_splits, |
1672 | | - shard_idx=self.shard_idx, |
1673 | | - memory_growth_threshold=self.memory_growth_threshold, |
1674 | | - ) |
1675 | | - else: |
1676 | | - param_iter = self.model.full( |
1677 | | - aoa_config=self.aoa_config, memory_growth_threshold=self.memory_growth_threshold |
1678 | | - ) |
1679 | | - return param_iter |
1680 | | - |
1681 | | - def determin_saver_based_group(self): |
1682 | | - self.num_saver_ranks = paddle.distributed.get_world_size() |
1683 | | - self.rank = paddle.distributed.get_rank() |
1684 | | - |
1685 | | - if self.h_group and self.v_group: |
1686 | | - self.num_saver_ranks = self.h_group.nranks * self.v_group.nranks |
1687 | | - self.rank = self.h_group.rank + self.v_group.rank * self.h_group.nranks |
1688 | | - |
1689 | | - if self.saved_in_one_node: |
1690 | | - local_world_size = int(os.environ.get("PADDLE_LOCAL_SIZE", 8)) |
1691 | | - self.num_saver_ranks = min(local_world_size, self.num_saver_ranks) |
1692 | | - |
1693 | | - def save_checkpoint(self, path, max_shard_size="16GB"): |
1694 | | - total_saved_size = save_full_param( |
1695 | | - itr=self.get_full_param_iter(), |
1696 | | - save_dir=path, |
1697 | | - rank=self.rank, |
1698 | | - moe_sharding_world_size=self.num_saver_ranks, |
1699 | | - max_shard_size=max_shard_size, |
1700 | | - num_saver_ranks=self.num_saver_ranks, |
1701 | | - ) |
1702 | | - if paddle.distributed.get_world_size() > 1: |
1703 | | - paddle.distributed.barrier() |
1704 | | - |
1705 | | - # TODO(): fix total size |
1706 | | - all_sizes = [] |
1707 | | - if paddle.distributed.get_world_size() > 1: |
1708 | | - paddle.distributed.all_gather_object(all_sizes, total_saved_size) |
1709 | | - else: |
1710 | | - all_sizes.append(total_saved_size) |
1711 | | - total_size = sum(all_sizes) |
1712 | | - replace_name_and_gen_index(path, total_size) |
1713 | | - return total_saved_size |
0 commit comments