Skip to content

Commit 683a798

Browse files
committed
improve typing
1 parent 63f37a8 commit 683a798

28 files changed

+151
-72
lines changed

.pre-commit-config.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,11 @@ repos:
44
hooks:
55
- id: black
66
- repo: https://github.com/pre-commit/mirrors-mypy
7-
rev: v0.961
7+
rev: v0.991
88
hooks:
99
- id: mypy
1010
additional_dependencies: [types-requests]
11+
args: [--install-types, --non-interactive, --explicit-package-bases, --check-untyped-defs]
1112
- repo: local
1213
hooks:
1314
- id: generate rdf docs

bioimageio/spec/__main__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
import sys
22
from pathlib import Path
33
from pprint import pprint
4+
from typing import Any, Callable, Dict, Optional, Union
45

56
import typer
67

78
from bioimageio.spec import __version__, collection, commands, model, rdf
9+
from spec.shared.raw_nodes import URI
810

11+
enrich_partial_rdf_with_imjoy_plugin: Optional[Callable[[Dict[str, Any], Union[URI, Path]], Dict[str, Any]]]
912
try:
1013
from bioimageio.spec.partner.utils import enrich_partial_rdf_with_imjoy_plugin
1114
except ImportError:
@@ -81,6 +84,7 @@ def validate_partner_collection(
8184
),
8285
verbose: bool = typer.Option(False, help="show traceback of unexpected (no ValidationError) exceptions"),
8386
):
87+
assert enrich_partial_rdf_with_imjoy_plugin is not None
8488
summary = commands.validate(
8589
rdf_source, update_format, update_format_inner, enrich_partial_rdf=enrich_partial_rdf_with_imjoy_plugin
8690
)
@@ -101,10 +105,12 @@ def validate_partner_collection(
101105

102106
sys.exit(ret_code)
103107

108+
cmd_doc = commands.validate.__doc__
109+
assert cmd_doc is not None
104110
validate_partner_collection.__doc__ = (
105111
"A special version of the bioimageio validate command that enriches the RDFs defined in collections by parsing any "
106112
"associated imjoy plugins. This is implemented using the 'enrich_partial_rdf' of the regular validate command:\n"
107-
+ commands.validate.__doc__
113+
+ cmd_doc
108114
)
109115

110116

bioimageio/spec/collection/v0_2/schema.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,21 @@
1-
from marshmallow import missing, validates
1+
from types import ModuleType
2+
from typing import ClassVar, List, Union
3+
4+
from marshmallow import INCLUDE, missing, validates
25

36
from bioimageio.spec.rdf.v0_2.schema import RDF
47
from bioimageio.spec.shared import fields
58
from bioimageio.spec.shared.schema import SharedBioImageIOSchema, WithUnknown
69
from . import raw_nodes
710

811
try:
9-
from typing import List, Union, get_args
12+
from typing import get_args
1013
except ImportError:
1114
from typing_extensions import get_args # type: ignore
1215

1316

1417
class _BioImageIOSchema(SharedBioImageIOSchema):
15-
raw_nodes = raw_nodes
18+
raw_nodes: ClassVar[ModuleType] = raw_nodes
1619

1720

1821
class CollectionEntry(_BioImageIOSchema, WithUnknown):
@@ -21,6 +24,9 @@ class CollectionEntry(_BioImageIOSchema, WithUnknown):
2124

2225

2326
class Collection(_BioImageIOSchema, WithUnknown, RDF):
27+
class Meta:
28+
unknown = INCLUDE
29+
2430
bioimageio_description = f"""# BioImage.IO Collection Resource Description File Specification {get_args(raw_nodes.FormatVersion)[-1]}
2531
This specification defines the fields used in a BioImage.IO-compliant resource description file (`RDF`) for describing collections of other resources.
2632
These fields are typically stored in a YAML file which we call Collection Resource Description File or `collection RDF`.

bioimageio/spec/commands.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def update_format(
3232
def validate(
3333
rdf_source: Union[RawResourceDescription, dict, os.PathLike, IO, str, bytes],
3434
update_format: bool = False,
35-
update_format_inner: bool = None,
35+
update_format_inner: Optional[bool] = None,
3636
verbose: bool = "deprecated", # type: ignore
3737
enrich_partial_rdf: Callable[[dict, Union[URI, Path]], dict] = default_enrich_partial_rdf,
3838
) -> ValidationSummary:
@@ -148,7 +148,7 @@ def validate(
148148
def update_rdf(
149149
source: Union[RawResourceDescription, dict, os.PathLike, IO, str, bytes],
150150
update: Union[RawResourceDescription, dict, os.PathLike, IO, str, bytes],
151-
output: Union[dict, os.PathLike] = None,
151+
output: Union[None, dict, os.PathLike] = None,
152152
validate_output: bool = True,
153153
) -> Union[dict, Path, RawResourceDescription]:
154154
"""
@@ -215,6 +215,7 @@ def update_rdf(
215215
f"Failed to convert paths in updated rdf to relative paths with root {output}; error: {e}"
216216
)
217217
warnings.warn(f"updated rdf at {output} contains absolute paths and is thus invalid!")
218+
assert isinstance(out_data, RawResourceDescription)
218219
out_data = serialize_raw_resource_description_to_dict(out_data, convert_absolute_paths=False)
219220

220221
yaml.dump(out_data, output)

bioimageio/spec/dataset/v0_2/schema.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
from types import ModuleType
2+
from typing import ClassVar
3+
14
from bioimageio.spec.rdf.v0_2.schema import RDF
25
from bioimageio.spec.shared.schema import SharedBioImageIOSchema
36
from . import raw_nodes
@@ -9,7 +12,7 @@
912

1013

1114
class _BioImageIOSchema(SharedBioImageIOSchema):
12-
raw_nodes = raw_nodes
15+
raw_nodes: ClassVar[ModuleType] = raw_nodes
1316

1417

1518
class Dataset(_BioImageIOSchema, RDF):

bioimageio/spec/io_.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def extract_resource_package(
8181
source: Union[os.PathLike, IO, str, bytes, raw_nodes.URI]
8282
) -> Tuple[dict, str, pathlib.Path]:
8383
"""extract a zip source to BIOIMAGEIO_CACHE_PATH"""
84-
source, source_name, root = resolve_rdf_source(source)
84+
src, source_name, root = resolve_rdf_source(source)
8585
if isinstance(root, bytes):
8686
raise NotImplementedError("package source was bytes")
8787

@@ -123,7 +123,7 @@ def extract_resource_package(
123123
warnings.warn(f"Could not remove download {download} due to {e}")
124124

125125
assert isinstance(package_path, pathlib.Path)
126-
return source, source_name, package_path
126+
return src, source_name, package_path
127127

128128

129129
def load_raw_resource_description(
@@ -292,7 +292,7 @@ def get_resource_package_content_wo_rdf(
292292
raw_rd: Union[GenericRawRD, raw_nodes.URI, str, pathlib.Path],
293293
*,
294294
weights_priority_order: Optional[Sequence[str]] = None, # model only
295-
) -> Tuple[GenericRawNode, Dict[str, Union[pathlib.PurePath, raw_nodes.URI]]]:
295+
) -> Tuple[raw_nodes.ResourceDescription, Dict[str, Union[pathlib.PurePath, raw_nodes.URI]]]:
296296
"""
297297
Args:
298298
raw_rd: raw resource description
@@ -305,25 +305,27 @@ def get_resource_package_content_wo_rdf(
305305
keyed by file names.
306306
Important note: the serialized rdf.yaml is not included.
307307
"""
308-
if not isinstance(raw_rd, raw_nodes.ResourceDescription):
309-
raw_rd = load_raw_resource_description(raw_rd)
308+
if isinstance(raw_rd, raw_nodes.ResourceDescription):
309+
r_rd = raw_rd
310+
else:
311+
r_rd = load_raw_resource_description(raw_rd)
310312

311-
sub_spec = _get_spec_submodule(raw_rd.type, raw_rd.format_version)
312-
if raw_rd.type == "model":
313+
sub_spec = _get_spec_submodule(r_rd.type, r_rd.format_version)
314+
if r_rd.type == "model":
313315
filter_kwargs = dict(weights_priority_order=weights_priority_order)
314316
else:
315317
filter_kwargs = {}
316318

317-
raw_rd = sub_spec.utils.filter_resource_description(raw_rd, **filter_kwargs)
319+
r_rd = sub_spec.utils.filter_resource_description(r_rd, **filter_kwargs)
318320

319-
content: Dict[str, Union[pathlib.PurePath, raw_nodes.URI, str]] = {}
320-
raw_rd = RawNodePackageTransformer(content, raw_rd.root_path).transform(raw_rd)
321+
content: Dict[str, Union[pathlib.PurePath, raw_nodes.URI]] = {}
322+
r_rd = RawNodePackageTransformer(content, r_rd.root_path).transform(r_rd)
321323
assert "rdf.yaml" not in content
322-
return raw_rd, content
324+
return r_rd, content
323325

324326

325327
def get_resource_package_content(
326-
raw_rd: Union[GenericRawNode, raw_nodes.URI, str, pathlib.Path],
328+
raw_rd: Union[raw_nodes.ResourceDescription, raw_nodes.URI, str, pathlib.Path],
327329
*,
328330
weights_priority_order: Optional[Sequence[str]] = None, # model only
329331
) -> Dict[str, Union[str, pathlib.PurePath, raw_nodes.URI]]:
@@ -343,7 +345,5 @@ def get_resource_package_content(
343345
"without yaml"
344346
)
345347

346-
content: Dict[str, Union[str, pathlib.PurePath, raw_nodes.URI]]
347-
raw_rd, content = get_resource_package_content_wo_rdf(raw_rd, weights_priority_order=weights_priority_order)
348-
content["rdf.yaml"] = serialize_raw_resource_description(raw_rd)
349-
return content
348+
r_rd, content = get_resource_package_content_wo_rdf(raw_rd, weights_priority_order=weights_priority_order)
349+
return {**content, **{"rdf.yaml": serialize_raw_resource_description(r_rd)}}

bioimageio/spec/model/v0_3/converters.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import copy
22
import pathlib
3-
from typing import Any, Dict
3+
from typing import Any, Dict, Union
44

55
from marshmallow import Schema
66

@@ -62,7 +62,7 @@ class DocSchema(Schema):
6262
assert isinstance(orig_doc, str)
6363
if orig_doc.startswith("http"):
6464
if orig_doc.endswith(".md"):
65-
doc = raw_nodes.URI(orig_doc)
65+
doc: Union[raw_nodes.URI, str, pathlib.Path] = raw_nodes.URI(orig_doc)
6666
else:
6767
doc = f"Find documentation at {orig_doc}"
6868
else:

bioimageio/spec/model/v0_3/schema.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import typing
22
import warnings
33
from copy import deepcopy
4+
from types import ModuleType
45

56
from marshmallow import RAISE, ValidationError, missing as missing_, post_load, pre_dump, pre_load, validates_schema
67

@@ -21,7 +22,7 @@
2122

2223

2324
class _BioImageIOSchema(SharedBioImageIOSchema):
24-
raw_nodes = raw_nodes
25+
raw_nodes: typing.ClassVar[ModuleType] = raw_nodes
2526

2627

2728
class RunMode(_BioImageIOSchema):
@@ -474,7 +475,7 @@ class ModelParent(_BioImageIOSchema):
474475

475476

476477
class Model(rdf.schema.RDF):
477-
raw_nodes = raw_nodes
478+
raw_nodes: typing.ClassVar[ModuleType] = raw_nodes
478479

479480
class Meta:
480481
unknown = RAISE
@@ -758,7 +759,7 @@ def validate_reference_tensor_names(self, data, **kwargs):
758759
raise ValidationError(f"{ref_tensor} not found in inputs")
759760

760761
@validates_schema
761-
def weights_entries_match_weights_formats(self, data, **kwargs):
762+
def weights_entries_match_weights_formats(self, data, **kwargs) -> None:
762763
weights: typing.Dict[str, _WeightsEntryBase] = data["weights"]
763764
for weights_format, weights_entry in weights.items():
764765
if weights_format in ["keras_hdf5", "tensorflow_js", "tensorflow_saved_model_bundle"]:

bioimageio/spec/model/v0_4/raw_nodes.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,9 @@
4040

4141
# reassign to use imported classes
4242
ImplicitOutputShape = ImplicitOutputShape
43+
InputTensor = InputTensor
4344
Maintainer = Maintainer
45+
OutputTensor = OutputTensor
4446
ParametrizedInputShape = ParametrizedInputShape
4547
Postprocessing = Postprocessing
4648
PostprocessingName = PostprocessingName

bioimageio/spec/model/v0_4/schema.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import typing
22
from copy import deepcopy
3+
from types import ModuleType
34

45
import numpy
56
from marshmallow import RAISE, ValidationError, missing, pre_load, validates, validates_schema
@@ -23,7 +24,7 @@
2324

2425

2526
class _BioImageIOSchema(SharedBioImageIOSchema):
26-
raw_nodes = raw_nodes
27+
raw_nodes: typing.ClassVar[ModuleType] = raw_nodes
2728

2829

2930
class _TensorBase(_BioImageIOSchema):
@@ -184,7 +185,7 @@ def matching_halo_length(self, data, **kwargs):
184185

185186

186187
class _WeightsEntryBase(_WeightsEntryBase03):
187-
raw_nodes = raw_nodes
188+
raw_nodes: typing.ClassVar[ModuleType] = raw_nodes
188189
dependencies = fields.Dependencies(
189190
bioimageio_description="Dependency manager and dependency file, specified as `<dependency manager>:<relative "
190191
"path to file>`. For example: 'conda:./environment.yaml', 'maven:./pom.xml', or 'pip:./requirements.txt'. "
@@ -249,7 +250,7 @@ class Dataset(_Dataset):
249250

250251

251252
class TorchscriptWeightsEntry(_WeightsEntryBase):
252-
raw_nodes = raw_nodes
253+
raw_nodes: typing.ClassVar[ModuleType] = raw_nodes
253254

254255
bioimageio_description = "Torchscript weights format"
255256
weights_format = fields.String(validate=field_validators.Equal("torchscript"), required=True, load_only=True)
@@ -294,7 +295,7 @@ def id_xor_uri(self, data, **kwargs):
294295

295296

296297
class Model(rdf.schema.RDF):
297-
raw_nodes = raw_nodes
298+
raw_nodes: typing.ClassVar = raw_nodes
298299

299300
class Meta:
300301
unknown = RAISE
@@ -433,7 +434,7 @@ def no_duplicate_output_tensor_names(self, value: typing.List[raw_nodes.OutputTe
433434
raise ValidationError("Duplicate output tensor names are not allowed.")
434435

435436
@validates_schema
436-
def inputs_and_outputs(self, data, **kwargs):
437+
def inputs_and_outputs(self, data, **kwargs) -> None:
437438
ipts: typing.List[raw_nodes.InputTensor] = data.get("inputs")
438439
outs: typing.List[raw_nodes.OutputTensor] = data.get("outputs")
439440
if any(
@@ -603,7 +604,7 @@ def add_weights_format_key_to_weights_entry_value(self, data: dict, many=False,
603604
return data
604605

605606
@validates_schema
606-
def validate_reference_tensor_names(self, data, **kwargs):
607+
def validate_reference_tensor_names(self, data, **kwargs) -> None:
607608
def get_tnames(tname: str):
608609
return [t.get("name") if isinstance(t, dict) else t.name for t in data.get(tname, [])]
609610

@@ -649,7 +650,7 @@ def get_tnames(tname: str):
649650
raise ValidationError(f"invalid self reference for preprocessing of tensor {t.name}")
650651

651652
@validates_schema
652-
def weights_entries_match_weights_formats(self, data, **kwargs):
653+
def weights_entries_match_weights_formats(self, data, **kwargs) -> None:
653654
weights: typing.Dict[str, WeightsEntry] = data.get("weights", {})
654655
for weights_format, weights_entry in weights.items():
655656
if not isinstance(weights_entry, get_args(raw_nodes.WeightsEntry)):

0 commit comments

Comments
 (0)