@@ -1040,15 +1040,20 @@ def __bool__(self):
1040
1040
# will mistakenly using `__len__` to determine the value.
1041
1041
return self .h5_file .__bool__ ()
1042
1042
1043
- def _get_h5_file (self , path_or_io ):
1043
+ def _get_h5_file (self , path_or_io , mode = None ):
1044
+ mode = mode or self .mode
1045
+ if mode not in ("r" , "w" , "a" ):
1046
+ raise ValueError (
1047
+ f"`mode` should be either 'r', 'w' or 'a'. Received: { mode } "
1048
+ )
1044
1049
if self .archive :
1045
- if self . mode == "w" :
1050
+ if mode == "w" :
1046
1051
self .io_file = io .BytesIO ()
1047
1052
else :
1048
1053
self .io_file = self .archive .open (str (path_or_io ), "r" )
1049
- return h5py .File (self .io_file , mode = self . mode )
1054
+ return h5py .File (self .io_file , mode = mode )
1050
1055
else :
1051
- return h5py .File (path_or_io , mode = self . mode )
1056
+ return h5py .File (path_or_io , mode = mode )
1052
1057
1053
1058
def make (self , path , metadata = None ):
1054
1059
"""Make a new H5 entry group.
@@ -1148,10 +1153,16 @@ def __getitem__(self, key):
1148
1153
and value .attrs ["dtype" ] == "bfloat16"
1149
1154
):
1150
1155
value = np .array (value , dtype = ml_dtypes .bfloat16 )
1156
+ elif (
1157
+ hasattr (value , "shape" )
1158
+ and hasattr (value , "dtype" )
1159
+ and not isinstance (value , np .ndarray )
1160
+ ):
1161
+ value = np .array (value )
1151
1162
return value
1152
1163
1153
1164
def __setitem__ (self , key , value ):
1154
- if self .mode != "w" :
1165
+ if self .mode not in ( "w" , "a" ) :
1155
1166
raise ValueError ("Setting a value is only allowed in write mode." )
1156
1167
if not self ._h5_entry_initialized :
1157
1168
self ._create_h5_group (self ._h5_entry_path )
@@ -1164,7 +1175,7 @@ def __setitem__(self, key, value):
1164
1175
self ._h5_entry_group [key ] = value
1165
1176
1166
1177
def __delitem__ (self , key ):
1167
- if self .mode != "w" :
1178
+ if self .mode not in ( "w" , "a" ) :
1168
1179
raise ValueError ("Deleting a value is only allowed in write mode." )
1169
1180
del self ._h5_entry_group [key ]
1170
1181
@@ -1202,7 +1213,7 @@ def __init__(self, path_or_io, max_shard_size=5, archive=None, mode="r"):
1202
1213
self .archive = archive
1203
1214
self .io_file = None
1204
1215
1205
- self .max_shard_size = float (max_shard_size )
1216
+ self .max_shard_size = float (max_shard_size ) * 1024 ** 3 # To bytes.
1206
1217
self .base_name = self .path .stem .replace (".weights" , "" )
1207
1218
1208
1219
if self .path .suffix != ".json" :
@@ -1226,6 +1237,7 @@ def __init__(self, path_or_io, max_shard_size=5, archive=None, mode="r"):
1226
1237
self .current_shard_size = 0
1227
1238
self .total_shard_size = 0 # In bytes.
1228
1239
self .current_shard_path = None
1240
+ self .current_shard_filenames = []
1229
1241
if self .mode == "w" :
1230
1242
self .sharding_config = {
1231
1243
"metadata" : {
@@ -1243,6 +1255,27 @@ def __init__(self, path_or_io, max_shard_size=5, archive=None, mode="r"):
1243
1255
self .sharding_config = json .load (map_file )
1244
1256
self .h5_file = self ._create_new_shard_file ()
1245
1257
1258
+ def make (self , path , metadata = None ):
1259
+ """Make a new H5 entry group.
1260
+
1261
+ This method is only available in write mode. It defers the creation of
1262
+ the H5 entry group until `__setitem__` is called, preventing the
1263
+ creation of empty groups.
1264
+
1265
+ The information about the current shard is reset.
1266
+
1267
+ Args:
1268
+ path: `str`. The variable path.
1269
+ metadata: Optional `dict`. The metadata to save with the H5 entry
1270
+ group. Defaults to `None`.
1271
+ """
1272
+ self .current_shard_filenames = []
1273
+ if self .h5_file is not None :
1274
+ self .current_shard_filenames .append (
1275
+ pathlib .Path (self .h5_file .filename ).name
1276
+ )
1277
+ return super ().make (path , metadata )
1278
+
1246
1279
def get (self , path ):
1247
1280
"""Get the H5 entry group.
1248
1281
@@ -1259,17 +1292,27 @@ def get(self, path):
1259
1292
1260
1293
# If not found, check shard map and switch files.
1261
1294
weight_map = self .sharding_config ["weight_map" ]
1262
- filename = weight_map .get (parsed_path ) or weight_map .get (
1295
+ filenames = weight_map .get (parsed_path ) or weight_map .get (
1263
1296
"/" + parsed_path + "/vars"
1264
1297
)
1298
+ if filenames is not None :
1299
+ if not isinstance (filenames , list ):
1300
+ filenames = [filenames ]
1301
+ self .current_shard_filenames = filenames
1302
+ filename = filenames [0 ]
1303
+ else :
1304
+ self .current_shard_filenames = []
1305
+ filename = None
1265
1306
1266
1307
if filename is not None and filename != self .current_shard_path .name :
1267
1308
self .close ()
1268
1309
self .h5_file = self ._get_h5_file (self .path .with_name (filename ))
1269
1310
return super ().get (path )
1270
1311
1271
1312
def close (self ):
1272
- self .h5_file .close ()
1313
+ if self .h5_file is not None :
1314
+ self .h5_file .close ()
1315
+ self .h5_file = None
1273
1316
if self .mode == "w" :
1274
1317
self .sharding_config ["metadata" ]["total_size" ] = (
1275
1318
self .total_shard_size
@@ -1289,28 +1332,128 @@ def close(self):
1289
1332
# Shard-specific methods.
1290
1333
1291
1334
def _create_new_shard_file (self ):
1335
+ """Create a new shard file and return the H5 file object."""
1292
1336
new_shard_path = (
1293
1337
f"{ self .base_name } _{ self .current_shard_index :05} .weights.h5"
1294
1338
)
1295
1339
self .current_shard_index += 1
1296
1340
self .current_shard_path = self .path .with_name (new_shard_path )
1297
- return self ._get_h5_file (self .current_shard_path )
1341
+ h5_file = self ._get_h5_file (self .current_shard_path )
1342
+ self .current_shard_filenames .append (pathlib .Path (h5_file .filename ).name )
1343
+ self ._h5_entry_initialized = False
1344
+ return h5_file
1345
+
1346
+ def _switch_h5_file (self , filename , mode ):
1347
+ """Switch to a different H5 file with the specified mode.
1348
+
1349
+ This is useful for retrieving information from all shards, such as the
1350
+ total length, keys, and items.
1351
+ """
1352
+ if mode not in ("r" , "a" ):
1353
+ raise ValueError (
1354
+ f"`mode` should be either 'r' or 'a'. Received: { mode } "
1355
+ )
1356
+ self .close ()
1357
+ self .h5_file = self ._get_h5_file (
1358
+ self .path .with_name (filename ), mode = mode
1359
+ )
1360
+ self ._get_h5_group (self ._h5_entry_path )
1361
+
1362
+ def _restore_h5_file (self ):
1363
+ """Ensure the current shard is the last one created.
1364
+
1365
+ We use mode="a" to avoid truncating the file during the switching.
1366
+ """
1367
+ if (
1368
+ pathlib .Path (self .h5_file .filename ).name
1369
+ != self .current_shard_path .name
1370
+ ):
1371
+ self ._switch_h5_file (self .current_shard_path .name , mode = "a" )
1298
1372
1299
1373
# H5 entry level methods.
1300
1374
1375
+ def _get_h5_group (self , path ):
1376
+ """Get the H5 entry group. If it doesn't exist, return an empty dict."""
1377
+ try :
1378
+ if not path :
1379
+ self ._h5_entry_group = self .h5_file ["vars" ]
1380
+ else :
1381
+ self ._h5_entry_group = self .h5_file [path ]["vars" ]
1382
+ self ._h5_entry_initialized = True
1383
+ except KeyError :
1384
+ self ._h5_entry_group = {}
1385
+ self ._h5_entry_initialized = False
1386
+
1387
+ # Dict methods.
1388
+
1389
+ def __len__ (self ):
1390
+ total_len = self ._h5_entry_group .__len__ ()
1391
+ for filename in self .current_shard_filenames :
1392
+ if filename == self .current_shard_path .name :
1393
+ continue
1394
+ self ._switch_h5_file (filename , mode = "r" )
1395
+ total_len += self ._h5_entry_group .__len__ ()
1396
+ self ._restore_h5_file ()
1397
+ return total_len
1398
+
1399
+ def keys (self ):
1400
+ keys = set (self ._h5_entry_group .keys ())
1401
+ for filename in self .current_shard_filenames :
1402
+ if filename == self .current_shard_path .name :
1403
+ continue
1404
+ self ._switch_h5_file (filename , mode = "r" )
1405
+ keys .update (self ._h5_entry_group .keys ())
1406
+ self ._restore_h5_file ()
1407
+ return keys
1408
+
1409
+ def items (self ):
1410
+ yield from self ._h5_entry_group .items ()
1411
+ for filename in self .current_shard_filenames :
1412
+ if filename == self .current_shard_path .name :
1413
+ continue
1414
+ self ._switch_h5_file (filename , mode = "r" )
1415
+ yield from self ._h5_entry_group .items ()
1416
+ self ._restore_h5_file ()
1417
+
1418
+ def values (self ):
1419
+ yield from self ._h5_entry_group .values ()
1420
+ for filename in self .current_shard_filenames :
1421
+ if filename == self .current_shard_path .name :
1422
+ continue
1423
+ self ._switch_h5_file (filename , mode = "r" )
1424
+ yield from self ._h5_entry_group .values ()
1425
+ self ._restore_h5_file ()
1426
+
1427
+ def __getitem__ (self , key ):
1428
+ if key in self ._h5_entry_group :
1429
+ return super ().__getitem__ (key )
1430
+
1431
+ for filename in self .current_shard_filenames :
1432
+ if filename == self .current_shard_path .name :
1433
+ continue
1434
+ self ._switch_h5_file (filename , mode = "r" )
1435
+ if key in self ._h5_entry_group :
1436
+ item = super ().__getitem__ (key )
1437
+ self ._restore_h5_file ()
1438
+ return item
1439
+ raise KeyError (
1440
+ f"Key '{ key } ' not found in any of the shards: "
1441
+ f"{ self .current_shard_filenames } "
1442
+ )
1443
+
1301
1444
def __setitem__ (self , key , value ):
1445
+ self ._restore_h5_file ()
1446
+
1302
1447
# Accumulate `current_shard_size`.
1303
1448
value = backend .convert_to_numpy (value )
1304
1449
dtype = backend .standardize_dtype (value .dtype )
1305
1450
weight_counts = math .prod (value .shape )
1306
1451
per_param_size = dtype_utils .dtype_size (dtype )
1307
- value_size = weight_counts * per_param_size / ( 8.0 * 1024 ** 3 ) # To GB .
1308
- self .total_shard_size += weight_counts * per_param_size / 8 # In bytes.
1452
+ value_size = weight_counts * per_param_size / 8 # In bytes .
1453
+ self .total_shard_size += value_size
1309
1454
if value_size > self .max_shard_size :
1310
- value_size_str = readable_memory_size (value_size * 1024 ** 3 )
1311
- max_shard_size_str = readable_memory_size (
1312
- self .max_shard_size * 1024 ** 3
1313
- )
1455
+ value_size_str = readable_memory_size (value_size )
1456
+ max_shard_size_str = readable_memory_size (self .max_shard_size )
1314
1457
raise ValueError (
1315
1458
f"The size of { key } is { value_size_str } which "
1316
1459
f"exceeds the maximum shard size { max_shard_size_str } . You "
@@ -1323,16 +1466,53 @@ def __setitem__(self, key, value):
1323
1466
if self .current_shard_size > self .max_shard_size :
1324
1467
self .close ()
1325
1468
self .h5_file = self ._create_new_shard_file ()
1326
- self .make (self ._h5_entry_path )
1327
1469
self .current_shard_size = value_size
1328
1470
1329
1471
super ().__setitem__ (key , value )
1330
1472
1473
+ # Update the weight map.
1331
1474
variable_path = self ._h5_entry_group .name
1332
- if variable_path not in self .sharding_config ["weight_map" ]:
1333
- self .sharding_config ["weight_map" ][variable_path ] = (
1334
- self .current_shard_path .name
1335
- )
1475
+ shard_filename = self .current_shard_path .name
1476
+ weight_map = self .sharding_config ["weight_map" ]
1477
+ if variable_path not in weight_map :
1478
+ weight_map [variable_path ] = shard_filename
1479
+ else :
1480
+ if not isinstance (weight_map [variable_path ], list ):
1481
+ weight_map [variable_path ] = [weight_map [variable_path ]]
1482
+ if shard_filename not in weight_map [variable_path ]:
1483
+ weight_map [variable_path ].append (shard_filename )
1484
+
1485
+ def __delitem__ (self , key ):
1486
+ if key in self ._h5_entry_group :
1487
+ super ().__delitem__ (key )
1488
+ return
1489
+
1490
+ for filename in self .current_shard_filenames :
1491
+ if filename == self .current_shard_path .name :
1492
+ continue
1493
+ self ._switch_h5_file (filename , mode = "a" )
1494
+ if key in self ._h5_entry_group :
1495
+ super ().__delitem__ (key )
1496
+ self ._restore_h5_file ()
1497
+ return
1498
+ raise KeyError (
1499
+ f"Key '{ key } ' not found in any of the shards: "
1500
+ f"{ self .current_shard_filenames } "
1501
+ )
1502
+
1503
+ def __contains__ (self , item ):
1504
+ if item in self ._h5_entry_group :
1505
+ return True
1506
+
1507
+ for filename in self .current_shard_filenames :
1508
+ if filename == self .current_shard_path .name :
1509
+ continue
1510
+ self ._switch_h5_file (filename , mode = "r" )
1511
+ if item in self ._h5_entry_group :
1512
+ self ._restore_h5_file ()
1513
+ return True
1514
+ self ._restore_h5_file ()
1515
+ return False
1336
1516
1337
1517
1338
1518
class NpzIOStore :
0 commit comments