|
14 | 14 | """Contains helpers to split tensors into shards.""" |
15 | 15 |
|
16 | 16 | from dataclasses import dataclass, field |
17 | | -from typing import Any, Callable, Dict, List, Optional, TypeVar |
| 17 | +from typing import Any, Callable, Dict, List, Optional, TypeVar, Union |
18 | 18 |
|
19 | 19 | from .. import logging |
20 | 20 |
|
@@ -46,7 +46,7 @@ def split_state_dict_into_shards_factory( |
46 | 46 | get_tensor_size: TensorSizeFn_T, |
47 | 47 | get_storage_id: StorageIDFn_T = lambda tensor: None, |
48 | 48 | filename_pattern: str = FILENAME_PATTERN, |
49 | | - max_shard_size: int = MAX_SHARD_SIZE, |
| 49 | + max_shard_size: Union[int, str] = MAX_SHARD_SIZE, |
50 | 50 | ) -> StateDictSplit: |
51 | 51 | """ |
52 | 52 | Split a model state dictionary in shards so that each shard is smaller than a given size. |
@@ -89,6 +89,9 @@ def split_state_dict_into_shards_factory( |
89 | 89 | current_shard_size = 0 |
90 | 90 | total_size = 0 |
91 | 91 |
|
| 92 | + if isinstance(max_shard_size, str): |
| 93 | + max_shard_size = parse_size_to_int(max_shard_size) |
| 94 | + |
92 | 95 | for key, tensor in state_dict.items(): |
93 | 96 | # when bnb serialization is used the weights in the state dict can be strings |
94 | 97 | # check: https://github.com/huggingface/transformers/pull/24416 for more details |
@@ -167,3 +170,44 @@ def split_state_dict_into_shards_factory( |
167 | 170 | filename_to_tensors=filename_to_tensors, |
168 | 171 | tensor_to_filename=tensor_name_to_filename, |
169 | 172 | ) |
| 173 | + |
| 174 | + |
| 175 | +SIZE_UNITS = { |
| 176 | + "TB": 10**12, |
| 177 | + "GB": 10**9, |
| 178 | + "MB": 10**6, |
| 179 | + "KB": 10**3, |
| 180 | +} |
| 181 | + |
| 182 | + |
| 183 | +def parse_size_to_int(size_as_str: str) -> int: |
| 184 | + """ |
| 185 | + Parse a size expressed as a string with digits and unit (like `"5MB"`) to an integer (in bytes). |
| 186 | +
|
| 187 | + Supported units are "TB", "GB", "MB", "KB". |
| 188 | +
|
| 189 | + Args: |
| 190 | + size_as_str (`str`): The size to convert. Will be directly returned if an `int`. |
| 191 | +
|
| 192 | + Example: |
| 193 | +
|
| 194 | + ```py |
| 195 | + >>> parse_size_to_int("5MB") |
| 196 | + 5000000 |
| 197 | + ``` |
| 198 | + """ |
| 199 | + size_as_str = size_as_str.strip() |
| 200 | + |
| 201 | + # Parse unit |
| 202 | + unit = size_as_str[-2:].upper() |
| 203 | + if unit not in SIZE_UNITS: |
| 204 | + raise ValueError(f"Unit '{unit}' not supported. Supported units are TB, GB, MB, KB. Got '{size_as_str}'.") |
| 205 | + multiplier = SIZE_UNITS[unit] |
| 206 | + |
| 207 | + # Parse value |
| 208 | + try: |
| 209 | + value = float(size_as_str[:-2].strip()) |
| 210 | + except ValueError as e: |
| 211 | + raise ValueError(f"Could not parse the size value from '{size_as_str}': {e}") from e |
| 212 | + |
| 213 | + return int(value * multiplier) |
0 commit comments