Skip to content

Commit 20f821b

Browse files
SunMarcWauplin
andcommitted
Support max_shard_size as string in split_state_dict_into_shards_factory (#2286)
* fix max-shard-size * for torch also * add tests * Fix styling + do not support KiB --------- Co-authored-by: Lucain Pouget <[email protected]>
1 parent 12b34d7 commit 20f821b

File tree

5 files changed

+70
-9
lines changed

5 files changed

+70
-9
lines changed

src/huggingface_hub/serialization/_base.py

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
"""Contains helpers to split tensors into shards."""
1515

1616
from dataclasses import dataclass, field
17-
from typing import Any, Callable, Dict, List, Optional, TypeVar
17+
from typing import Any, Callable, Dict, List, Optional, TypeVar, Union
1818

1919
from .. import logging
2020

@@ -46,7 +46,7 @@ def split_state_dict_into_shards_factory(
4646
get_tensor_size: TensorSizeFn_T,
4747
get_storage_id: StorageIDFn_T = lambda tensor: None,
4848
filename_pattern: str = FILENAME_PATTERN,
49-
max_shard_size: int = MAX_SHARD_SIZE,
49+
max_shard_size: Union[int, str] = MAX_SHARD_SIZE,
5050
) -> StateDictSplit:
5151
"""
5252
Split a model state dictionary in shards so that each shard is smaller than a given size.
@@ -89,6 +89,9 @@ def split_state_dict_into_shards_factory(
8989
current_shard_size = 0
9090
total_size = 0
9191

92+
if isinstance(max_shard_size, str):
93+
max_shard_size = parse_size_to_int(max_shard_size)
94+
9295
for key, tensor in state_dict.items():
9396
# when bnb serialization is used the weights in the state dict can be strings
9497
# check: https://github.com/huggingface/transformers/pull/24416 for more details
@@ -167,3 +170,44 @@ def split_state_dict_into_shards_factory(
167170
filename_to_tensors=filename_to_tensors,
168171
tensor_to_filename=tensor_name_to_filename,
169172
)
173+
174+
175+
SIZE_UNITS = {
176+
"TB": 10**12,
177+
"GB": 10**9,
178+
"MB": 10**6,
179+
"KB": 10**3,
180+
}
181+
182+
183+
def parse_size_to_int(size_as_str: str) -> int:
184+
"""
185+
Parse a size expressed as a string with digits and unit (like `"5MB"`) to an integer (in bytes).
186+
187+
Supported units are "TB", "GB", "MB", "KB".
188+
189+
Args:
190+
size_as_str (`str`): The size to convert. Will be directly returned if an `int`.
191+
192+
Example:
193+
194+
```py
195+
>>> parse_size_to_int("5MB")
196+
5000000
197+
```
198+
"""
199+
size_as_str = size_as_str.strip()
200+
201+
# Parse unit
202+
unit = size_as_str[-2:].upper()
203+
if unit not in SIZE_UNITS:
204+
raise ValueError(f"Unit '{unit}' not supported. Supported units are TB, GB, MB, KB. Got '{size_as_str}'.")
205+
multiplier = SIZE_UNITS[unit]
206+
207+
# Parse value
208+
try:
209+
value = float(size_as_str[:-2].strip())
210+
except ValueError as e:
211+
raise ValueError(f"Could not parse the size value from '{size_as_str}': {e}") from e
212+
213+
return int(value * multiplier)

src/huggingface_hub/serialization/_numpy.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414
"""Contains numpy-specific helpers."""
1515

16-
from typing import TYPE_CHECKING, Dict
16+
from typing import TYPE_CHECKING, Dict, Union
1717

1818
from ._base import FILENAME_PATTERN, MAX_SHARD_SIZE, StateDictSplit, split_state_dict_into_shards_factory
1919

@@ -26,7 +26,7 @@ def split_numpy_state_dict_into_shards(
2626
state_dict: Dict[str, "np.ndarray"],
2727
*,
2828
filename_pattern: str = FILENAME_PATTERN,
29-
max_shard_size: int = MAX_SHARD_SIZE,
29+
max_shard_size: Union[int, str] = MAX_SHARD_SIZE,
3030
) -> StateDictSplit:
3131
"""
3232
Split a model state dictionary in shards so that each shard is smaller than a given size.

src/huggingface_hub/serialization/_tensorflow.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
import math
1717
import re
18-
from typing import TYPE_CHECKING, Dict
18+
from typing import TYPE_CHECKING, Dict, Union
1919

2020
from ._base import MAX_SHARD_SIZE, StateDictSplit, split_state_dict_into_shards_factory
2121

@@ -28,7 +28,7 @@ def split_tf_state_dict_into_shards(
2828
state_dict: Dict[str, "tf.Tensor"],
2929
*,
3030
filename_pattern: str = "tf_model{suffix}.h5",
31-
max_shard_size: int = MAX_SHARD_SIZE,
31+
max_shard_size: Union[int, str] = MAX_SHARD_SIZE,
3232
) -> StateDictSplit:
3333
"""
3434
Split a model state dictionary in shards so that each shard is smaller than a given size.

src/huggingface_hub/serialization/_torch.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
import importlib
1717
from functools import lru_cache
18-
from typing import TYPE_CHECKING, Dict, Tuple
18+
from typing import TYPE_CHECKING, Dict, Tuple, Union
1919

2020
from ._base import FILENAME_PATTERN, MAX_SHARD_SIZE, StateDictSplit, split_state_dict_into_shards_factory
2121

@@ -28,7 +28,7 @@ def split_torch_state_dict_into_shards(
2828
state_dict: Dict[str, "torch.Tensor"],
2929
*,
3030
filename_pattern: str = FILENAME_PATTERN,
31-
max_shard_size: int = MAX_SHARD_SIZE,
31+
max_shard_size: Union[int, str] = MAX_SHARD_SIZE,
3232
) -> StateDictSplit:
3333
"""
3434
Split a model state dictionary in shards so that each shard is smaller than a given size.
@@ -67,7 +67,7 @@ def split_torch_state_dict_into_shards(
6767
6868
>>> def save_state_dict(state_dict: Dict[str, torch.Tensor], save_directory: str):
6969
... state_dict_split = split_torch_state_dict_into_shards(state_dict)
70-
... for filename, tensors in state_dict_split.filename_to_tensors.values():
70+
... for filename, tensors in state_dict_split.filename_to_tensors.items():
7171
... shard = {tensor: state_dict[tensor] for tensor in tensors}
7272
... safe_save_file(
7373
... shard,

tests/test_serialization.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1+
import pytest
2+
13
from huggingface_hub.serialization import split_state_dict_into_shards_factory
4+
from huggingface_hub.serialization._base import parse_size_to_int
25
from huggingface_hub.serialization._numpy import get_tensor_size as get_tensor_size_numpy
36
from huggingface_hub.serialization._tensorflow import get_tensor_size as get_tensor_size_tensorflow
47
from huggingface_hub.serialization._torch import get_tensor_size as get_tensor_size_torch
@@ -123,3 +126,17 @@ def test_get_tensor_size_torch():
123126

124127
assert get_tensor_size_torch(torch.tensor([1, 2, 3, 4, 5], dtype=torch.float64)) == 5 * 8
125128
assert get_tensor_size_torch(torch.tensor([1, 2, 3, 4, 5], dtype=torch.float16)) == 5 * 2
129+
130+
131+
def test_parse_size_to_int():
132+
assert parse_size_to_int("1KB") == 1 * 10**3
133+
assert parse_size_to_int("2MB") == 2 * 10**6
134+
assert parse_size_to_int("3GB") == 3 * 10**9
135+
assert parse_size_to_int(" 10 KB ") == 10 * 10**3 # ok with whitespace
136+
assert parse_size_to_int("20mb") == 20 * 10**6 # ok with lowercase
137+
138+
with pytest.raises(ValueError, match="Unit 'IB' not supported"):
139+
parse_size_to_int("1KiB") # not a valid unit
140+
141+
with pytest.raises(ValueError, match="Could not parse the size value"):
142+
parse_size_to_int("1ooKB") # not a float

0 commit comments

Comments
 (0)