File tree Expand file tree Collapse file tree 1 file changed +5
-8
lines changed
bioimageio/core/weight_converter/torch Expand file tree Collapse file tree 1 file changed +5
-8
lines changed Original file line number Diff line number Diff line change 1- # type: ignore # TODO: type
21from pathlib import Path
32from typing import List , Sequence , Union
43
54import numpy as np
5+ import torch
66from numpy .testing import assert_array_almost_equal
77from typing_extensions import Any , assert_never
88
99from bioimageio .spec .model import v0_4 , v0_5
1010from bioimageio .spec .model .v0_5 import Version
1111
12- from ._utils import load_torch_model
13-
14- try :
15- import torch
16- except ImportError :
17- torch = None
12+ from ...model_adapters ._pytorch_model_adapter import PytorchModelAdapter
1813
1914
2015# FIXME: remove Any
@@ -119,7 +114,9 @@ def convert_weights_to_torchscript(
119114 with torch .no_grad ():
120115 input_data = [torch .from_numpy (inp .astype ("float32" )) for inp in input_data ]
121116
122- model = load_torch_model (state_dict_weights_descr )
117+ model = PytorchModelAdapter .get_network (
118+ state_dict_weights_descr , load_state = True
119+ )
123120
124121 # FIXME: remove Any
125122 if use_tracing :
You can’t perform that action at this time.
0 commit comments