Skip to content

Commit c12b990

Browse files
committed
load h5 from memory
1 parent 0877470 commit c12b990

File tree

1 file changed

+16
-5
lines changed

1 file changed

+16
-5
lines changed

bioimageio/core/backends/keras_backend.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
11
import os
22
import shutil
3-
from tempfile import NamedTemporaryFile
3+
from pathlib import Path
4+
from tempfile import TemporaryDirectory
45
from typing import Any, Optional, Sequence, Union
56

7+
import h5py # pyright: ignore[reportMissingTypeStubs]
8+
from keras.src.legacy.saving import ( # pyright: ignore[reportMissingTypeStubs]
9+
legacy_h5_format,
10+
)
611
from loguru import logger
712
from numpy.typing import NDArray
813

@@ -70,10 +75,16 @@ def __init__(
7075
)
7176

7277
weight_reader = model_description.weights.keras_hdf5.get_reader()
73-
# TODO: do we need to load keras model from disk?
74-
with NamedTemporaryFile(mode="wb") as temp_file:
75-
shutil.copyfileobj(weight_reader, temp_file)
76-
self._network = keras.models.load_model(temp_file.name)
78+
if weight_reader.suffix in (".h5", "hdf5"):
79+
h5_file = h5py.File(weight_reader, mode="r")
80+
self._network = legacy_h5_format.load_model_from_hdf5(h5_file)
81+
else:
82+
with TemporaryDirectory() as temp_dir:
83+
temp_path = Path(temp_dir) / weight_reader.original_file_name
84+
with temp_path.open("wb") as f:
85+
shutil.copyfileobj(weight_reader, f)
86+
87+
self._network = keras.models.load_model(temp_path)
7788

7889
self._output_axes = [
7990
tuple(a.id for a in get_axes_infos(out))

0 commit comments

Comments
 (0)