|
1 | 1 | import os |
2 | 2 | import shutil |
3 | | -from tempfile import NamedTemporaryFile |
| 3 | +from pathlib import Path |
| 4 | +from tempfile import TemporaryDirectory |
4 | 5 | from typing import Any, Optional, Sequence, Union |
5 | 6 |
|
| 7 | +import h5py # pyright: ignore[reportMissingTypeStubs] |
| 8 | +from keras.src.legacy.saving import ( # pyright: ignore[reportMissingTypeStubs] |
| 9 | + legacy_h5_format, |
| 10 | +) |
6 | 11 | from loguru import logger |
7 | 12 | from numpy.typing import NDArray |
8 | 13 |
|
@@ -70,10 +75,16 @@ def __init__( |
70 | 75 | ) |
71 | 76 |
|
72 | 77 | weight_reader = model_description.weights.keras_hdf5.get_reader() |
73 | | - # TODO: do we need to load keras model from disk? |
74 | | - with NamedTemporaryFile(mode="wb") as temp_file: |
75 | | - shutil.copyfileobj(weight_reader, temp_file) |
76 | | - self._network = keras.models.load_model(temp_file.name) |
| 78 | + if weight_reader.suffix in (".h5", "hdf5"): |
| 79 | + h5_file = h5py.File(weight_reader, mode="r") |
| 80 | + self._network = legacy_h5_format.load_model_from_hdf5(h5_file) |
| 81 | + else: |
| 82 | + with TemporaryDirectory() as temp_dir: |
| 83 | + temp_path = Path(temp_dir) / weight_reader.original_file_name |
| 84 | + with temp_path.open("wb") as f: |
| 85 | + shutil.copyfileobj(weight_reader, f) |
| 86 | + |
| 87 | + self._network = keras.models.load_model(temp_path) |
77 | 88 |
|
78 | 89 | self._output_axes = [ |
79 | 90 | tuple(a.id for a in get_axes_infos(out)) |
|
0 commit comments