11# type: ignore # TODO: type
2- import warnings
2+ from __future__ import annotations
33from pathlib import Path
4- from typing import Any , List , Sequence , cast
4+ from typing import Any , List , Sequence , cast , Union
55
66import numpy as np
77from numpy .testing import assert_array_almost_equal
88
9- from bioimageio .spec import load_description
10- from bioimageio .spec .common import InvalidDescr
119from bioimageio .spec .model import v0_4 , v0_5
1210
1311from ...digest_spec import get_member_id , get_test_inputs
1917 torch = None
2018
2119
22- def add_onnx_weights (
23- model_spec : "str | Path | v0_4.ModelDescr | v0_5.ModelDescr" ,
20+ def convert_weights_to_onnx (
21+ model_spec : Union [ v0_4 .ModelDescr , v0_5 .ModelDescr ] ,
2422 * ,
2523 output_path : Path ,
2624 use_tracing : bool = True ,
2725 test_decimal : int = 4 ,
2826 verbose : bool = False ,
29- opset_version : " int | None" = None ,
30- ):
27+ opset_version : int = 15 ,
28+ ) -> v0_5 . OnnxWeightsDescr :
3129 """Convert model weights from format 'pytorch_state_dict' to 'onnx'.
3230
3331 Args:
@@ -36,16 +34,6 @@ def add_onnx_weights(
3634 use_tracing: whether to use tracing or scripting to export the onnx format
3735 test_decimal: precision for testing whether the results agree
3836 """
39- if isinstance (model_spec , (str , Path )):
40- loaded_spec = load_description (Path (model_spec ))
41- if isinstance (loaded_spec , InvalidDescr ):
42- raise ValueError (f"Bad resource description: { loaded_spec } " )
43- if not isinstance (loaded_spec , (v0_4 .ModelDescr , v0_5 .ModelDescr )):
44- raise TypeError (
45- f"Path { model_spec } is a { loaded_spec .__class__ .__name__ } , expected a v0_4.ModelDescr or v0_5.ModelDescr"
46- )
47- model_spec = loaded_spec
48-
4937 state_dict_weights_descr = model_spec .weights .pytorch_state_dict
5038 if state_dict_weights_descr is None :
5139 raise ValueError (
@@ -54,9 +42,10 @@ def add_onnx_weights(
5442
5543 assert torch is not None
5644 with torch .no_grad ():
57-
5845 sample = get_test_inputs (model_spec )
59- input_data = [sample [get_member_id (ipt )].data .data for ipt in model_spec .inputs ]
46+ input_data = [
47+ sample .members [get_member_id (ipt )].data .data for ipt in model_spec .inputs
48+ ]
6049 input_tensors = [torch .from_numpy (ipt ) for ipt in input_data ]
6150 model = load_torch_model (state_dict_weights_descr )
6251
@@ -81,9 +70,9 @@ def add_onnx_weights(
8170 try :
8271 import onnxruntime as rt # pyright: ignore [reportMissingTypeStubs]
8372 except ImportError :
84- msg = "The onnx weights were exported, but onnx rt is not available and weights cannot be checked."
85- warnings . warn ( msg )
86- return
73+ raise ImportError (
74+ "The onnx weights were exported, but onnx rt is not available and weights cannot be checked."
75+ )
8776
8877 # check the onnx model
8978 sess = rt .InferenceSession (str (output_path ))
@@ -101,8 +90,11 @@ def add_onnx_weights(
10190 try :
10291 for exp , out in zip (expected_outputs , outputs ):
10392 assert_array_almost_equal (exp , out , decimal = test_decimal )
104- return 0
10593 except AssertionError as e :
106- msg = f"The onnx weights were exported, but results before and after conversion do not agree:\n { str (e )} "
107- warnings .warn (msg )
108- return 1
94+ raise ValueError (
95+ f"Results before and after weights conversion do not agree:\n { str (e )} "
96+ )
97+
98+ return v0_5 .OnnxWeightsDescr (
99+ source = output_path , parent = "pytorch_state_dict" , opset_version = opset_version
100+ )
0 commit comments