Skip to content

Commit 1eb7b86

Browse files
Merge pull request #177 from bioimage-io/torchscript
Rename pytorch_script to torchscript
2 parents 06ec7d5 + d3dd1a3 commit 1eb7b86

File tree

12 files changed

+21
-24
lines changed

12 files changed

+21
-24
lines changed

.github/workflows/build.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ jobs:
8080
uses: conda-incubator/setup-miniconda@v2
8181
with:
8282
auto-update-conda: true
83+
# we need mamba to resolve environment-tf
8384
mamba-version: "*"
8485
channel-priority: strict
8586
activate-environment: bio-core-tf

bioimageio/core/__main__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -249,10 +249,10 @@ def convert_torch_weights_to_torchscript(
249249
output_path: Path = typer.Argument(..., help="Where to save the torchscript weights."),
250250
use_tracing: bool = typer.Option(True, help="Whether to use torch.jit tracing or scripting."),
251251
) -> int:
252-
ret_code = torch_converter.convert_weights_to_pytorch_script(model_rdf, output_path, use_tracing)
252+
ret_code = torch_converter.convert_weights_to_torchscript(model_rdf, output_path, use_tracing)
253253
sys.exit(ret_code)
254254

255-
convert_torch_weights_to_torchscript.__doc__ = torch_converter.convert_weights_to_pytorch_script.__doc__
255+
convert_torch_weights_to_torchscript.__doc__ = torch_converter.convert_weights_to_torchscript.__doc__
256256

257257

258258
if keras_converter is not None:

bioimageio/core/build_spec/build_model.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -113,10 +113,8 @@ def _get_weights(
113113
source=weight_source, sha256=weight_hash, opset_version=opset_version, **attachments
114114
)
115115

116-
elif weight_type == "pytorch_script":
117-
weights = model_spec.raw_nodes.PytorchScriptWeightsEntry(
118-
source=weight_source, sha256=weight_hash, **attachments
119-
)
116+
elif weight_type == "torchscript":
117+
weights = model_spec.raw_nodes.TorchscriptWeightsEntry(source=weight_source, sha256=weight_hash, **attachments)
120118

121119
elif weight_type == "keras_hdf5":
122120
if tensorflow_version is None:

bioimageio/core/prediction_pipeline/_model_adapters/_model_adapter.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
#: Known weight formats in order of priority
99
#: First match wins
10-
_WEIGHT_FORMATS = ["pytorch_state_dict", "tensorflow_saved_model_bundle", "pytorch_script", "onnx", "keras_hdf5"]
10+
_WEIGHT_FORMATS = ["pytorch_state_dict", "tensorflow_saved_model_bundle", "torchscript", "onnx", "keras_hdf5"]
1111

1212

