Skip to content

Commit 9aede3d

Browse files
authored
Merge pull request #527 from thewtex/server-side-chunking
ENH: Server-side multi-scale chunking of images
2 parents 89e001d + a669fb5 commit 9aede3d

File tree

6 files changed

+179
-76
lines changed

6 files changed

+179
-76
lines changed

itkwidgets/_type_aliases.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import itkwasm
22
import zarr
33
import numpy as np
4+
import dask
45

5-
from .integrations.dask import HAVE_DASK
66
from .integrations.itk import HAVE_ITK
77
from .integrations.pytorch import HAVE_TORCH
88
from .integrations.vtk import HAVE_VTK
@@ -23,10 +23,8 @@
2323
import vtk
2424
Image = Union[Image, vtk.vtkImageData]
2525
Point_Sets = Union[Point_Sets, vtk.vtkPolyData]
26-
if HAVE_DASK:
27-
import dask
28-
Image = Union[Image, dask.array.core.Array]
29-
Point_Sets = Union[Point_Sets, dask.array.core.Array]
26+
Image = Union[Image, dask.array.core.Array]
27+
Point_Sets = Union[Point_Sets, dask.array.core.Array]
3028
if HAVE_TORCH:
3129
import torch
3230
Image = Union[Image, torch.Tensor]

itkwidgets/imjoy.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1-
from dataclasses import dataclass, asdict
1+
from dataclasses import asdict
22

3-
from typing import Dict
3+
from typing import Dict
44

55
import itkwasm
66
import numcodecs
77
from imjoy_rpc import api
8+
import zarr
89

910
_numcodec_encoder = numcodecs.Blosc(cname='lz4', clevel=3)
1011
_numcodec_config = _numcodec_encoder.get_config()
@@ -24,6 +25,26 @@ def encode_itkwasm_image(image):
2425

2526
return image_dict
2627

28+
def encode_zarr_store(store):
29+
def getItem(key):
30+
return store[key]
31+
32+
def setItem(key, value):
33+
store[key] = value
34+
35+
def containsItem(key):
36+
return key in store
37+
38+
return {
39+
"_rintf": True,
40+
"_rtype": 'zarr-store',
41+
"getItem": getItem,
42+
"setItem": setItem,
43+
"containsItem": containsItem,
44+
}
45+
2746
def register_itkwasm_imjoy_codecs():
2847

2948
api.registerCodec({'name': 'itkwasm-image', 'type': itkwasm.Image, 'encoder': encode_itkwasm_image})
49+
api.registerCodec({'name': 'zarr-store', 'type': zarr.storage.BaseStore, 'encoder': encode_zarr_store})
50+
Lines changed: 137 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,60 +1,148 @@
1+
from collections.abc import MutableMapping
2+
from re import S
3+
14
import itkwasm
25
import numpy as np
36
import zarr
7+
from multiscale_spatial_image import MultiscaleSpatialImage, to_multiscale, itk_image_to_multiscale
8+
from spatial_image import to_spatial_image, is_spatial_image
49

5-
from .dask import HAVE_DASK, dask_array_to_ndarray
10+
import dask
11+
import xarray as xr
612
from .itk import HAVE_ITK, itk_image_to_wasm_image, itk_group_spatial_object_to_wasm_point_set
713
from .pytorch import HAVE_TORCH
8-
from .vtk import HAVE_VTK, vtk_image_to_ndarray, vtk_polydata_to_vtkjs
9-
from .xarray import HAVE_XARRAY, xarray_data_array_to_numpy, xarray_data_set_to_numpy
14+
from .vtk import HAVE_VTK, vtk_image_to_spatial_image, vtk_polydata_to_vtkjs
15+
from .xarray import xarray_data_array_to_numpy, xarray_data_set_to_numpy
1016
from ..render_types import RenderType
1117

18+
def _spatial_image_scale_factors(spatial_image, min_length):
19+
sizes = dict(spatial_image.sizes)
20+
scale_factors = []
21+
dims = spatial_image.dims
22+
previous = { d: 1 for d in { 'x', 'y', 'z' }.intersection(dims) }
23+
while (np.array(list(sizes.values())) > min_length).any():
24+
max_size = np.array(list(sizes.values())).max()
25+
to_skip = { d: sizes[d] <= max_size / 2 for d in previous.keys() }
26+
scale_factor = {}
27+
for dim in previous.keys():
28+
if to_skip[dim]:
29+
scale_factor[dim] = previous[dim]
30+
continue
31+
scale_factor[dim] = 2 * previous[dim]
32+
33+
sizes[dim] = int(sizes[dim] / 2)
34+
previous = scale_factor
35+
scale_factors.append(scale_factor)
36+
37+
return scale_factors
38+
39+
def _make_multiscale_store(multiscale):
40+
# Todo: for very large images serialize to disk cache
41+
store = zarr.storage.MemoryStore(dimension_separator='/')
42+
multiscale.to_zarr(store, compute=True)
43+
return store
1244

