Skip to content

Commit 50a61d7

Browse files
committed
ENH: Add label image downsampling support
Also switch to DASK_IMAGE_GAUSSIAN for the default intensity downsampler (better quality).
1 parent f896068 commit 50a61d7

File tree

3 files changed

+23
-15
lines changed

3 files changed

+23
-15
lines changed

itkwidgets/integrations/__init__.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import itkwasm
55
import numpy as np
66
import zarr
7-
from multiscale_spatial_image import MultiscaleSpatialImage, to_multiscale, itk_image_to_multiscale
7+
from multiscale_spatial_image import MultiscaleSpatialImage, to_multiscale, itk_image_to_multiscale, Methods
88
from spatial_image import to_spatial_image, is_spatial_image
99

1010
import dask
@@ -42,8 +42,12 @@ def _make_multiscale_store(multiscale):
4242
multiscale.to_zarr(store, compute=True)
4343
return store
4444

45-
def _get_viewer_image(image):
45+
def _get_viewer_image(image, label=False):
4646
min_length = 64
47+
if label:
48+
method = Methods.DASK_IMAGE_NEAREST
49+
else:
50+
method = Methods.DASK_IMAGE_GAUSSIAN
4751
if isinstance(image, MultiscaleSpatialImage):
4852
return _make_multiscale_store(image)
4953

@@ -68,27 +72,27 @@ def _get_viewer_image(image):
6872
previous = scale_factor
6973
scale_factors.append(scale_factor)
7074

71-
multiscale = itk_image_to_multiscale(image, scale_factors=scale_factors)
75+
multiscale = itk_image_to_multiscale(image, scale_factors=scale_factors, method=method)
7276
return _make_multiscale_store(multiscale)
7377

7478
if HAVE_VTK:
7579
import vtk
7680
if isinstance(image, vtk.vtkImageData):
7781
spatial_image = vtk_image_to_spatial_image(image)
7882
scale_factors = _spatial_image_scale_factors(spatial_image, min_length)
79-
multiscale = to_multiscale(spatial_image, scale_factors)
83+
multiscale = to_multiscale(spatial_image, scale_factors, method=method)
8084
return _make_multiscale_store(multiscale)
8185

8286
if isinstance(image, dask.array.core.Array):
8387
spatial_image = to_spatial_image(image)
8488
scale_factors = _spatial_image_scale_factors(spatial_image, min_length)
85-
multiscale = to_multiscale(spatial_image, scale_factors)
89+
multiscale = to_multiscale(spatial_image, scale_factors, method=method)
8690
return _make_multiscale_store(multiscale)
8791

8892
if isinstance(image, zarr.Array):
8993
spatial_image = to_spatial_image(image)
9094
scale_factors = _spatial_image_scale_factors(spatial_image, min_length)
91-
multiscale = to_multiscale(spatial_image, scale_factors)
95+
multiscale = to_multiscale(spatial_image, scale_factors, method=method)
9296
return _make_multiscale_store(multiscale)
9397

9498
# NGFF Zarr
@@ -100,30 +104,30 @@ def _get_viewer_image(image):
100104
if isinstance(image, torch.Tensor):
101105
spatial_image = to_spatial_image(image.numpy())
102106
scale_factors = _spatial_image_scale_factors(spatial_image, min_length)
103-
multiscale = to_multiscale(spatial_image, scale_factors)
107+
multiscale = to_multiscale(spatial_image, scale_factors, method=method)
104108
return _make_multiscale_store(multiscale)
105109

106110
# Todo: preserve dask Array, if present, check if dims are NGFF -> use dims, coords
107111
# Check if coords are uniform, if not, resample
108112
if isinstance(image, xr.DataArray):
109113
if is_spatial_image(image):
110114
scale_factors = _spatial_image_scale_factors(image, min_length)
111-
multiscale = to_multiscale(image, scale_factors)
115+
multiscale = to_multiscale(image, scale_factors, method=method)
112116
return _make_multiscale_store(multiscale)
113117

114118
return xarray_data_array_to_numpy(image)
115119
if isinstance(image, xr.Dataset):
116120
da = image[next(iter(image.variables.keys()))]
117121
if is_spatial_image(da):
118122
scale_factors = _spatial_image_scale_factors(da, min_length)
119-
multiscale = to_multiscale(da, scale_factors)
123+
multiscale = to_multiscale(da, scale_factors, method=method)
120124
return _make_multiscale_store(multiscale)
121125
return xarray_data_set_to_numpy(image)
122126

123127
if isinstance(image, np.ndarray):
124128
spatial_image = to_spatial_image(image)
125129
scale_factors = _spatial_image_scale_factors(spatial_image, min_length)
126-
multiscale = to_multiscale(spatial_image, scale_factors)
130+
multiscale = to_multiscale(spatial_image, scale_factors, method=method)
127131
return _make_multiscale_store(multiscale)
128132
raise RuntimeError("Could not process the viewer image")
129133

itkwidgets/viewer.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
]
2020

2121
_viewer_count = 1
22+
_codecs_registered = False
2223

2324

2425
class ViewerRPC:
@@ -79,7 +80,10 @@ async def run(self, ctx):
7980
render_type = _detect_render_type(data, input_type)
8081
key = init_key_aliases()[input_type]
8182
if render_type is RenderType.IMAGE:
82-
result = _get_viewer_image(data)
83+
if input_type == 'label_image':
84+
result = _get_viewer_image(data, label=True)
85+
else:
86+
result = _get_viewer_image(data, label=False)
8387
elif render_type is RenderType.POINT_SET:
8488
result = _get_viewer_point_sets(data)
8589
if result is None:
@@ -208,7 +212,7 @@ def set_background_color(self, bgColor: List[float]):
208212
def set_image(self, image: Image):
209213
render_type = _detect_render_type(image, 'image')
210214
if render_type is RenderType.IMAGE:
211-
image = _get_viewer_image(image)
215+
image = _get_viewer_image(image, label=False)
212216
self.queue_request('setImage', image)
213217
elif render_type is RenderType.POINT_SET:
214218
image = _get_viewer_point_sets(image)
@@ -250,7 +254,7 @@ def set_image_volume_sample_distance(self, distance: float):
250254
def set_label_image(self, label_image: Image):
251255
render_type = _detect_render_type(label_image, 'image')
252256
if render_type is RenderType.IMAGE:
253-
label_image = _get_viewer_image(label_image)
257+
label_image = _get_viewer_image(label_image, label=True)
254258
self.queue_request('setImage', label_image)
255259
elif render_type is RenderType.POINT_SET:
256260
label_image = _get_viewer_point_sets(label_image)

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,13 @@ keywords = [
3333
"webgpu",
3434
]
3535

36-
requires-python = ">=3.7"
36+
requires-python = ">=3.8"
3737
dependencies = [
3838
"itkwasm >= 1.0b1",
3939
"imjoy-rpc >= 0.5.16",
4040
"imjoy-utils >= 0.1.2",
4141
"numcodecs",
42-
"multiscale_spatial_image >= 0.10.1",
42+
"multiscale_spatial_image[dask-image] >= 0.11.1",
4343
"zarr",
4444
]
4545

0 commit comments

Comments
 (0)