Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
f50b278
bump pyright
FynnBe Mar 31, 2025
767395e
export ValidationSummary
FynnBe Apr 2, 2025
a58b388
bump spec
FynnBe Apr 2, 2025
0b5f619
Merge branch 'dev' into update_examples
FynnBe Apr 2, 2025
aa2a2c1
add pytorch < 1.13 compatibility
FynnBe Apr 2, 2025
4ab00b1
tes affable-shark again
FynnBe Apr 2, 2025
096d440
pyright fix private import errors
FynnBe Apr 3, 2025
968317e
bump spec
FynnBe Apr 7, 2025
0f5b61d
run pytest-coverage only in CI as debugging with it is buggy in vscode
FynnBe Apr 11, 2025
009b6c0
WIP update model_usage.ipynb
FynnBe Apr 30, 2025
6f73473
improve io
FynnBe May 22, 2025
d50c0c4
bump pyright
FynnBe May 22, 2025
49081f9
bump spec
FynnBe May 22, 2025
39e3da7
Merge branch 'bump_spec' into dev + black
FynnBe May 22, 2025
85663d0
update _import_from_file_impl
FynnBe Jun 10, 2025
6e7219e
Merge branch 'main' into dev
FynnBe Jun 10, 2025
d79ca05
update DL backends
FynnBe Jun 10, 2025
e161f56
more get_reader updates
FynnBe Jun 10, 2025
aca9186
bump pyright
FynnBe Jun 13, 2025
7b4337c
update tests
FynnBe Jun 24, 2025
0877470
make dynamic import more robust
FynnBe Jun 24, 2025
c12b990
load h5 from memory
FynnBe Jun 24, 2025
c10c677
do not fail on unexpected return value of load_state_dict
FynnBe Jun 25, 2025
57ea3ff
warn about tolerated mismatched elements
FynnBe Jun 25, 2025
565fd3e
update tests
FynnBe Jun 25, 2025
a2c8406
expose format_version arg to CLI test cmd
FynnBe Jun 26, 2025
e48eade
update spec reference
FynnBe Jun 26, 2025
0dbe380
mark min python version=3.8
FynnBe Jun 27, 2025
f08a728
skip cache population if size limit hit
FynnBe Jun 27, 2025
2e11fa8
skip installation of careamics
FynnBe Jul 1, 2025
4b403bd
mark known test failure
FynnBe Jul 1, 2025
9fe2747
fix cli test
FynnBe Jul 1, 2025
2766274
fix respx package name
FynnBe Jul 1, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 1 addition & 7 deletions .github/workflows/build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,6 @@ jobs:
strategy:
matrix:
include:
- python-version: '3.8'
conda-env: py38
spec: conda
- python-version: '3.8'
conda-env: py38
spec: main
- python-version: '3.9'
conda-env: dev
spec: conda
Expand Down Expand Up @@ -174,7 +168,7 @@ jobs:
path: bioimageio_cache
key: ${{matrix.run-expensive-tests && needs.populate-cache.outputs.cache-key || needs.populate-cache.outputs.cache-key-light}}
- name: pytest
run: pytest --disable-pytest-warnings
run: pytest --cov bioimageio --cov-report xml --cov-append --capture no --disable-pytest-warnings
env:
BIOIMAGEIO_CACHE_PATH: bioimageio_cache
RUN_EXPENSIVE_TESTS: ${{ matrix.run-expensive-tests && 'true' || 'false' }}
Expand Down
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,10 @@ may be controlled with the `LOGURU_LEVEL` environment variable.

## Changelog

### 0.9.0 (coming soon)

- update to [bioimageio.spec 0.5.4.3](https://github.com/bioimage-io/spec-bioimage-io/blob/main/changelog.md#bioimageiospec-0543)

### 0.8.0

- breaking: removed `decimals` argument from bioimageio CLI and `bioimageio.core.commands.test()`
Expand Down
2 changes: 2 additions & 0 deletions bioimageio/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""

from bioimageio.spec import (
ValidationSummary,
build_description,
dump_description,
load_dataset_description,
Expand Down Expand Up @@ -112,4 +113,5 @@
"test_model",
"test_resource",
"validate_format",
"ValidationSummary",
]
94 changes: 66 additions & 28 deletions bioimageio/core/_resource_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
overload,
)

