11# pyright: reportUnknownVariableType=false
2+ import shutil
3+ import tempfile
24import warnings
5+ from pathlib import Path
36from typing import Any , List , Optional , Sequence , Union
47
58import onnxruntime as rt # pyright: ignore[reportMissingTypeStubs]
6- from numpy .typing import NDArray
7-
89from bioimageio .spec .model import v0_4 , v0_5
10+ from loguru import logger
11+ from numpy .typing import NDArray
912
1013from ..model_adapters import ModelAdapter
1114from ..utils ._type_guards import is_list , is_tuple
@@ -20,11 +23,63 @@ def __init__(
2023 ):
2124 super ().__init__ (model_description = model_description )
2225
23- if model_description .weights .onnx is None :
26+ onnx_descr = model_description .weights .onnx
27+ if onnx_descr is None :
2428 raise ValueError ("No ONNX weights specified for {model_description.name}" )
2529
26- reader = model_description .weights .onnx .get_reader ()
27- self ._session = rt .InferenceSession (reader .read ())
30+ providers = None
31+ if hasattr (rt , "get_available_providers" ):
32+ providers = rt .get_available_providers ()
33+
34+ if (
35+ isinstance (onnx_descr , v0_5 .OnnxWeightsDescr )
36+ and onnx_descr .external_data is not None
37+ ):
38+ src = onnx_descr .source .absolute ()
39+ src_data = onnx_descr .external_data .source .absolute ()
40+ if (
41+ isinstance (src , Path )
42+ and isinstance (src_data , Path )
43+ and src .parent == src_data .parent
44+ ):
45+ logger .debug (
46+ "Loading ONNX model with external data from {}" ,
47+ src .parent ,
48+ )
49+ self ._session = rt .InferenceSession (
50+ src ,
51+ providers = providers , # pyright: ignore[reportUnknownArgumentType]
52+ )
53+ else :
54+ src_reader = onnx_descr .get_reader ()
55+ src_data_reader = onnx_descr .external_data .get_reader ()
56+ with tempfile .TemporaryDirectory () as tmpdir :
57+ logger .debug (
58+ "Loading ONNX model with external data from {}" ,
59+ tmpdir ,
60+ )
61+ src = Path (tmpdir ) / src_reader .original_file_name
62+ src_data = Path (tmpdir ) / src_data_reader .original_file_name
63+ with src .open ("wb" ) as f :
64+ shutil .copyfileobj (src_reader , f )
65+ with src_data .open ("wb" ) as f :
66+ shutil .copyfileobj (src_data_reader , f )
67+
68+ self ._session = rt .InferenceSession (
69+ src ,
70+ providers = providers , # pyright: ignore[reportUnknownArgumentType]
71+ )
72+ else :
73+ # load single source file from bytes (without external data, so probably <2GB)
74+ logger .debug (
75+ "Loading ONNX model from bytes (read from {})" , onnx_descr .source
76+ )
77+ reader = onnx_descr .get_reader ()
78+ self ._session = rt .InferenceSession (
79+ reader .read (),
80+ providers = providers , # pyright: ignore[reportUnknownArgumentType]
81+ )
82+
2883 onnx_inputs = self ._session .get_inputs ()
2984 self ._input_names : List [str ] = [ipt .name for ipt in onnx_inputs ]
3085
0 commit comments