1313
class ModelAdapter(abc.ABC):
@@ -114,7 +114,7 @@ def create_model_adapter(
114114
return adapter_cls(bioimageio_model=bioimageio_model, devices=devices)
115115

116116
raise RuntimeError(
117-
f"weight format {weight_format} not among weight formats listed in model: {list(bioimageio_model.weights.keys())}"
117+
f"weight format {weight_format} not among formats listed in model: {list(bioimageio_model.weights.keys())}"
118118
)
119119

120120

@@ -139,7 +139,7 @@ def _get_model_adapter(weight_format: str) -> Type[ModelAdapter]:
139139

140140
return ONNXModelAdapter
141141

142-
elif weight_format == "pytorch_script":
142+
elif weight_format == "torchscript":
143143
from ._torchscript_model_adapter import TorchscriptModelAdapter
144144

145145
return TorchscriptModelAdapter

bioimageio/core/prediction_pipeline/_model_adapters/_torchscript_model_adapter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
class TorchscriptModelAdapter(ModelAdapter):
1313
def _load(self, *, devices: Optional[List[str]] = None):
14-
weight_path = str(self.bioimageio_model.weights["pytorch_script"].source.resolve())
14+
weight_path = str(self.bioimageio_model.weights["torchscript"].source.resolve())
1515
if devices is None:
1616
self.devices = ["cuda" if torch.cuda.is_available() else "cpu"]
1717
else:

bioimageio/core/resource_io/nodes.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ class PytorchStateDictWeightsEntry(Node, model_raw_nodes.PytorchStateDictWeights
161161

162162

163163
@dataclass
164-
class PytorchScriptWeightsEntry(Node, model_raw_nodes.PytorchScriptWeightsEntry):
164+
class TorchscriptWeightsEntry(Node, model_raw_nodes.TorchscriptWeightsEntry):
165165
source: Path = missing
166166

167167

@@ -184,10 +184,10 @@ class Attachments(Node, model_raw_nodes.Attachments):
184184
WeightsEntry = Union[
185185
KerasHdf5WeightsEntry,
186186
OnnxWeightsEntry,
187-
PytorchScriptWeightsEntry,
188187
PytorchStateDictWeightsEntry,
189188
TensorflowJsWeightsEntry,
190189
TensorflowSavedModelBundleWeightsEntry,
190+
TorchscriptWeightsEntry,
191191
]
192192

193193

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
from .onnx import convert_weights_to_onnx
2-
from .torchscript import convert_weights_to_pytorch_script
2+
from .torchscript import convert_weights_to_torchscript

bioimageio/core/weight_converter/torch/torchscript.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def _check(input_):
7272
return ret
7373

7474

75-
def convert_weights_to_pytorch_script(
75+
def convert_weights_to_torchscript(
7676
model_spec: Union[str, Path, spec.model.raw_nodes.Model], output_path: Union[str, Path], use_tracing: bool = True
7777
):
7878
"""Convert model weights from format 'pytorch_state_dict' to 'torchscript'.

tests/build_spec/test_build_spec.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,10 @@ def _test_build_spec(
2828
model_kwargs = None if weight_spec.kwargs is missing else weight_spec.kwargs
2929
architecture = str(weight_spec.architecture)
3030
weight_type_ = None # the weight type can be auto-detected
31-
elif weight_type == "pytorch_script":
31+
elif weight_type == "torchscript":
3232
architecture = None
3333
model_kwargs = None
34-
weight_type_ = "pytorch_script" # the weight type CANNOT be auto-detcted
34+
weight_type_ = "torchscript" # the weight type CANNOT be auto-detcted
3535
else:
3636
architecture = None
3737
model_kwargs = None
@@ -109,7 +109,7 @@ def test_build_spec_implicit_output_shape(unet2d_nuclei_broad_model, tmp_path):
109109

110110

111111
def test_build_spec_torchscript(any_torchscript_model, tmp_path):
112-
_test_build_spec(any_torchscript_model, tmp_path / "model.zip", "pytorch_script")
112+
_test_build_spec(any_torchscript_model, tmp_path / "model.zip", "torchscript")
113113

114114

115115
def test_build_spec_onnx(any_onnx_model, tmp_path):
@@ -133,7 +133,7 @@ def test_build_spec_tfjs(any_tensorflow_js_model, tmp_path):
133133

134134

135135
def test_build_spec_deepimagej(unet2d_nuclei_broad_model, tmp_path):
136-
_test_build_spec(unet2d_nuclei_broad_model, tmp_path / "model.zip", "pytorch_script", add_deepimagej_config=True)
136+
_test_build_spec(unet2d_nuclei_broad_model, tmp_path / "model.zip", "torchscript", add_deepimagej_config=True)
137137

138138

139139
def test_build_spec_deepimagej_keras(unet2d_keras, tmp_path):

tests/prediction_pipeline/test_device_management.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from functools import wraps
2-
31
import numpy as np
42
import pytest
53
import xarray as xr
@@ -60,7 +58,7 @@ def test_device_management_torch(any_torch_model):
6058

6159
@skip_on(TooFewDevicesException, reason="Too few devices")
6260
def test_device_management_torchscript(any_torchscript_model):
63-
_test_device_management(any_torchscript_model, "pytorch_script")
61+
_test_device_management(any_torchscript_model, "torchscript")
6462

6563

6664
@pytest.mark.skipif(pytest.skip_torch, reason="requires torch for device discovery")

0 commit comments

Comments
 (0)