Skip to content

Commit bb539d4

Browse files
committed
update torchscript adapter
1 parent 40dfe25 commit bb539d4

File tree

1 file changed

+5
-8
lines changed

1 file changed

+5
-8
lines changed

bioimageio/core/weight_converter/torch/_torchscript.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,15 @@
1-
# type: ignore # TODO: type
21
from pathlib import Path
32
from typing import List, Sequence, Union
43

54
import numpy as np
5+
import torch
66
from numpy.testing import assert_array_almost_equal
77
from typing_extensions import Any, assert_never
88

99
from bioimageio.spec.model import v0_4, v0_5
1010
from bioimageio.spec.model.v0_5 import Version
1111

12-
from ._utils import load_torch_model
13-
14-
try:
15-
import torch
16-
except ImportError:
17-
torch = None
12+
from ...model_adapters._pytorch_model_adapter import PytorchModelAdapter
1813

1914

2015
# FIXME: remove Any
@@ -119,7 +114,9 @@ def convert_weights_to_torchscript(
119114
with torch.no_grad():
120115
input_data = [torch.from_numpy(inp.astype("float32")) for inp in input_data]
121116

122-
model = load_torch_model(state_dict_weights_descr)
117+
model = PytorchModelAdapter.get_network(
118+
state_dict_weights_descr, load_state=True
119+
)
123120

124121
# FIXME: remove Any
125122
if use_tracing:

0 commit comments

Comments
 (0)