|
8 | 8 | import shutil |
9 | 9 | import subprocess |
10 | 10 | import sys |
| 11 | +from abc import ABC |
11 | 12 | from argparse import RawTextHelpFormatter |
12 | 13 | from difflib import SequenceMatcher |
13 | 14 | from functools import cached_property |
|
42 | 43 | SettingsConfigDict, |
43 | 44 | YamlConfigSettingsSource, |
44 | 45 | ) |
45 | | -from ruyaml import YAML |
46 | 46 | from tqdm import tqdm |
47 | 47 | from typing_extensions import assert_never |
48 | 48 |
|
|
53 | 53 | load_description, |
54 | 54 | save_bioimageio_yaml_only, |
55 | 55 | settings, |
| 56 | + update_format, |
| 57 | + update_hashes, |
56 | 58 | ) |
57 | 59 | from bioimageio.spec._internal.io_basics import ZipPath |
| 60 | +from bioimageio.spec._internal.io_utils import yaml |
58 | 61 | from bioimageio.spec._internal.types import NotEmpty |
59 | 62 | from bioimageio.spec.dataset import DatasetDescr |
60 | 63 | from bioimageio.spec.model import ModelDescr, v0_4, v0_5 |
|
66 | 69 | WeightFormatArgAny, |
67 | 70 | package, |
68 | 71 | test, |
69 | | - update_format, |
70 | 72 | ) |
71 | 73 | from .common import MemberId, SampleId, SupportedWeightsFormat |
72 | 74 | from .digest_spec import get_member_ids, load_sample_for_model |
|
84 | 86 | from .utils import VERSION |
85 | 87 | from .weight_converters._add_weights import add_weights |
86 | 88 |
|
87 | | -yaml = YAML(typ="safe") |
88 | | - |
89 | 89 |
|
90 | 90 | class CmdBase(BaseModel, use_attribute_docstrings=True, cli_implicit_flags=True): |
91 | 91 | pass |
@@ -254,31 +254,68 @@ def _get_stat( |
254 | 254 | return stat |
255 | 255 |
|
256 | 256 |
|
257 | | -class UpdateFormatCmd(CmdBase, WithSource): |
258 | | - """Update the metadata format""" |
| 257 | +class UpdateCmdBase(CmdBase, WithSource, ABC): |
| 258 | + output: Union[Literal["render", "stdout"], Path] = "render" |
| 259 | + """Output updated bioimageio.yaml to the terminal or write to a file.""" |
259 | 260 |
|
260 | | - output: Optional[Path] = None |
261 | | - """Save updated bioimageio.yaml to this file. |
| 261 | + exclude_unset: bool = Field(True, alias="exclude-unset") |
| 262 | + """Exclude fields that have not explicitly be set.""" |
262 | 263 |
|
263 | | - Updated bioimageio.yaml is rendered to the terminal if the output is None. |
264 | | - """ |
265 | | - |
266 | | - exclude_defaults: bool = Field(True, alias="exclude-defaults") |
| 264 | + exclude_defaults: bool = Field(False, alias="exclude-defaults") |
267 | 265 | """Exclude fields that have the default value (even if set explicitly).""" |
268 | 266 |
|
| 267 | + @cached_property |
| 268 | + def updated(self) -> Union[ResourceDescr, InvalidDescr]: |
| 269 | + raise NotImplementedError |
| 270 | + |
269 | 271 | def run(self): |
270 | | - updated = update_format( |
271 | | - self.source, output=self.output, exclude_defaults=self.exclude_defaults |
272 | | - ) |
273 | | - updated_stream = StringIO() |
| 272 | + if self.output == "render": |
| 273 | + out = StringIO() |
| 274 | + elif self.output == "stdout": |
| 275 | + out = sys.stdout |
| 276 | + else: |
| 277 | + out = self.output |
| 278 | + |
274 | 279 | save_bioimageio_yaml_only( |
275 | | - updated, updated_stream, exclude_defaults=self.exclude_defaults |
| 280 | + self.updated, |
| 281 | + out, |
| 282 | + exclude_unset=self.exclude_unset, |
| 283 | + exclude_defaults=self.exclude_defaults, |
| 284 | + ) |
| 285 | + |
| 286 | + if self.output == "render": |
| 287 | + assert isinstance(out, StringIO) |
| 288 | + updated_md = f"```yaml\n{out.getvalue()}\n```" |
| 289 | + |
| 290 | + rich_markdown = rich.markdown.Markdown(updated_md) |
| 291 | + console = rich.console.Console() |
| 292 | + console.print(rich_markdown) |
| 293 | + |
| 294 | + |
| 295 | +class UpdateFormatCmd(UpdateCmdBase): |
| 296 | + """Update the metadata format to the latest format version.""" |
| 297 | + |
| 298 | + perform_io_checks: bool = Field( |
| 299 | + settings.perform_io_checks, alias="perform-io-checks" |
| 300 | + ) |
| 301 | + """Wether or not to attempt validation that may require file download. |
| 302 | + If `True` file hash values are added if not present.""" |
| 303 | + |
| 304 | + @cached_property |
| 305 | + def updated(self): |
| 306 | + return update_format( |
| 307 | + self.source, |
| 308 | + exclude_defaults=self.exclude_defaults, |
| 309 | + perform_io_checks=self.perform_io_checks, |
276 | 310 | ) |
277 | | - updated_md = f"```yaml\n{updated_stream.getvalue()}\n```" |
278 | 311 |
|
279 | | - rich_markdown = rich.markdown.Markdown(updated_md) |
280 | | - console = rich.console.Console() |
281 | | - console.print(rich_markdown) |
| 312 | + |
| 313 | +class UpdateHashesCmd(UpdateCmdBase): |
| 314 | + """Create a bioimageio.yaml description with updated file hashes.""" |
| 315 | + |
| 316 | + @cached_property |
| 317 | + def updated(self): |
| 318 | + return update_hashes(self.source) |
282 | 319 |
|
283 | 320 |
|
284 | 321 | class PredictCmd(CmdBase, WithSource): |
@@ -690,6 +727,9 @@ class Bioimageio( |
690 | 727 | update_format: CliSubCommand[UpdateFormatCmd] = Field(alias="update-format") |
691 | 728 | """Update the metadata format""" |
692 | 729 |
|
| 730 | + update_hashes: CliSubCommand[UpdateHashesCmd] = Field(alias="update-hashes") |
| 731 | + """Create a bioimageio.yaml description with updated file hashes.""" |
| 732 | + |
693 | 733 | add_weights: CliSubCommand[ConvertWeightsCmd] = Field(alias="add-weights") |
694 | 734 | """Add additional weights to the model descriptions converted from available |
695 | 735 | formats to improve deployability.""" |
@@ -732,12 +772,13 @@ def run(self): |
732 | 772 | pformat({k: v for k, v in self.model_dump().items() if v is not None}), |
733 | 773 | ) |
734 | 774 | cmd = ( |
735 | | - self.validate_format |
736 | | - or self.test |
| 775 | + self.add_weights |
737 | 776 | or self.package |
738 | 777 | or self.predict |
| 778 | + or self.test |
739 | 779 | or self.update_format |
740 | | - or self.add_weights |
| 780 | + or self.update_hashes |
| 781 | + or self.validate_format |
741 | 782 | ) |
742 | 783 | assert cmd is not None |
743 | 784 | cmd.run() |
|
0 commit comments