import xarray as xr
from loguru import logger
from typing_extensions import NotRequired, TypedDict, Unpack, assert_never, get_args

Expand Down Expand Up @@ -55,6 +56,7 @@
InstalledPackage,
ValidationDetail,
ValidationSummary,
WarningEntry,
)

from ._prediction_pipeline import create_prediction_pipeline
Expand Down Expand Up @@ -510,7 +512,7 @@ def load_description_and_test(

enable_determinism(determinism, weight_formats=weight_formats)
for w in weight_formats:
_test_model_inference(rd, w, devices, **deprecated)
_test_model_inference(rd, w, devices, stop_early=stop_early, **deprecated)
if stop_early and rd.validation_summary.status == "failed":
break

Expand Down Expand Up @@ -587,14 +589,16 @@ def _test_model_inference(
model: Union[v0_4.ModelDescr, v0_5.ModelDescr],
weight_format: SupportedWeightsFormat,
devices: Optional[Sequence[str]],
stop_early: bool,
**deprecated: Unpack[DeprecatedKwargs],
) -> None:
test_name = f"Reproduce test outputs from test inputs ({weight_format})"
logger.debug("starting '{}'", test_name)
errors: List[ErrorEntry] = []
error_entries: List[ErrorEntry] = []
warning_entries: List[WarningEntry] = []

def add_error_entry(msg: str, with_traceback: bool = False):
errors.append(
error_entries.append(
ErrorEntry(
loc=("weights", weight_format),
msg=msg,
Expand All @@ -603,6 +607,15 @@ def add_error_entry(msg: str, with_traceback: bool = False):
)
)

def add_warning_entry(msg: str):
warning_entries.append(
WarningEntry(
loc=("weights", weight_format),
msg=msg,
type="bioimageio.core",
)
)

try:
inputs = get_test_inputs(model)
expected = get_test_outputs(model)
Expand All @@ -622,34 +635,58 @@ def add_error_entry(msg: str, with_traceback: bool = False):
actual = results.members.get(m)
if actual is None:
add_error_entry("Output tensors for test case may not be None")
break
if stop_early:
break
else:
continue

rtol, atol, mismatched_tol = _get_tolerance(
model, wf=weight_format, m=m, **deprecated
)
mismatched = (abs_diff := abs(actual - expected)) > atol + rtol * abs(
expected
)
rtol_value = rtol * abs(expected)
abs_diff = abs(actual - expected)
mismatched = abs_diff > atol + rtol_value
mismatched_elements = mismatched.sum().item()
if mismatched_elements / expected.size > mismatched_tol / 1e6:
r_max_idx = (r_diff := (abs_diff / (abs(expected) + 1e-6))).argmax()
r_max = r_diff[r_max_idx].item()
r_actual = actual[r_max_idx].item()
r_expected = expected[r_max_idx].item()
a_max_idx = abs_diff.argmax()
a_max = abs_diff[a_max_idx].item()
a_actual = actual[a_max_idx].item()
a_expected = expected[a_max_idx].item()
add_error_entry(
f"Output '{m}' disagrees with {mismatched_elements} of"
+ f" {expected.size} expected values."
+ f"\n Max relative difference: {r_max:.2e}"
+ rf" (= \|{r_actual:.2e} - {r_expected:.2e}\|/\|{r_expected:.2e} + 1e-6\|)"
+ f" at {r_max_idx}"
+ f"\n Max absolute difference: {a_max:.2e}"
+ rf" (= \|{a_actual:.7e} - {a_expected:.7e}\|) at {a_max_idx}"
)
break
if not mismatched_elements:
continue

mismatched_ppm = mismatched_elements / expected.size * 1e6
abs_diff[~mismatched] = 0 # ignore non-mismatched elements

r_max_idx = (r_diff := (abs_diff / (abs(expected) + 1e-6))).argmax()
r_max = r_diff[r_max_idx].item()
r_actual = actual[r_max_idx].item()
r_expected = expected[r_max_idx].item()

# Calculate the max absolute difference with the relative tolerance subtracted
abs_diff_wo_rtol: xr.DataArray = xr.ufuncs.maximum(
(abs_diff - rtol_value).data, 0
)
a_max_idx = {
AxisId(k): int(v) for k, v in abs_diff_wo_rtol.argmax().items()
}

a_max = abs_diff[a_max_idx].item()
a_actual = actual[a_max_idx].item()
a_expected = expected[a_max_idx].item()

msg = (
f"Output '{m}' disagrees with {mismatched_elements} of"
+ f" {expected.size} expected values"
+ f" ({mismatched_ppm:.1f} ppm)."
+ f"\n Max relative difference: {r_max:.2e}"
+ rf" (= \|{r_actual:.2e} - {r_expected:.2e}\|/\|{r_expected:.2e} + 1e-6\|)"
+ f" at {r_max_idx}"
+ f"\n Max absolute difference not accounted for by relative tolerance: {a_max:.2e}"
+ rf" (= \|{a_actual:.7e} - {a_expected:.7e}\|) at {a_max_idx}"
)
if mismatched_ppm > mismatched_tol:
add_error_entry(msg)
if stop_early:
break
else:
add_warning_entry(msg)

except Exception as e:
if get_validation_context().raise_errors:
raise e
Expand All @@ -660,9 +697,10 @@ def add_error_entry(msg: str, with_traceback: bool = False):
ValidationDetail(
name=test_name,
loc=("weights", weight_format),
status="failed" if errors else "passed",
status="failed" if error_entries else "passed",
recommended_env=get_conda_env(entry=dict(model.weights)[weight_format]),
errors=errors,
errors=error_entries,
warnings=warning_entries,
)
)

Expand Down
3 changes: 0 additions & 3 deletions bioimageio/core/_settings.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
from typing import Literal

from dotenv import load_dotenv
from pydantic import Field
from typing_extensions import Annotated

from bioimageio.spec._internal._settings import Settings as SpecSettings

_ = load_dotenv()


class Settings(SpecSettings):
"""environment variables for bioimageio.spec and bioimageio.core"""
Expand Down
24 changes: 20 additions & 4 deletions bioimageio/core/backends/keras_backend.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,27 @@
import os
import shutil
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Any, Optional, Sequence, Union

import h5py # pyright: ignore[reportMissingTypeStubs]
from keras.src.legacy.saving import ( # pyright: ignore[reportMissingTypeStubs]
legacy_h5_format,
)
from loguru import logger
from numpy.typing import NDArray

from bioimageio.spec._internal.io import download
from bioimageio.spec._internal.type_guards import is_list, is_tuple
from bioimageio.spec.model import v0_4, v0_5
from bioimageio.spec.model.v0_5 import Version

from .._settings import settings
from ..digest_spec import get_axes_infos
from ..utils._type_guards import is_list, is_tuple
from ._model_adapter import ModelAdapter

os.environ["KERAS_BACKEND"] = settings.keras_backend


# by default, we use the keras integrated with tensorflow
# TODO: check if we should prefer keras
try:
Expand Down Expand Up @@ -67,9 +74,18 @@ def __init__(
devices,
)

weight_path = download(model_description.weights.keras_hdf5.source).path
weight_reader = model_description.weights.keras_hdf5.get_reader()
if weight_reader.suffix in (".h5", "hdf5"):
h5_file = h5py.File(weight_reader, mode="r")
self._network = legacy_h5_format.load_model_from_hdf5(h5_file)
else:
with TemporaryDirectory() as temp_dir:
temp_path = Path(temp_dir) / weight_reader.original_file_name
with temp_path.open("wb") as f:
shutil.copyfileobj(weight_reader, f)

self._network = keras.models.load_model(temp_path)

self._network = keras.models.load_model(weight_path)
self._output_axes = [
tuple(a.id for a in get_axes_infos(out))
for out in model_description.outputs
Expand Down
7 changes: 3 additions & 4 deletions bioimageio/core/backends/onnx_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,10 @@
import onnxruntime as rt # pyright: ignore[reportMissingTypeStubs]
from numpy.typing import NDArray

from bioimageio.spec._internal.type_guards import is_list, is_tuple
from bioimageio.spec.model import v0_4, v0_5
from bioimageio.spec.utils import download

from ..model_adapters import ModelAdapter
from ..utils._type_guards import is_list, is_tuple


class ONNXModelAdapter(ModelAdapter):
Expand All @@ -24,8 +23,8 @@ def __init__(
if model_description.weights.onnx is None:
raise ValueError("No ONNX weights specified for {model_description.name}")

local_path = download(model_description.weights.onnx.source).path
self._session = rt.InferenceSession(local_path.read_bytes())
reader = model_description.weights.onnx.get_reader()
self._session = rt.InferenceSession(reader.read())
onnx_inputs = self._session.get_inputs()
self._input_names: List[str] = [ipt.name for ipt in onnx_inputs]

Expand Down
50 changes: 34 additions & 16 deletions bioimageio/core/backends/pytorch_backend.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import gc
import warnings
from contextlib import nullcontext
from io import TextIOWrapper
from io import BytesIO, TextIOWrapper
from pathlib import Path
from typing import Any, List, Literal, Optional, Sequence, Union

Expand All @@ -11,12 +11,13 @@
from torch import nn
from typing_extensions import assert_never

from bioimageio.spec._internal.type_guards import is_list, is_ndarray, is_tuple
from bioimageio.spec.common import ZipPath
from bioimageio.spec._internal.version_type import Version
from bioimageio.spec.common import BytesReader, ZipPath
from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5
from bioimageio.spec.utils import download

from ..digest_spec import import_callable
from ..utils._type_guards import is_list, is_ndarray, is_tuple
from ._model_adapter import ModelAdapter


Expand Down Expand Up @@ -73,7 +74,9 @@ def _forward_impl(
if r is None:
result.append(None)
elif isinstance(r, torch.Tensor):
r_np: NDArray[Any] = r.detach().cpu().numpy()
r_np: NDArray[Any] = ( # pyright: ignore[reportUnknownVariableType]
r.detach().cpu().numpy()
)
result.append(r_np)
elif is_ndarray(r):
result.append(r)
Expand Down Expand Up @@ -129,34 +132,49 @@ def load_torch_model(
if load_state:
torch_model = load_torch_state_dict(
torch_model,
path=download(weight_spec).path,
path=download(weight_spec),
devices=use_devices,
)
return torch_model


def load_torch_state_dict(
model: nn.Module,
path: Union[Path, ZipPath],
path: Union[Path, ZipPath, BytesReader],
devices: Sequence[torch.device],
) -> nn.Module:
model = model.to(devices[0])
with path.open("rb") as f:
if isinstance(path, (Path, ZipPath)):
ctxt = path.open("rb")
else:
ctxt = nullcontext(BytesIO(path.read()))

with ctxt as f:
assert not isinstance(f, TextIOWrapper)
state = torch.load(f, map_location=devices[0], weights_only=True)
if Version(str(torch.__version__)) < Version("1.13"):
state = torch.load(f, map_location=devices[0])
else:
state = torch.load(f, map_location=devices[0], weights_only=True)

incompatible = model.load_state_dict(state)
if (
incompatible is not None # pyright: ignore[reportUnnecessaryComparison]
and incompatible.missing_keys
isinstance(incompatible, tuple)
and hasattr(incompatible, "missing_keys")
and hasattr(incompatible, "unexpected_keys")
):
logger.warning("Missing state dict keys: {}", incompatible.missing_keys)
if incompatible.missing_keys:
logger.warning("Missing state dict keys: {}", incompatible.missing_keys)

if (
incompatible is not None # pyright: ignore[reportUnnecessaryComparison]
and incompatible.unexpected_keys
):
logger.warning("Unexpected state dict keys: {}", incompatible.unexpected_keys)
if hasattr(incompatible, "unexpected_keys") and incompatible.unexpected_keys:
logger.warning(
"Unexpected state dict keys: {}", incompatible.unexpected_keys
)
else:
logger.warning(
"`model.load_state_dict()` unexpectedly returned: {} "
+ "(expected named tuple with `missing_keys` and `unexpected_keys` attributes)",
(s[:20] + "..." if len(s := str(incompatible)) > 20 else s),
)

return model

Expand Down
Loading
Loading