1515import os
1616from dataclasses import dataclass
1717from time import sleep
18- from typing import Any , Dict , List , Optional , Tuple
18+ from typing import Any , Dict , List , Optional , Tuple , Union
1919
2020import numpy as np
2121import torch
2929 from torch .utils ._pytree import PyTree , tree_flatten , treespec_dumps
3030
3131
32- def _get_data_optimizer_node_rank () -> Optional [int ]:
33- node_rank = os .getenv ("DATA_OPTIMIZER_NODE_RANK" , None )
34- if node_rank is not None :
35- return int (node_rank )
36- return node_rank
32+ _FORMAT_TO_RATIO = {
33+ "kb" : 1024 ,
34+ "mb" : 1024 ** 2 ,
35+ "gb" : 1024 ** 3 ,
36+ "tb" : 1024 ** 4 ,
37+ "pb" : 1024 ** 5 ,
38+ "eb" : 1024 ** 6 ,
39+ "zb" : 1024 ** 7 ,
40+ "yb" : 1024 ** 8 ,
41+ }
42+
43+
44+ def _convert_bytes_to_int (bytes_str : str ) -> int :
45+ """Convert human readable byte format to an integer."""
46+ for suffix in _FORMAT_TO_RATIO :
47+ bytes_str = bytes_str .lower ().strip ()
48+ if bytes_str .lower ().endswith (suffix ):
49+ try :
50+ return int (float (bytes_str [0 : - len (suffix )]) * _FORMAT_TO_RATIO [suffix ])
51+ except ValueError :
52+ raise ValueError (
53+ "" .join (
54+ [
55+ f"Unsupported value/suffix { bytes_str } . Supported suffix are " ,
56+ f'{ ["b" ] + list (_FORMAT_TO_RATIO .keys ())} .' ,
57+ ]
58+ )
59+ )
60+ raise ValueError (f"The supported units are { _FORMAT_TO_RATIO .keys ()} " )
3761
3862
3963@dataclass
@@ -52,7 +76,7 @@ def __init__(
5276 self ,
5377 cache_dir : str ,
5478 chunk_size : Optional [int ] = None ,
55- chunk_bytes : Optional [int ] = None ,
79+ chunk_bytes : Optional [Union [ int , str ] ] = None ,
5680 compression : Optional [str ] = None ,
5781 follow_tensor_dimension : bool = True ,
5882 ):
@@ -75,7 +99,7 @@ def __init__(
7599
76100 self ._serializers : Dict [str , Serializer ] = _SERIALIZERS
77101 self ._chunk_size = chunk_size
78- self ._chunk_bytes = chunk_bytes
102+ self ._chunk_bytes = _convert_bytes_to_int ( chunk_bytes ) if isinstance ( chunk_bytes , str ) else chunk_bytes
79103 self ._compression = compression
80104
81105 self ._data_format : Optional [List [str ]] = None
0 commit comments