Skip to content

Commit 9286df3

Browse files
Add support for all dict methods to ShardedH5IOStroe. (#21365)
* Add support for all dict methods to `ShardedH5IOStroe`. * Increase test coverage. * Update.
1 parent 24f104e commit 9286df3

File tree

2 files changed

+251
-27
lines changed

2 files changed

+251
-27
lines changed

keras/src/saving/saving_lib.py

Lines changed: 201 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1040,15 +1040,20 @@ def __bool__(self):
10401040
# will mistakenly using `__len__` to determine the value.
10411041
return self.h5_file.__bool__()
10421042

1043-
def _get_h5_file(self, path_or_io):
1043+
def _get_h5_file(self, path_or_io, mode=None):
1044+
mode = mode or self.mode
1045+
if mode not in ("r", "w", "a"):
1046+
raise ValueError(
1047+
f"`mode` should be either 'r', 'w' or 'a'. Received: {mode}"
1048+
)
10441049
if self.archive:
1045-
if self.mode == "w":
1050+
if mode == "w":
10461051
self.io_file = io.BytesIO()
10471052
else:
10481053
self.io_file = self.archive.open(str(path_or_io), "r")
1049-
return h5py.File(self.io_file, mode=self.mode)
1054+
return h5py.File(self.io_file, mode=mode)
10501055
else:
1051-
return h5py.File(path_or_io, mode=self.mode)
1056+
return h5py.File(path_or_io, mode=mode)
10521057

10531058
def make(self, path, metadata=None):
10541059
"""Make a new H5 entry group.
@@ -1148,10 +1153,16 @@ def __getitem__(self, key):
11481153
and value.attrs["dtype"] == "bfloat16"
11491154
):
11501155
value = np.array(value, dtype=ml_dtypes.bfloat16)
1156+
elif (
1157+
hasattr(value, "shape")
1158+
and hasattr(value, "dtype")
1159+
and not isinstance(value, np.ndarray)
1160+
):
1161+
value = np.array(value)
11511162
return value
11521163

11531164
def __setitem__(self, key, value):
1154-
if self.mode != "w":
1165+
if self.mode not in ("w", "a"):
11551166
raise ValueError("Setting a value is only allowed in write mode.")
11561167
if not self._h5_entry_initialized:
11571168
self._create_h5_group(self._h5_entry_path)
@@ -1164,7 +1175,7 @@ def __setitem__(self, key, value):
11641175
self._h5_entry_group[key] = value
11651176

11661177
def __delitem__(self, key):
1167-
if self.mode != "w":
1178+
if self.mode not in ("w", "a"):
11681179
raise ValueError("Deleting a value is only allowed in write mode.")
11691180
del self._h5_entry_group[key]
11701181

@@ -1202,7 +1213,7 @@ def __init__(self, path_or_io, max_shard_size=5, archive=None, mode="r"):
12021213
self.archive = archive
12031214
self.io_file = None
12041215

1205-
self.max_shard_size = float(max_shard_size)
1216+
self.max_shard_size = float(max_shard_size) * 1024**3 # To bytes.
12061217
self.base_name = self.path.stem.replace(".weights", "")
12071218

12081219
if self.path.suffix != ".json":
@@ -1226,6 +1237,7 @@ def __init__(self, path_or_io, max_shard_size=5, archive=None, mode="r"):
12261237
self.current_shard_size = 0
12271238
self.total_shard_size = 0 # In bytes.
12281239
self.current_shard_path = None
1240+
self.current_shard_filenames = []
12291241
if self.mode == "w":
12301242
self.sharding_config = {
12311243
"metadata": {
@@ -1243,6 +1255,27 @@ def __init__(self, path_or_io, max_shard_size=5, archive=None, mode="r"):
12431255
self.sharding_config = json.load(map_file)
12441256
self.h5_file = self._create_new_shard_file()
12451257

1258+
def make(self, path, metadata=None):
1259+
"""Make a new H5 entry group.
1260+
1261+
This method is only available in write mode. It defers the creation of
1262+
the H5 entry group until `__setitem__` is called, preventing the
1263+
creation of empty groups.
1264+
1265+
The information about the current shard is reset.
1266+
1267+
Args:
1268+
path: `str`. The variable path.
1269+
metadata: Optional `dict`. The metadata to save with the H5 entry
1270+
group. Defaults to `None`.
1271+
"""
1272+
self.current_shard_filenames = []
1273+
if self.h5_file is not None:
1274+
self.current_shard_filenames.append(
1275+
pathlib.Path(self.h5_file.filename).name
1276+
)
1277+
return super().make(path, metadata)
1278+
12461279
def get(self, path):
12471280
"""Get the H5 entry group.
12481281
@@ -1259,17 +1292,27 @@ def get(self, path):
12591292

12601293
# If not found, check shard map and switch files.
12611294
weight_map = self.sharding_config["weight_map"]
1262-
filename = weight_map.get(parsed_path) or weight_map.get(
1295+
filenames = weight_map.get(parsed_path) or weight_map.get(
12631296
"/" + parsed_path + "/vars"
12641297
)
1298+
if filenames is not None:
1299+
if not isinstance(filenames, list):
1300+
filenames = [filenames]
1301+
self.current_shard_filenames = filenames
1302+
filename = filenames[0]
1303+
else:
1304+
self.current_shard_filenames = []
1305+
filename = None
12651306

12661307
if filename is not None and filename != self.current_shard_path.name:
12671308
self.close()
12681309
self.h5_file = self._get_h5_file(self.path.with_name(filename))
12691310
return super().get(path)
12701311

12711312
def close(self):
1272-
self.h5_file.close()
1313+
if self.h5_file is not None:
1314+
self.h5_file.close()
1315+
self.h5_file = None
12731316
if self.mode == "w":
12741317
self.sharding_config["metadata"]["total_size"] = (
12751318
self.total_shard_size
@@ -1289,28 +1332,128 @@ def close(self):
12891332
# Shard-specific methods.
12901333

12911334
def _create_new_shard_file(self):
1335+
"""Create a new shard file and return the H5 file object."""
12921336
new_shard_path = (
12931337
f"{self.base_name}_{self.current_shard_index:05}.weights.h5"
12941338
)
12951339
self.current_shard_index += 1
12961340
self.current_shard_path = self.path.with_name(new_shard_path)
1297-
return self._get_h5_file(self.current_shard_path)
1341+
h5_file = self._get_h5_file(self.current_shard_path)
1342+
self.current_shard_filenames.append(pathlib.Path(h5_file.filename).name)
1343+
self._h5_entry_initialized = False
1344+
return h5_file
1345+
1346+
def _switch_h5_file(self, filename, mode):
1347+
"""Switch to a different H5 file with the specified mode.
1348+
1349+
This is useful for retrieving information from all shards, such as the
1350+
total length, keys, and items.
1351+
"""
1352+
if mode not in ("r", "a"):
1353+
raise ValueError(
1354+
f"`mode` should be either 'r' or 'a'. Received: {mode}"
1355+
)
1356+
self.close()
1357+
self.h5_file = self._get_h5_file(
1358+
self.path.with_name(filename), mode=mode
1359+
)
1360+
self._get_h5_group(self._h5_entry_path)
1361+
1362+
def _restore_h5_file(self):
1363+
"""Ensure the current shard is the last one created.
1364+
1365+
We use mode="a" to avoid truncating the file during the switching.
1366+
"""
1367+
if (
1368+
pathlib.Path(self.h5_file.filename).name
1369+
!= self.current_shard_path.name
1370+
):
1371+
self._switch_h5_file(self.current_shard_path.name, mode="a")
12981372

12991373
# H5 entry level methods.
13001374

1375+
def _get_h5_group(self, path):
1376+
"""Get the H5 entry group. If it doesn't exist, return an empty dict."""
1377+
try:
1378+
if not path:
1379+
self._h5_entry_group = self.h5_file["vars"]
1380+
else:
1381+
self._h5_entry_group = self.h5_file[path]["vars"]
1382+
self._h5_entry_initialized = True
1383+
except KeyError:
1384+
self._h5_entry_group = {}
1385+
self._h5_entry_initialized = False
1386+
1387+
# Dict methods.
1388+
1389+
def __len__(self):
1390+
total_len = self._h5_entry_group.__len__()
1391+
for filename in self.current_shard_filenames:
1392+
if filename == self.current_shard_path.name:
1393+
continue
1394+
self._switch_h5_file(filename, mode="r")
1395+
total_len += self._h5_entry_group.__len__()
1396+
self._restore_h5_file()
1397+
return total_len
1398+
1399+
def keys(self):
1400+
keys = set(self._h5_entry_group.keys())
1401+
for filename in self.current_shard_filenames:
1402+
if filename == self.current_shard_path.name:
1403+
continue
1404+
self._switch_h5_file(filename, mode="r")
1405+
keys.update(self._h5_entry_group.keys())
1406+
self._restore_h5_file()
1407+
return keys
1408+
1409+
def items(self):
1410+
yield from self._h5_entry_group.items()
1411+
for filename in self.current_shard_filenames:
1412+
if filename == self.current_shard_path.name:
1413+
continue
1414+
self._switch_h5_file(filename, mode="r")
1415+
yield from self._h5_entry_group.items()
1416+
self._restore_h5_file()
1417+
1418+
def values(self):
1419+
yield from self._h5_entry_group.values()
1420+
for filename in self.current_shard_filenames:
1421+
if filename == self.current_shard_path.name:
1422+
continue
1423+
self._switch_h5_file(filename, mode="r")
1424+
yield from self._h5_entry_group.values()
1425+
self._restore_h5_file()
1426+
1427+
def __getitem__(self, key):
1428+
if key in self._h5_entry_group:
1429+
return super().__getitem__(key)
1430+
1431+
for filename in self.current_shard_filenames:
1432+
if filename == self.current_shard_path.name:
1433+
continue
1434+
self._switch_h5_file(filename, mode="r")
1435+
if key in self._h5_entry_group:
1436+
item = super().__getitem__(key)
1437+
self._restore_h5_file()
1438+
return item
1439+
raise KeyError(
1440+
f"Key '{key}' not found in any of the shards: "
1441+
f"{self.current_shard_filenames}"
1442+
)
1443+
13011444
def __setitem__(self, key, value):
1445+
self._restore_h5_file()
1446+
13021447
# Accumulate `current_shard_size`.
13031448
value = backend.convert_to_numpy(value)
13041449
dtype = backend.standardize_dtype(value.dtype)
13051450
weight_counts = math.prod(value.shape)
13061451
per_param_size = dtype_utils.dtype_size(dtype)
1307-
value_size = weight_counts * per_param_size / (8.0 * 1024**3) # To GB.
1308-
self.total_shard_size += weight_counts * per_param_size / 8 # In bytes.
1452+
value_size = weight_counts * per_param_size / 8 # In bytes.
1453+
self.total_shard_size += value_size
13091454
if value_size > self.max_shard_size:
1310-
value_size_str = readable_memory_size(value_size * 1024**3)
1311-
max_shard_size_str = readable_memory_size(
1312-
self.max_shard_size * 1024**3
1313-
)
1455+
value_size_str = readable_memory_size(value_size)
1456+
max_shard_size_str = readable_memory_size(self.max_shard_size)
13141457
raise ValueError(
13151458
f"The size of {key} is {value_size_str} which "
13161459
f"exceeds the maximum shard size {max_shard_size_str}. You "
@@ -1323,16 +1466,53 @@ def __setitem__(self, key, value):
13231466
if self.current_shard_size > self.max_shard_size:
13241467
self.close()
13251468
self.h5_file = self._create_new_shard_file()
1326-
self.make(self._h5_entry_path)
13271469
self.current_shard_size = value_size
13281470

13291471
super().__setitem__(key, value)
13301472

1473+
# Update the weight map.
13311474
variable_path = self._h5_entry_group.name
1332-
if variable_path not in self.sharding_config["weight_map"]:
1333-
self.sharding_config["weight_map"][variable_path] = (
1334-
self.current_shard_path.name
1335-
)
1475+
shard_filename = self.current_shard_path.name
1476+
weight_map = self.sharding_config["weight_map"]
1477+
if variable_path not in weight_map:
1478+
weight_map[variable_path] = shard_filename
1479+
else:
1480+
if not isinstance(weight_map[variable_path], list):
1481+
weight_map[variable_path] = [weight_map[variable_path]]
1482+
if shard_filename not in weight_map[variable_path]:
1483+
weight_map[variable_path].append(shard_filename)
1484+
1485+
def __delitem__(self, key):
1486+
if key in self._h5_entry_group:
1487+
super().__delitem__(key)
1488+
return
1489+
1490+
for filename in self.current_shard_filenames:
1491+
if filename == self.current_shard_path.name:
1492+
continue
1493+
self._switch_h5_file(filename, mode="a")
1494+
if key in self._h5_entry_group:
1495+
super().__delitem__(key)
1496+
self._restore_h5_file()
1497+
return
1498+
raise KeyError(
1499+
f"Key '{key}' not found in any of the shards: "
1500+
f"{self.current_shard_filenames}"
1501+
)
1502+
1503+
def __contains__(self, item):
1504+
if item in self._h5_entry_group:
1505+
return True
1506+
1507+
for filename in self.current_shard_filenames:
1508+
if filename == self.current_shard_path.name:
1509+
continue
1510+
self._switch_h5_file(filename, mode="r")
1511+
if item in self._h5_entry_group:
1512+
self._restore_h5_file()
1513+
return True
1514+
self._restore_h5_file()
1515+
return False
13361516

13371517

13381518
class NpzIOStore:

0 commit comments

Comments
 (0)