Skip to content

Commit 3ac62bf

Browse files
committed
use scipy to avoid custom dask/numpy handling for softmax
1 parent b6f7664 commit 3ac62bf

File tree

4 files changed

+9
-9
lines changed

4 files changed

+9
-9
lines changed

bioimageio/core/proc_ops.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
)
1414

1515
import numpy as np
16+
import scipy # pyright: ignore[reportMissingTypeStubs]
1617
import xarray as xr
1718
from typing_extensions import Self, assert_never
1819

@@ -538,11 +539,10 @@ class Softmax(_SimpleOperator):
538539
axis: AxisId = AxisId("channel")
539540

540541
def _apply(self, x: Tensor, stat: Stat) -> Tensor:
541-
x_max = x.data.max(dim=self.axis, keepdims=False)
542-
x_shifted = x.data - x_max
543-
exp_x_shifted = xr.DataArray(x_shifted.data.exp(), dims=x.dims)
544-
result = exp_x_shifted / exp_x_shifted.sum(dim=self.axis)
545-
return Tensor.from_xarray(result)
542+
axis_idx = x.dims.index(self.axis)
543+
result = scipy.special.softmax(x.data, axis=axis_idx)
544+
result_xr = xr.DataArray(result, dims=x.dims)
545+
return Tensor.from_xarray(result_xr)
546546

547547
@property
548548
def required_measures(self) -> Collection[Measure]:

bioimageio/core/tensor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545

4646

4747
# TODO: complete docstrings
48+
# TODO: in the long run---with improved typing in xarray---we should probably replace `Tensor` with xr.DataArray
4849
class Tensor(MagicTensorOpsMixin):
4950
"""A wrapper around an xr.DataArray for better integration with bioimageio.spec
5051
and improved type annotations."""

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ dependencies = [
1414
"pydantic-settings>=2.5,<3",
1515
"pydantic>=2.7.0,<3",
1616
"ruyaml",
17+
"scipy",
1718
"tqdm",
1819
"typing-extensions",
1920
"xarray>=2023.01,<2025.3.0",
@@ -60,7 +61,6 @@ dev = [
6061
"pytest-cov",
6162
"pytest",
6263
"python-dotenv",
63-
"scipy",
6464
"segment-anything", # for model testing
6565
"timm", # for model testing
6666
"torch>=1.6,<3",

tests/test_proc_ops.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import numpy as np
44
import pytest
5+
import scipy # pyright: ignore[reportMissingTypeStubs]
56
import xarray as xr
67
from typing_extensions import TypeGuard
78

@@ -396,9 +397,7 @@ def test_softmax(tid: MemberId):
396397
xr.testing.assert_allclose(exp, sample.members[tid].data, rtol=1e-5, atol=1e-7)
397398

398399

399-
def test_softmax_w_scipy(tid: MemberId):
400-
import scipy # pyright: ignore[reportMissingTypeStubs]
401-
400+
def test_softmax_with_scipy(tid: MemberId):
402401
from bioimageio.core.proc_ops import Softmax
403402

404403
shape = (3, 32, 32)

0 commit comments

Comments
 (0)