22
33import importlib .util
44from itertools import chain
5+ from pathlib import Path
56from typing import (
67 Any ,
78 Callable ,
1617 Union ,
1718)
1819
20+ import numpy as np
21+ import xarray as xr
22+ from loguru import logger
1923from numpy .typing import NDArray
2024from typing_extensions import Unpack , assert_never
2125
26+ from bioimageio .core .common import MemberId , PerMember , SampleId
27+ from bioimageio .core .io import load_tensor
28+ from bioimageio .core .sample import Sample
2229from bioimageio .spec ._internal .io_utils import HashKwargs , download
2330from bioimageio .spec .common import FileSource
2431from bioimageio .spec .model import AnyModelDescr , v0_4 , v0_5
3037)
3138from bioimageio .spec .utils import load_array
3239
33- from .axis import AxisId , AxisInfo , PerAxis
40+ from .axis import AxisId , AxisInfo , AxisLike , PerAxis
3441from .block_meta import split_multiple_shapes_into_blocks
3542from .common import Halo , MemberId , PerMember , SampleId , TotalNumberOfBlocks
3643from .sample import (
@@ -329,12 +336,35 @@ def get_io_sample_block_metas(
329336 )
330337
331338
339+ def get_tensor (
340+ src : Union [Tensor , xr .DataArray , NDArray [Any ], Path ],
341+ ipt : Union [v0_4 .InputTensorDescr , v0_5 .InputTensorDescr ],
342+ ):
343+ """helper to cast/load various tensor sources"""
344+
345+ if isinstance (src , Tensor ):
346+ return src
347+
348+ if isinstance (src , xr .DataArray ):
349+ return Tensor .from_xarray (src )
350+
351+ if isinstance (src , np .ndarray ):
352+ return Tensor .from_numpy (src , dims = get_axes_infos (ipt ))
353+
354+ if isinstance (src , Path ):
355+ return load_tensor (src , axes = get_axes_infos (ipt ))
356+
357+ assert_never (src )
358+
359+
332360def create_sample_for_model (
333361 model : AnyModelDescr ,
334362 * ,
335363 stat : Optional [Stat ] = None ,
336364 sample_id : SampleId = None ,
337- inputs : Optional [PerMember [NDArray [Any ]]] = None , # TODO: make non-optional
365+ inputs : Optional [
366+ PerMember [Union [Tensor , xr .DataArray , NDArray [Any ], Path ]]
367+ ] = None , # TODO: make non-optional
338368 ** kwargs : NDArray [Any ], # TODO: deprecate in favor of `inputs`
339369) -> Sample :
340370 """Create a sample from a single set of input(s) for a specific bioimage.io model
@@ -359,10 +389,54 @@ def create_sample_for_model(
359389
360390 return Sample (
361391 members = {
362- m : Tensor . from_numpy (inputs [m ], dims = get_axes_infos ( ipt ) )
392+ m : get_tensor (inputs [m ], ipt )
363393 for m , ipt in model_inputs .items ()
364394 if m in inputs
365395 },
366396 stat = {} if stat is None else stat ,
367397 id = sample_id ,
368398 )
399+
400+
401+ def load_sample_for_model (
402+ * ,
403+ model : AnyModelDescr ,
404+ paths : PerMember [Path ],
405+ axes : Optional [PerMember [Sequence [AxisLike ]]] = None ,
406+ stat : Optional [Stat ] = None ,
407+ sample_id : Optional [SampleId ] = None ,
408+ ):
409+ """load a single sample from `paths` that can be processed by `model`"""
410+
411+ if axes is None :
412+ axes = {}
413+
414+ # make sure members are keyed by MemberId, not string
415+ paths = {MemberId (k ): v for k , v in paths .items ()}
416+ axes = {MemberId (k ): v for k , v in axes .items ()}
417+
418+ model_inputs = {get_member_id (d ): d for d in model .inputs }
419+
420+ if unknown := {k for k in paths if k not in model_inputs }:
421+ raise ValueError (f"Got unexpected paths for { unknown } " )
422+
423+ if unknown := {k for k in axes if k not in model_inputs }:
424+ raise ValueError (f"Got unexpected axes hints for: { unknown } " )
425+
426+ members : Dict [MemberId , Tensor ] = {}
427+ for m , p in paths .items ():
428+ if m not in axes :
429+ axes [m ] = get_axes_infos (model_inputs [m ])
430+ logger .warning (
431+ "loading paths with {}'s default input axes {} for input '{}'" ,
432+ axes [m ],
433+ model .id or model .name ,
434+ m ,
435+ )
436+ members [m ] = load_tensor (p , axes [m ])
437+
438+ return Sample (
439+ members = members ,
440+ stat = {} if stat is None else stat ,
441+ id = sample_id or tuple (sorted (paths .values ())),
442+ )
0 commit comments