File tree Expand file tree Collapse file tree 3 files changed +21
-8
lines changed
bioimageio/core/weight_converter/torch Expand file tree Collapse file tree 3 files changed +21
-8
lines changed Original file line number Diff line number Diff line change 44from typing import Any , List , Sequence , cast
55
66import numpy as np
7- import torch
87from numpy .testing import assert_array_almost_equal
98
109from bioimageio .spec import load_description
1413from ...digest_spec import get_member_id , get_test_inputs
1514from ...weight_converter .torch ._utils import load_torch_model
1615
16+ try :
17+ import torch
18+ except ImportError :
19+ torch = None
20+
1721
1822def add_onnx_weights (
1923 model_spec : "str | Path | v0_4.ModelDescr | v0_5.ModelDescr" ,
@@ -48,6 +52,7 @@ def add_onnx_weights(
4852 "The provided model does not have weights in the pytorch state dict format"
4953 )
5054
55+ assert torch is not None
5156 with torch .no_grad ():
5257
5358 sample = get_test_inputs (model_spec )
Original file line number Diff line number Diff line change 33from typing import List , Sequence , Union
44
55import numpy as np
6- import torch
76from numpy .testing import assert_array_almost_equal
87from typing_extensions import Any , assert_never
98
1211
1312from ._utils import load_torch_model
1413
14+ try :
15+ import torch
16+ except ImportError :
17+ torch = None
18+
1519
1620# FIXME: remove Any
1721def _check_predictions (
1822 model : Any ,
1923 scripted_model : Any ,
2024 model_spec : "v0_4.ModelDescr | v0_5.ModelDescr" ,
21- input_data : Sequence [torch .Tensor ],
25+ input_data : Sequence [" torch.Tensor" ],
2226):
27+ assert torch is not None
28+
2329 def _check (input_ : Sequence [torch .Tensor ]) -> None :
2430 expected_tensors = model (* input_ )
2531 if isinstance (expected_tensors , torch .Tensor ):
Original file line number Diff line number Diff line change 11from typing import Union
22
3- import torch
4-
53from bioimageio .core .model_adapters ._pytorch_model_adapter import PytorchModelAdapter
64from bioimageio .spec .model import v0_4 , v0_5
75from bioimageio .spec .utils import download
86
7+ try :
8+ import torch
9+ except ImportError :
10+ torch = None
11+
912
1013# additional convenience for pytorch state dict, eventually we want this in python-bioimageio too
1114# and for each weight format
1215def load_torch_model ( # pyright: ignore[reportUnknownParameterType]
1316 node : Union [v0_4 .PytorchStateDictWeightsDescr , v0_5 .PytorchStateDictWeightsDescr ],
1417):
18+ assert torch is not None
1519 model = ( # pyright: ignore[reportUnknownVariableType]
1620 PytorchModelAdapter .get_network (node )
1721 )
18- state = torch .load ( # pyright: ignore[reportUnknownVariableType]
19- download (node .source ).path , map_location = "cpu"
20- )
22+ state = torch .load (download (node .source ).path , map_location = "cpu" )
2123 model .load_state_dict (state ) # FIXME: check incompatible keys?
2224 return model .eval () # pyright: ignore[reportUnknownVariableType]
You can’t perform that action at this time.
0 commit comments