1
- from collections .abc import MutableMapping
2
- from re import S
3
-
4
1
import itkwasm
5
2
import numpy as np
6
3
import zarr
7
- from multiscale_spatial_image import MultiscaleSpatialImage , to_multiscale , itk_image_to_multiscale , Methods
8
- from spatial_image import to_spatial_image , is_spatial_image
4
+ from ngff_zarr import to_multiscales , to_ngff_zarr , to_ngff_image , itk_image_to_ngff_image , Methods
9
5
10
6
import dask
11
- import xarray as xr
12
- from .itk import HAVE_ITK , itk_image_to_wasm_image , itk_group_spatial_object_to_wasm_point_set
7
+ from .itk import HAVE_ITK , itk_group_spatial_object_to_wasm_point_set
13
8
from .pytorch import HAVE_TORCH
14
- from .vtk import HAVE_VTK , vtk_image_to_spatial_image , vtk_polydata_to_vtkjs
15
- from .xarray import xarray_data_array_to_numpy , xarray_data_set_to_numpy
9
+ from .vtk import HAVE_VTK , vtk_image_to_ngff_image , vtk_polydata_to_vtkjs
10
+ from .xarray import HAVE_XARRAY , HAVE_MULTISCALE_SPATIAL_IMAGE , xarray_data_array_to_numpy , xarray_data_set_to_numpy
16
11
from ..render_types import RenderType
17
12
18
13
def _spatial_image_scale_factors (spatial_image , min_length ):
@@ -36,99 +31,99 @@ def _spatial_image_scale_factors(spatial_image, min_length):
36
31
37
32
return scale_factors
38
33
39
- def _make_multiscale_store (multiscale ):
34
+ def _make_multiscale_store ():
40
35
# Todo: for very large images serialize to disk cache
36
+ # -> create DirectoryStore in cache directory and return as the chunk_store
41
37
store = zarr .storage .MemoryStore (dimension_separator = '/' )
42
- multiscale .to_zarr (store , compute = True )
43
- return store
38
+ return store , None
44
39
45
40
def _get_viewer_image (image , label = False ):
41
+ # NGFF Zarr
42
+ if isinstance (image , zarr .Group ) and 'multiscales' in image .attrs :
43
+ return image .store
44
+
46
45
min_length = 64
47
46
if label :
48
47
method = Methods .DASK_IMAGE_NEAREST
49
48
else :
50
49
method = Methods .DASK_IMAGE_GAUSSIAN
51
- if isinstance (image , MultiscaleSpatialImage ):
52
- return _make_multiscale_store (image )
50
+
51
+ store , chunk_store = _make_multiscale_store ()
52
+
53
+ if HAVE_MULTISCALE_SPATIAL_IMAGE :
54
+ if isinstance (image , MultiscaleSpatialImage ):
55
+ image .to_zarr (store , compute = True )
56
+ return store
57
+
58
+ if isinstance (image , itkwasm .Image ):
59
+ ngff_image = itk_image_to_ngff_image (image )
60
+ multiscales = to_multiscales (ngff_image , method = method )
61
+ to_ngff_zarr (store , multiscales , chunk_store = chunk_store )
53
62
54
- # Todo: support for itkwasm.Image
55
63
if HAVE_ITK :
56
64
import itk
57
65
if isinstance (image , itk .Image ):
58
- dimension = image .GetImageDimension ()
59
- size = np .array (itk .size (image ))
60
- scale_factors = []
61
- dims = ('x' , 'y' , 'z' )
62
- previous = {'x' : 1 , 'y' : 1 , 'z' : 1 }
63
- while (size > min_length ).any ():
64
- to_skip = size <= size .max () / 2
65
- scale_factor = {}
66
- for dim in range (dimension ):
67
- if to_skip [dim ]:
68
- scale_factor [dims [dim ]] = previous [dims [dim ]]
69
- continue
70
- scale_factor [dims [dim ]] = 2 * previous [dims [dim ]]
71
- size [dim ] = int (size [dim ] / 2 )
72
- previous = scale_factor
73
- scale_factors .append (scale_factor )
74
-
75
- multiscale = itk_image_to_multiscale (image , scale_factors = scale_factors , method = method )
76
- return _make_multiscale_store (multiscale )
66
+ ngff_image = itk_image_to_ngff_image (image )
67
+ multiscales = to_multiscales (ngff_image , method = method )
68
+ to_ngff_zarr (store , multiscales , chunk_store = chunk_store )
69
+ return store
77
70
78
71
if HAVE_VTK :
79
72
import vtk
80
73
if isinstance (image , vtk .vtkImageData ):
81
- spatial_image = vtk_image_to_spatial_image (image )
82
- scale_factors = _spatial_image_scale_factors ( spatial_image , min_length )
83
- multiscale = to_multiscale ( spatial_image , scale_factors , method = method )
84
- return _make_multiscale_store ( multiscale )
74
+ ngff_image = vtk_image_to_ngff_image (image )
75
+ multiscales = to_multiscales ( ngff_image , method = method )
76
+ to_ngff_zarr ( store , multiscales , chunk_store = chunk_store )
77
+ return store
85
78
86
79
if isinstance (image , dask .array .core .Array ):
87
- spatial_image = to_spatial_image (image )
88
- scale_factors = _spatial_image_scale_factors ( spatial_image , min_length )
89
- multiscale = to_multiscale ( spatial_image , scale_factors , method = method )
90
- return _make_multiscale_store ( multiscale )
80
+ ngff_image = to_ngff_image (image )
81
+ multiscales = to_multiscales ( ngff_image , method = method )
82
+ to_ngff_zarr ( store , multiscales , chunk_store = chunk_store )
83
+ return store
91
84
92
85
if isinstance (image , zarr .Array ):
93
- spatial_image = to_spatial_image (image )
94
- scale_factors = _spatial_image_scale_factors (spatial_image , min_length )
95
- multiscale = to_multiscale (spatial_image , scale_factors , method = method )
96
- return _make_multiscale_store (multiscale )
97
-
98
- # NGFF Zarr
99
- if isinstance (image , zarr .Group ) and 'multiscales' in image .attrs :
100
- return image .store
86
+ ngff_image = to_ngff_image (image )
87
+ multiscales = to_multiscales (ngff_image , method = method )
88
+ to_ngff_zarr (store , multiscales , chunk_store = chunk_store )
89
+ return store
101
90
102
91
if HAVE_TORCH :
103
92
import torch
104
93
if isinstance (image , torch .Tensor ):
105
- spatial_image = to_spatial_image (image .numpy ())
106
- scale_factors = _spatial_image_scale_factors ( spatial_image , min_length )
107
- multiscale = to_multiscale ( spatial_image , scale_factors , method = method )
108
- return _make_multiscale_store ( multiscale )
94
+ ngff_image = to_ngff_image (image .numpy ())
95
+ multiscales = to_multiscales ( ngff_image , method = method )
96
+ to_ngff_zarr ( store , multiscales , chunk_store = chunk_store )
97
+ return store
109
98
110
99
# Todo: preserve dask Array, if present, check if dims are NGFF -> use dims, coords
111
100
# Check if coords are uniform, if not, resample
112
- if isinstance (image , xr .DataArray ):
113
- if is_spatial_image (image ):
114
- scale_factors = _spatial_image_scale_factors (image , min_length )
115
- multiscale = to_multiscale (image , scale_factors , method = method )
116
- return _make_multiscale_store (multiscale )
117
-
118
- return xarray_data_array_to_numpy (image )
119
- if isinstance (image , xr .Dataset ):
120
- da = image [next (iter (image .variables .keys ()))]
121
- if is_spatial_image (da ):
122
- scale_factors = _spatial_image_scale_factors (da , min_length )
123
- multiscale = to_multiscale (da , scale_factors , method = method )
124
- return _make_multiscale_store (multiscale )
125
- return xarray_data_set_to_numpy (image )
101
+ if HAVE_XARRAY :
102
+ import xarray as xr
103
+ if isinstance (image , xr .DataArray ):
104
+ # if HAVE_MULTISCALE_SPATIAL_IMAGE:
105
+ # from spatial_image import is_spatial_image
106
+ # if is_spatial_image(image):
107
+ # from multiscale_spatial_image import to_multiscale
108
+ # scale_factors = _spatial_image_scale_factors(image, min_length)
109
+ # multiscale = to_multiscale(image, scale_factors, method=method)
110
+ # return _make_multiscale_store(multiscale)
111
+
112
+ return xarray_data_array_to_numpy (image )
113
+ if isinstance (image , xr .Dataset ):
114
+ # da = image[next(iter(image.variables.keys()))]
115
+ # if is_spatial_image(da):
116
+ # scale_factors = _spatial_image_scale_factors(da, min_length)
117
+ # multiscale = to_multiscale(da, scale_factors, method=method)
118
+ # return _make_multiscale_store(multiscale)
119
+ return xarray_data_set_to_numpy (image )
126
120
127
121
if isinstance (image , np .ndarray ):
128
- spatial_image = to_spatial_image (image )
129
- scale_factors = _spatial_image_scale_factors (spatial_image , min_length )
130
- multiscale = to_multiscale (spatial_image , scale_factors , method = method )
131
- return _make_multiscale_store (multiscale )
122
+ ngff_image = to_ngff_image (image )
123
+ multiscales = to_multiscales (ngff_image , method = method )
124
+ to_ngff_zarr (store , multiscales , chunk_store = chunk_store )
125
+ return store
126
+
132
127
raise RuntimeError ("Could not process the viewer image" )
133
128
134
129
@@ -143,10 +138,12 @@ def _get_viewer_point_sets(point_sets):
143
138
import torch
144
139
if isinstance (point_sets , torch .Tensor ):
145
140
return point_sets .numpy ()
146
- if isinstance (point_sets , xr .DataArray ):
147
- return xarray_data_array_to_numpy (point_sets )
148
- if isinstance (point_sets , xr .Dataset ):
149
- return xarray_data_set_to_numpy (point_sets )
141
+ if HAVE_XARRAY :
142
+ import xarray as xr
143
+ if isinstance (point_sets , xr .DataArray ):
144
+ return xarray_data_array_to_numpy (point_sets )
145
+ if isinstance (point_sets , xr .Dataset ):
146
+ return xarray_data_set_to_numpy (point_sets )
150
147
return point_sets
151
148
152
149
@@ -155,8 +152,10 @@ def _detect_render_type(data, input_type) -> RenderType:
155
152
return RenderType .IMAGE
156
153
elif isinstance (data , itkwasm .PointSet ):
157
154
return RenderType .POINT_SET
158
- elif isinstance (data , MultiscaleSpatialImage ):
159
- return RenderType .IMAGE
155
+ elif HAVE_MULTISCALE_SPATIAL_IMAGE :
156
+ from multiscale_spatial_image import MultiscaleSpatialImage
157
+ if isinstance (data , MultiscaleSpatialImage ):
158
+ return RenderType .IMAGE
160
159
elif isinstance (data , (zarr .Array , zarr .Group )):
161
160
# For now assume zarr.Group is an image
162
161
# In the future, once NGFF supports point sets fully
@@ -194,11 +193,13 @@ def _detect_render_type(data, input_type) -> RenderType:
194
193
return RenderType .POINT_SET
195
194
else :
196
195
return RenderType .IMAGE
197
- if isinstance (data , xr .DataArray ):
198
- if input_type == 'point_sets' :
199
- return RenderType .POINT_SET
200
- else :
201
- return RenderType .IMAGE
196
+ if HAVE_XARRAY :
197
+ import xarray as xr
198
+ if isinstance (data , xr .DataArray ):
199
+ if input_type == 'point_sets' :
200
+ return RenderType .POINT_SET
201
+ else :
202
+ return RenderType .IMAGE
202
203
if isinstance (data , xr .Dataset ):
203
204
if input_type == 'point_sets' :
204
205
return RenderType .POINT_SET
0 commit comments