11import collections .abc
22import warnings
33import zipfile
4- from io import TextIOWrapper
54from pathlib import Path , PurePosixPath
65from shutil import copyfileobj
76from typing import (
1514)
1615
1716import h5py # pyright: ignore[reportMissingTypeStubs]
18- import numpy as np
1917from imageio .v3 import imread , imwrite # type: ignore
2018from loguru import logger
2119from numpy .typing import NDArray
2220from pydantic import BaseModel , ConfigDict , TypeAdapter
2321from typing_extensions import assert_never
2422
25- from bioimageio .spec ._internal .io import interprete_file_source
23+ from bioimageio .spec ._internal .io import get_reader , interprete_file_source
24+ from bioimageio .spec ._internal .type_guards import is_ndarray
2625from bioimageio .spec .common import (
26+ FileSource ,
2727 HttpUrl ,
2828 PermissiveFileSource ,
2929 RelativeFilePath ,
@@ -65,51 +65,51 @@ def load_image(
6565 else :
6666 src = parsed_source
6767
68- # FIXME: why is pyright complaining about giving the union to _split_dataset_path?
6968 if isinstance (src , Path ):
70- file_source , subpath = _split_dataset_path (src )
69+ file_source , suffix , subpath = _split_dataset_path (src )
7170 elif isinstance (src , HttpUrl ):
72- file_source , subpath = _split_dataset_path (src )
71+ file_source , suffix , subpath = _split_dataset_path (src )
7372 elif isinstance (src , ZipPath ):
74- file_source , subpath = _split_dataset_path (src )
73+ file_source , suffix , subpath = _split_dataset_path (src )
7574 else :
7675 assert_never (src )
7776
78- path = download (file_source ).path
79-
80- if path .suffix == ".npy" :
77+ if suffix == ".npy" :
8178 if subpath is not None :
82- raise ValueError (f"Unexpected subpath { subpath } for .npy path { path } " )
83- return load_array (path )
84- elif path .suffix in SUFFIXES_WITH_DATAPATH :
79+ logger .warning (
80+ "Unexpected subpath {} for .npy source {}" , subpath , file_source
81+ )
82+
83+ image = load_array (file_source )
84+ elif suffix in SUFFIXES_WITH_DATAPATH :
8585 if subpath is None :
8686 dataset_path = DEFAULT_H5_DATASET_PATH
8787 else :
8888 dataset_path = str (subpath )
8989
90- with h5py .File (path , "r" ) as f :
90+ reader = download (file_source )
91+
92+ with h5py .File (reader , "r" ) as f :
9193 h5_dataset = f .get ( # pyright: ignore[reportUnknownVariableType]
9294 dataset_path
9395 )
9496 if not isinstance (h5_dataset , h5py .Dataset ):
9597 raise ValueError (
96- f"{ path } is not of type { h5py .Dataset } , but has type "
98+ f"{ file_source } did not load as { h5py .Dataset } , but has type "
9799 + str (
98100 type (h5_dataset ) # pyright: ignore[reportUnknownArgumentType]
99101 )
100102 )
101103 image : NDArray [Any ]
102104 image = h5_dataset [:] # pyright: ignore[reportUnknownVariableType]
103- assert isinstance (image , np .ndarray ), type (
104- image # pyright: ignore[reportUnknownArgumentType]
105- )
106- return image # pyright: ignore[reportUnknownVariableType]
107- elif isinstance (path , ZipPath ):
108- return imread (
109- path .read_bytes (), extension = path .suffix
110- ) # pyright: ignore[reportUnknownVariableType]
111105 else :
112- return imread (path ) # pyright: ignore[reportUnknownVariableType]
106+ reader = download (file_source )
107+ image = imread ( # pyright: ignore[reportUnknownVariableType]
108+ reader .read (), extension = suffix
109+ )
110+
111+ assert is_ndarray (image )
112+ return image
113113
114114
115115def load_tensor (
@@ -123,19 +123,21 @@ def load_tensor(
123123
124124_SourceT = TypeVar ("_SourceT" , Path , HttpUrl , ZipPath )
125125
126+ Suffix = str
127+
126128
127129def _split_dataset_path (
128130 source : _SourceT ,
129- ) -> Tuple [_SourceT , Optional [PurePosixPath ]]:
131+ ) -> Tuple [_SourceT , Suffix , Optional [PurePosixPath ]]:
130132 """Split off subpath (e.g. internal h5 dataset path)
131133 from a file path following a file extension.
132134
133135 Examples:
134136 >>> _split_dataset_path(Path("my_file.h5/dataset"))
135- (...Path('my_file.h5'), PurePosixPath('dataset'))
137+ (...Path('my_file.h5'), '.h5', PurePosixPath('dataset'))
136138
137139 >>> _split_dataset_path(Path("my_plain_file"))
138- (...Path('my_plain_file'), None)
140+ (...Path('my_plain_file'), '', None)
139141
140142 """
141143 if isinstance (source , RelativeFilePath ):
@@ -148,50 +150,55 @@ def _split_dataset_path(
148150 def separate_pure_path (path : PurePosixPath ):
149151 for p in path .parents :
150152 if p .suffix in SUFFIXES_WITH_DATAPATH :
151- return p , PurePosixPath (path .relative_to (p ))
153+ return p , p . suffix , PurePosixPath (path .relative_to (p ))
152154
153- return path , None
155+ return path , path . suffix , None
154156
155157 if isinstance (src , HttpUrl ):
156- file_path , data_path = separate_pure_path (PurePosixPath (src .path or "" ))
158+ file_path , suffix , data_path = separate_pure_path (PurePosixPath (src .path or "" ))
157159
158160 if data_path is None :
159- return src , None
161+ return src , suffix , None
160162
161163 return (
162164 HttpUrl (str (file_path ).replace (f"/{ data_path } " , "" )),
165+ suffix ,
163166 data_path ,
164167 )
165168
166169 if isinstance (src , ZipPath ):
167- file_path , data_path = separate_pure_path (PurePosixPath (str (src )))
170+ file_path , suffix , data_path = separate_pure_path (PurePosixPath (str (src )))
168171
169172 if data_path is None :
170- return src , None
173+ return src , suffix , None
171174
172175 return (
173176 ZipPath (str (file_path ).replace (f"/{ data_path } " , "" )),
177+ suffix ,
174178 data_path ,
175179 )
176180
177- file_path , data_path = separate_pure_path (PurePosixPath (src ))
178- return Path (file_path ), data_path
181+ file_path , suffix , data_path = separate_pure_path (PurePosixPath (src ))
182+ return Path (file_path ), suffix , data_path
179183
180184
181185def save_tensor (path : Union [Path , str ], tensor : Tensor ) -> None :
182186 # TODO: save axis meta data
183187
184- data : NDArray [Any ] = tensor .data .to_numpy ()
185- file_path , subpath = _split_dataset_path (Path (path ))
186- if not file_path .suffix :
188+ data : NDArray [Any ] = ( # pyright: ignore[reportUnknownVariableType]
189+ tensor .data .to_numpy ()
190+ )
191+ assert is_ndarray (data )
192+ file_path , suffix , subpath = _split_dataset_path (Path (path ))
193+ if not suffix :
187194 raise ValueError (f"No suffix (needed to decide file format) found in { path } " )
188195
189196 file_path .parent .mkdir (exist_ok = True , parents = True )
190197 if file_path .suffix == ".npy" :
191198 if subpath is not None :
192199 raise ValueError (f"Unexpected subpath { subpath } found in .npy path { path } " )
193200 save_array (file_path , data )
194- elif file_path . suffix in (".h5" , ".hdf" , ".hdf5" ):
201+ elif suffix in (".h5" , ".hdf" , ".hdf5" ):
195202 if subpath is None :
196203 dataset_path = DEFAULT_H5_DATASET_PATH
197204 else :
@@ -275,22 +282,39 @@ def load_dataset_stat(path: Path):
275282def ensure_unzipped (source : Union [PermissiveFileSource , ZipPath ], folder : Path ):
276283 """unzip a (downloaded) **source** to a file in **folder** if source is a zip archive.
277284 Always returns the path to the unzipped source (maybe source itself)"""
278- local_weights_file = download (source ).path
279- if isinstance (local_weights_file , ZipPath ):
280- # source is inside a zip archive
281- out_path = folder / local_weights_file .filename
282- with local_weights_file .open ("rb" ) as src , out_path .open ("wb" ) as dst :
283- assert not isinstance (src , TextIOWrapper )
284- copyfileobj (src , dst )
285-
286- local_weights_file = out_path
287-
288- if zipfile .is_zipfile (local_weights_file ):
285+ weights_reader = get_reader (source )
286+ out_path = folder / (
287+ weights_reader .original_file_name or f"file{ weights_reader .suffix } "
288+ )
289+
290+ if zipfile .is_zipfile (weights_reader ):
291+ out_path = out_path .with_name (out_path .name + ".unzipped" )
292+ out_path .parent .mkdir (exist_ok = True , parents = True )
289293 # source itself is a zipfile
290- out_path = folder / local_weights_file .with_suffix (".unzipped" ).name
291- with zipfile .ZipFile (local_weights_file , "r" ) as f :
294+ with zipfile .ZipFile (weights_reader , "r" ) as f :
292295 f .extractall (out_path )
293296
294- return out_path
295297 else :
296- return local_weights_file
298+ out_path .parent .mkdir (exist_ok = True , parents = True )
299+ with out_path .open ("wb" ) as f :
300+ copyfileobj (weights_reader , f )
301+
302+ return out_path
303+
304+
305+ def get_suffix (source : Union [ZipPath , FileSource ]) -> str :
306+ if isinstance (source , Path ):
307+ return source .suffix
308+ elif isinstance (source , ZipPath ):
309+ return source .suffix
310+ if isinstance (source , RelativeFilePath ):
311+ return source .path .suffix
312+ elif isinstance (source , ZipPath ):
313+ return source .suffix
314+ elif isinstance (source , HttpUrl ):
315+ if source .path is None :
316+ return ""
317+ else :
318+ return PurePosixPath (source .path ).suffix
319+ else :
320+ assert_never (source )
0 commit comments