Skip to content

Commit a9d1132

Browse files
committed
WIP add test_prediction
1 parent 8a8cb00 commit a9d1132

File tree

1 file changed

+96
-0
lines changed

1 file changed

+96
-0
lines changed

tests/test_prediction.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,99 @@
1+
from pathlib import Path
2+
from typing import Literal, Mapping, NamedTuple, assert_never
3+
4+
import numpy as np
5+
import pytest
6+
import xarray as xr
7+
8+
from bioimageio.core import (
9+
AxisId,
10+
MemberId,
11+
PredictionPipeline,
12+
Sample,
13+
create_prediction_pipeline,
14+
load_model,
15+
predict,
16+
)
17+
from bioimageio.core.digest_spec import get_test_inputs, get_test_outputs
18+
from bioimageio.spec import AnyModelDescr
19+
20+
21+
class Prep(NamedTuple):
22+
model: AnyModelDescr
23+
prediction_pipeline: PredictionPipeline
24+
input_sample: Sample
25+
output_sample: Sample
26+
27+
28+
@pytest.fixture(scope="module")
29+
def prep(any_model: str):
30+
model = load_model(any_model, perform_io_checks=False)
31+
input_sample = get_test_inputs(model)
32+
output_sample = get_test_outputs(model)
33+
return Prep(model, create_prediction_pipeline(model), input_sample, output_sample)
34+
35+
36+
def test_predict_with_pipeline(prep: Prep):
37+
out = predict(
38+
model=prep.prediction_pipeline,
39+
inputs=prep.input_sample,
40+
)
41+
assert out == prep.output_sample
42+
43+
44+
@pytest.mark.parameterize("tensor_input", ["numpy", "xarray"])
45+
def test_predict_with_model_description(
46+
tensor_input: Literal["numpy", "xarray"], prep: Prep
47+
):
48+
if tensor_input == "xarray":
49+
ipt = {m: t.data for m, t in prep.input_sample.members.items()}
50+
assert all(isinstance(v, xr.DataArray) for v in ipt.values())
51+
elif tensor_input == "numpy":
52+
ipt = {m: t.data.data for m, t in prep.input_sample.members.items()}
53+
assert all(isinstance(v, np.ndarray) for v in ipt.values())
54+
else:
55+
assert_never(tensor_input)
56+
57+
out = predict(
58+
model=prep.model,
59+
inputs=ipt,
60+
sample_id=prep.input_sample.id,
61+
skip_preprocessing=False,
62+
skip_postprocessing=False,
63+
)
64+
assert out == prep.output_sample
65+
66+
67+
@pytest.mark.parameterize("with_proces", [True, False])
68+
def test_predict_with_blocking(with_procs: bool, prep: Prep):
69+
input_block_shape: Mapping[MemberId, Mapping[AxisId, int]] = {
70+
list(prep.input_sample.members)[0]: {
71+
"x": 32, # pyright: ignore[reportAssignmentType]
72+
AxisId("y"): 32,
73+
}
74+
}
75+
out = predict(
76+
model=prep.prediction_pipeline,
77+
inputs=prep.input_sample,
78+
input_block_shape=input_block_shape,
79+
sample_id=prep.input_sample.id,
80+
skip_preprocessing=with_procs,
81+
skip_postprocessing=with_procs,
82+
)
83+
assert out == prep.output_sample
84+
85+
86+
def test_predict_save_output(prep: Prep, tmp_path: Path):
87+
save_path = tmp_path / "{member_id}_{sample_id}.h5"
88+
out = predict(
89+
model=prep.prediction_pipeline,
90+
inputs=prep.input_sample,
91+
save_output_path=save_path,
92+
)
93+
assert out == prep.output_sample
94+
assert save_path.parent.exists()
95+
96+
197
# TODO: update
298
# from pathlib import Path
399

0 commit comments

Comments
 (0)