|
| 1 | +from pathlib import Path |
| 2 | +from typing import Literal, Mapping, NamedTuple, assert_never |
| 3 | + |
| 4 | +import numpy as np |
| 5 | +import pytest |
| 6 | +import xarray as xr |
| 7 | + |
| 8 | +from bioimageio.core import ( |
| 9 | + AxisId, |
| 10 | + MemberId, |
| 11 | + PredictionPipeline, |
| 12 | + Sample, |
| 13 | + create_prediction_pipeline, |
| 14 | + load_model, |
| 15 | + predict, |
| 16 | +) |
| 17 | +from bioimageio.core.digest_spec import get_test_inputs, get_test_outputs |
| 18 | +from bioimageio.spec import AnyModelDescr |
| 19 | + |
| 20 | + |
| 21 | +class Prep(NamedTuple): |
| 22 | + model: AnyModelDescr |
| 23 | + prediction_pipeline: PredictionPipeline |
| 24 | + input_sample: Sample |
| 25 | + output_sample: Sample |
| 26 | + |
| 27 | + |
| 28 | +@pytest.fixture(scope="module") |
| 29 | +def prep(any_model: str): |
| 30 | + model = load_model(any_model, perform_io_checks=False) |
| 31 | + input_sample = get_test_inputs(model) |
| 32 | + output_sample = get_test_outputs(model) |
| 33 | + return Prep(model, create_prediction_pipeline(model), input_sample, output_sample) |
| 34 | + |
| 35 | + |
| 36 | +def test_predict_with_pipeline(prep: Prep): |
| 37 | + out = predict( |
| 38 | + model=prep.prediction_pipeline, |
| 39 | + inputs=prep.input_sample, |
| 40 | + ) |
| 41 | + assert out == prep.output_sample |
| 42 | + |
| 43 | + |
| 44 | +@pytest.mark.parameterize("tensor_input", ["numpy", "xarray"]) |
| 45 | +def test_predict_with_model_description( |
| 46 | + tensor_input: Literal["numpy", "xarray"], prep: Prep |
| 47 | +): |
| 48 | + if tensor_input == "xarray": |
| 49 | + ipt = {m: t.data for m, t in prep.input_sample.members.items()} |
| 50 | + assert all(isinstance(v, xr.DataArray) for v in ipt.values()) |
| 51 | + elif tensor_input == "numpy": |
| 52 | + ipt = {m: t.data.data for m, t in prep.input_sample.members.items()} |
| 53 | + assert all(isinstance(v, np.ndarray) for v in ipt.values()) |
| 54 | + else: |
| 55 | + assert_never(tensor_input) |
| 56 | + |
| 57 | + out = predict( |
| 58 | + model=prep.model, |
| 59 | + inputs=ipt, |
| 60 | + sample_id=prep.input_sample.id, |
| 61 | + skip_preprocessing=False, |
| 62 | + skip_postprocessing=False, |
| 63 | + ) |
| 64 | + assert out == prep.output_sample |
| 65 | + |
| 66 | + |
| 67 | +@pytest.mark.parameterize("with_proces", [True, False]) |
| 68 | +def test_predict_with_blocking(with_procs: bool, prep: Prep): |
| 69 | + input_block_shape: Mapping[MemberId, Mapping[AxisId, int]] = { |
| 70 | + list(prep.input_sample.members)[0]: { |
| 71 | + "x": 32, # pyright: ignore[reportAssignmentType] |
| 72 | + AxisId("y"): 32, |
| 73 | + } |
| 74 | + } |
| 75 | + out = predict( |
| 76 | + model=prep.prediction_pipeline, |
| 77 | + inputs=prep.input_sample, |
| 78 | + input_block_shape=input_block_shape, |
| 79 | + sample_id=prep.input_sample.id, |
| 80 | + skip_preprocessing=with_procs, |
| 81 | + skip_postprocessing=with_procs, |
| 82 | + ) |
| 83 | + assert out == prep.output_sample |
| 84 | + |
| 85 | + |
| 86 | +def test_predict_save_output(prep: Prep, tmp_path: Path): |
| 87 | + save_path = tmp_path / "{member_id}_{sample_id}.h5" |
| 88 | + out = predict( |
| 89 | + model=prep.prediction_pipeline, |
| 90 | + inputs=prep.input_sample, |
| 91 | + save_output_path=save_path, |
| 92 | + ) |
| 93 | + assert out == prep.output_sample |
| 94 | + assert save_path.parent.exists() |
| 95 | + |
| 96 | + |
1 | 97 | # TODO: update |
2 | 98 | # from pathlib import Path |
3 | 99 |
|
|
0 commit comments