Skip to content

Commit b6f7664

Browse files
committed
make softmax compatible with xarray <= 2024.07
1 parent be8a69e commit b6f7664

File tree

3 files changed

+23
-1
lines changed

3 files changed

+23
-1
lines changed

bioimageio/core/proc_ops.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -539,7 +539,8 @@ class Softmax(_SimpleOperator):
539539

540540
def _apply(self, x: Tensor, stat: Stat) -> Tensor:
541541
x_max = x.data.max(dim=self.axis, keepdims=False)
542-
exp_x_shifted: xr.DataArray = xr.ufuncs.exp(x.data - x_max)
542+
x_shifted = x.data - x_max
543+
exp_x_shifted = xr.DataArray(x_shifted.data.exp(), dims=x.dims)
543544
result = exp_x_shifted / exp_x_shifted.sum(dim=self.axis)
544545
return Tensor.from_xarray(result)
545546

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ dev = [
6060
"pytest-cov",
6161
"pytest",
6262
"python-dotenv",
63+
"scipy",
6364
"segment-anything", # for model testing
6465
"timm", # for model testing
6566
"torch>=1.6,<3",

tests/test_proc_ops.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -394,3 +394,23 @@ def test_softmax(tid: MemberId):
394394
dims=axes,
395395
)
396396
xr.testing.assert_allclose(exp, sample.members[tid].data, rtol=1e-5, atol=1e-7)
397+
398+
399+
def test_softmax_w_scipy(tid: MemberId):
400+
import scipy # pyright: ignore[reportMissingTypeStubs]
401+
402+
from bioimageio.core.proc_ops import Softmax
403+
404+
shape = (3, 32, 32)
405+
axes = ("channel", "y", "x")
406+
np_data = np.random.rand(*shape)
407+
data = xr.DataArray(np_data, dims=axes)
408+
sample = Sample(members={tid: Tensor.from_xarray(data)}, stat={}, id=None)
409+
softmax = Softmax(tid, tid, axis=AxisId("channel"))
410+
softmax(sample)
411+
412+
exp = xr.DataArray(
413+
scipy.special.softmax(np_data, axis=0),
414+
dims=axes,
415+
)
416+
xr.testing.assert_allclose(exp, sample.members[tid].data, rtol=1e-5, atol=1e-7)

0 commit comments

Comments
 (0)