1345
def _get_viewer_image(image):
46+
min_length = 64
47+
if isinstance(image, MultiscaleSpatialImage):
48+
return _make_multiscale_store(image)
49+
50+
# Todo: support for itkwasm.Image
1451
if HAVE_ITK:
1552
import itk
1653
if isinstance(image, itk.Image):
17-
return itk_image_to_wasm_image(image)
54+
dimension = image.GetImageDimension()
55+
size = np.array(itk.size(image))
56+
scale_factors = []
57+
dims = ('x', 'y', 'z')
58+
previous = {'x': 1, 'y': 1, 'z': 1}
59+
while (size > min_length).any():
60+
to_skip = size <= size.max() / 2
61+
scale_factor = {}
62+
for dim in range(dimension):
63+
if to_skip[dim]:
64+
scale_factor[dims[dim]] = previous[dims[dim]]
65+
continue
66+
scale_factor[dims[dim]] = 2 * previous[dims[dim]]
67+
size[dim] = int(size[dim] / 2)
68+
previous = scale_factor
69+
scale_factors.append(scale_factor)
70+
71+
multiscale = itk_image_to_multiscale(image, scale_factors=scale_factors)
72+
return _make_multiscale_store(multiscale)
73+
1874
if HAVE_VTK:
1975
import vtk
2076
if isinstance(image, vtk.vtkImageData):
21-
return vtk_image_to_ndarray(image)
22-
if HAVE_DASK:
23-
import dask
24-
if isinstance(image, dask.array.core.Array):
25-
return dask_array_to_ndarray(image)
77+
spatial_image = vtk_image_to_spatial_image(image)
78+
scale_factors = _spatial_image_scale_factors(spatial_image, min_length)
79+
multiscale = to_multiscale(spatial_image, scale_factors)
80+
return _make_multiscale_store(multiscale)
81+
82+
if isinstance(image, dask.array.core.Array):
83+
spatial_image = to_spatial_image(image)
84+
scale_factors = _spatial_image_scale_factors(spatial_image, min_length)
85+
multiscale = to_multiscale(spatial_image, scale_factors)
86+
return _make_multiscale_store(multiscale)
87+
88+
if isinstance(image, zarr.Array):
89+
spatial_image = to_spatial_image(image)
90+
scale_factors = _spatial_image_scale_factors(spatial_image, min_length)
91+
multiscale = to_multiscale(spatial_image, scale_factors)
92+
return _make_multiscale_store(multiscale)
93+
94+
# NGFF Zarr
95+
if isinstance(image, zarr.Group) and 'multiscales' in image.attrs:
96+
return _make_multiscale_store(image.store)
97+
2698
if HAVE_TORCH:
2799
import torch
28100
if isinstance(image, torch.Tensor):
29-
return image.numpy()
30-
if HAVE_XARRAY:
31-
import xarray
32-
if isinstance(image, xarray.DataArray):
33-
return xarray_data_array_to_numpy(image)
34-
if isinstance(image, xarray.Dataset):
35-
return xarray_data_set_to_numpy(image)
36-
return image
101+
spatial_image = to_spatial_image(image.numpy())
102+
scale_factors = _spatial_image_scale_factors(spatial_image, min_length)
103+
multiscale = to_multiscale(spatial_image, scale_factors)
104+
return _make_multiscale_store(multiscale)
105+
106+
# Todo: preserve dask Array, if present, check if dims are NGFF -> use dims, coords
107+
# Check if coords are uniform, if not, resample
108+
if isinstance(image, xr.DataArray):
109+
if is_spatial_image(image):
110+
scale_factors = _spatial_image_scale_factors(image, min_length)
111+
multiscale = to_multiscale(image, scale_factors)
112+
return _make_multiscale_store(multiscale)
113+
114+
return xarray_data_array_to_numpy(image)
115+
if isinstance(image, xr.Dataset):
116+
da = image[next(iter(image.variables.keys()))]
117+
if is_spatial_image(da):
118+
scale_factors = _spatial_image_scale_factors(da, min_length)
119+
multiscale = to_multiscale(da, scale_factors)
120+
return _make_multiscale_store(multiscale)
121+
return xarray_data_set_to_numpy(image)
122+
123+
if isinstance(image, np.ndarray):
124+
spatial_image = to_spatial_image(image)
125+
scale_factors = _spatial_image_scale_factors(spatial_image, min_length)
126+
multiscale = to_multiscale(spatial_image, scale_factors)
127+
return _make_multiscale_store(multiscale)
128+
raise RuntimeError("Could not process the viewer image")
37129

38130

