Skip to content

Commit 8a316f7

Browse files
committed
implement and test ScaleMeanVariance
1 parent 431191e commit 8a316f7

File tree

2 files changed

+77
-2
lines changed

2 files changed

+77
-2
lines changed

bioimageio/core/prediction_pipeline/_processing.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,30 @@ def __post_init__(self):
162162

163163
@dataclass
164164
class ScaleMeanVariance(Processing):
165-
...
165+
mode: Literal[SampleMode, DatasetMode] = PER_SAMPLE
166+
reference_tensor: TensorName = MISSING
167+
axes: Optional[Sequence[str]] = None
168+
eps: float = 1e-6
169+
170+
def get_required_measures(self) -> RequiredMeasures:
171+
axes = None if self.axes is None else tuple(self.axes)
172+
return {
173+
self.mode: {
174+
self.tensor_name: {Mean(axes=axes), Std(axes=axes)},
175+
self.reference_tensor: {Mean(axes=axes), Std(axes=axes)},
176+
}
177+
}
178+
179+
def apply(self, tensor: xr.DataArray) -> xr.DataArray:
180+
axes = None if self.axes is None else tuple(self.axes)
181+
assert self.mode in (PER_SAMPLE, PER_DATASET)
182+
mean = self.get_computed_measure(self.tensor_name, Mean(axes), mode=self.mode)
183+
std = self.get_computed_measure(self.tensor_name, Std(axes), mode=self.mode)
184+
ref_mean = self.get_computed_measure(self.reference_tensor, Mean(axes), mode=self.mode)
185+
ref_std = self.get_computed_measure(self.reference_tensor, Std(axes), mode=self.mode)
186+
187+
tensor = (tensor - mean) / (std + self.eps) * (ref_std + self.eps) + ref_mean
188+
return ensure_dtype(tensor, dtype="float32")
166189

167190

168191
@dataclass

tests/prediction_pipeline/test_postprocessing.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
import numpy as np
2+
import pytest
23
import xarray as xr
34

5+
from bioimageio.core.prediction_pipeline._measure_groups import compute_measures
46

5-
def test_binarize_postprocessing():
7+
8+
def test_binarize():
69
from bioimageio.core.prediction_pipeline._processing import Binarize
710

811
shape = (3, 32, 32)
@@ -16,3 +19,52 @@ def test_binarize_postprocessing():
1619
binarize = Binarize("data_name", threshold=threshold)
1720
res = binarize(data)
1821
xr.testing.assert_allclose(res, exp)
22+
23+
24+
@pytest.mark.parametrize("axes", [None, tuple("cy"), tuple("cyx"), tuple("x")])
25+
def test_scale_mean_variance(axes):
26+
from bioimageio.core.prediction_pipeline._processing import ScaleMeanVariance
27+
28+
shape = (3, 32, 46)
29+
ipt_axes = ("c", "y", "x")
30+
np_data = np.random.rand(*shape)
31+
ipt_data = xr.DataArray(np_data, dims=ipt_axes)
32+
ref_data = xr.DataArray((np_data * 2) + 3, dims=ipt_axes)
33+
34+
scale_mean_variance = ScaleMeanVariance("data_name", reference_tensor="ref_name", axes=axes)
35+
required = scale_mean_variance.get_required_measures()
36+
computed = compute_measures(required, sample={"data_name": ipt_data, "ref_name": ref_data})
37+
scale_mean_variance.set_computed_measures(computed)
38+
39+
res = scale_mean_variance(ipt_data)
40+
xr.testing.assert_allclose(res, ref_data)
41+
42+
43+
@pytest.mark.parametrize("axes", [None, tuple("cy"), tuple("y"), tuple("yx")])
44+
def test_scale_mean_variance_per_channel(axes):
45+
from bioimageio.core.prediction_pipeline._processing import ScaleMeanVariance
46+
47+
shape = (3, 32, 46)
48+
ipt_axes = ("c", "y", "x")
49+
np_data = np.random.rand(*shape)
50+
ipt_data = xr.DataArray(np_data, dims=ipt_axes)
51+
52+
# set different mean, std per channel
53+
np_ref_data = np.stack([d * i + i for i, d in enumerate(np_data, start=2)])
54+
print(np_ref_data.shape)
55+
ref_data = xr.DataArray(np_ref_data, dims=ipt_axes)
56+
57+
scale_mean_variance = ScaleMeanVariance("data_name", reference_tensor="ref_name", axes=axes)
58+
required = scale_mean_variance.get_required_measures()
59+
computed = compute_measures(required, sample={"data_name": ipt_data, "ref_name": ref_data})
60+
scale_mean_variance.set_computed_measures(computed)
61+
62+
res = scale_mean_variance(ipt_data)
63+
64+
if axes is not None and "c" not in axes:
65+
# mean,std per channel should match exactly
66+
xr.testing.assert_allclose(res, ref_data)
67+
else:
68+
# mean,std across channels should not match
69+
with pytest.raises(AssertionError):
70+
xr.testing.assert_allclose(res, ref_data)

0 commit comments

Comments
 (0)