39131
def _get_viewer_point_sets(point_sets):
40132
if HAVE_VTK:
41133
import vtk
42134
if isinstance(point_sets, vtk.vtkPolyData):
43135
return vtk_polydata_to_vtkjs(point_sets)
44-
if HAVE_DASK:
45-
import dask
46-
if isinstance(point_sets, dask.array.core.Array):
47-
return dask_array_to_ndarray(point_sets)
136+
if isinstance(point_sets, dask.array.core.Array):
137+
return np.asarray(point_sets)
48138
if HAVE_TORCH:
49139
import torch
50140
if isinstance(point_sets, torch.Tensor):
51141
return point_sets.numpy()
52-
if HAVE_XARRAY:
53-
import xarray
54-
if isinstance(point_sets, xarray.DataArray):
55-
return xarray_data_array_to_numpy(point_sets)
56-
if isinstance(point_sets, xarray.Dataset):
57-
return xarray_data_set_to_numpy(point_sets)
142+
if isinstance(point_sets, xr.DataArray):
143+
return xarray_data_array_to_numpy(point_sets)
144+
if isinstance(point_sets, xr.Dataset):
145+
return xarray_data_set_to_numpy(point_sets)
58146
return point_sets
59147

60148

@@ -63,6 +151,13 @@ def _detect_render_type(data, input_type) -> RenderType:
63151
return RenderType.IMAGE
64152
elif isinstance(data, itkwasm.PointSet):
65153
return RenderType.POINT_SET
154+
elif isinstance(data, MultiscaleSpatialImage):
155+
return RenderType.IMAGE
156+
elif isinstance(data, (zarr.Array, zarr.Group)):
157+
# For now assume zarr.Group is an image
158+
# In the future, once NGFF supports point sets fully
159+
# We may need to do more introspection
160+
return RenderType.IMAGE
66161
elif isinstance(data, np.ndarray):
67162
if input_type == 'point_sets':
68163
return RenderType.POINT_SET
@@ -83,29 +178,25 @@ def _detect_render_type(data, input_type) -> RenderType:
83178
return RenderType.IMAGE
84179
elif isinstance(data, vtk.vtkPolyData):
85180
return RenderType.POINT_SET
86-
if HAVE_DASK:
87-
import dask
88-
if isinstance(data, dask.array.core.Array):
89-
if input_type == 'point_sets':
90-
return RenderType.POINT_SET
91-
else:
92-
return RenderType.IMAGE
181+
if isinstance(data, dask.array.core.Array):
182+
if input_type == 'point_sets':
183+
return RenderType.POINT_SET
184+
else:
185+
return RenderType.IMAGE
93186
if HAVE_TORCH:
94187
import torch
95188
if isinstance(data, torch.Tensor):
96189
if input_type == 'point_sets':
97190
return RenderType.POINT_SET
98191
else:
99192
return RenderType.IMAGE
100-
if HAVE_XARRAY:
101-
import xarray
102-
if isinstance(data, xarray.DataArray):
103-
if input_type == 'point_sets':
104-
return RenderType.POINT_SET
105-
else:
106-
return RenderType.IMAGE
107-
if isinstance(data, xarray.Dataset):
108-
if input_type == 'point_sets':
109-
return RenderType.POINT_SET
110-
else:
111-
return RenderType.IMAGE
193+
if isinstance(data, xr.DataArray):
194+
if input_type == 'point_sets':
195+
return RenderType.POINT_SET
196+
else:
197+
return RenderType.IMAGE
198+
if isinstance(data, xr.Dataset):
199+
if input_type == 'point_sets':
200+
return RenderType.POINT_SET
201+
else:
202+
return RenderType.IMAGE

itkwidgets/integrations/dask.py

Lines changed: 0 additions & 19 deletions
This file was deleted.

itkwidgets/integrations/vtk.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,23 @@
66
except ImportError:
77
pass
88

9+
from spatial_image import to_spatial_image
910

10-
def vtk_image_to_ndarray(image):
11+
12+
def vtk_image_to_spatial_image(image):
1113
array = vtk_to_numpy(image.GetPointData().GetScalars())
12-
dims = list(image.GetDimensions())
13-
array.shape = dims[::-1]
14-
return array
14+
dimensions = list(image.GetDimensions())
15+
array.shape = dimensions[::-1]
16+
17+
origin = image.GetOrigin()
18+
translation = { 'x': origin[0], 'y': origin[1], 'z': origin[2] }
19+
20+
spacing = image.GetSpacing()
21+
scale = { 'x': spacing[0], 'y': spacing[1], 'z': spacing[2] }
22+
23+
spatial_image = to_spatial_image(array, scale=scale, translation=translation)
24+
25+
return spatial_image
1526

1627
def vtk_polydata_to_vtkjs(point_set):
1728
array = vtk_to_numpy(point_set.GetPoints().GetData())

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ dependencies = [
3939
"imjoy-rpc >= 0.5.13",
4040
"imjoy-utils >= 0.1.2",
4141
"numcodecs",
42+
"multiscale_spatial_image >= 0.10.1",
4243
"zarr",
4344
]
4445

0 commit comments

Comments
 (0)