1+ from collections .abc import MutableMapping
2+ from re import S
3+
14import itkwasm
25import numpy as np
36import zarr
7+ from multiscale_spatial_image import MultiscaleSpatialImage , to_multiscale , itk_image_to_multiscale
8+ from spatial_image import to_spatial_image , is_spatial_image
49
5- from .dask import HAVE_DASK , dask_array_to_ndarray
10+ import dask
11+ import xarray as xr
612from .itk import HAVE_ITK , itk_image_to_wasm_image , itk_group_spatial_object_to_wasm_point_set
713from .pytorch import HAVE_TORCH
8- from .vtk import HAVE_VTK , vtk_image_to_ndarray , vtk_polydata_to_vtkjs
9- from .xarray import HAVE_XARRAY , xarray_data_array_to_numpy , xarray_data_set_to_numpy
14+ from .vtk import HAVE_VTK , vtk_image_to_spatial_image , vtk_polydata_to_vtkjs
15+ from .xarray import xarray_data_array_to_numpy , xarray_data_set_to_numpy
1016from ..render_types import RenderType
1117
18+ def _spatial_image_scale_factors (spatial_image , min_length ):
19+ sizes = dict (spatial_image .sizes )
20+ scale_factors = []
21+ dims = spatial_image .dims
22+ previous = { d : 1 for d in { 'x' , 'y' , 'z' }.intersection (dims ) }
23+ while (np .array (list (sizes .values ())) > min_length ).any ():
24+ max_size = np .array (list (sizes .values ())).max ()
25+ to_skip = { d : sizes [d ] <= max_size / 2 for d in previous .keys () }
26+ scale_factor = {}
27+ for dim in previous .keys ():
28+ if to_skip [dim ]:
29+ scale_factor [dim ] = previous [dim ]
30+ continue
31+ scale_factor [dim ] = 2 * previous [dim ]
32+
33+ sizes [dim ] = int (sizes [dim ] / 2 )
34+ previous = scale_factor
35+ scale_factors .append (scale_factor )
36+
37+ return scale_factors
38+
39+ def _make_multiscale_store (multiscale ):
40+ # Todo: for very large images serialize to disk cache
41+ store = zarr .storage .MemoryStore (dimension_separator = '/' )
42+ multiscale .to_zarr (store , compute = True )
43+ return store
1244
1345def _get_viewer_image (image ):
46+ min_length = 64
47+ if isinstance (image , MultiscaleSpatialImage ):
48+ return _make_multiscale_store (image )
49+
50+ # Todo: support for itkwasm.Image
1451 if HAVE_ITK :
1552 import itk
1653 if isinstance (image , itk .Image ):
17- return itk_image_to_wasm_image (image )
54+ dimension = image .GetImageDimension ()
55+ size = np .array (itk .size (image ))
56+ scale_factors = []
57+ dims = ('x' , 'y' , 'z' )
58+ previous = {'x' : 1 , 'y' : 1 , 'z' : 1 }
59+ while (size > min_length ).any ():
60+ to_skip = size <= size .max () / 2
61+ scale_factor = {}
62+ for dim in range (dimension ):
63+ if to_skip [dim ]:
64+ scale_factor [dims [dim ]] = previous [dims [dim ]]
65+ continue
66+ scale_factor [dims [dim ]] = 2 * previous [dims [dim ]]
67+ size [dim ] = int (size [dim ] / 2 )
68+ previous = scale_factor
69+ scale_factors .append (scale_factor )
70+
71+ multiscale = itk_image_to_multiscale (image , scale_factors = scale_factors )
72+ return _make_multiscale_store (multiscale )
73+
1874 if HAVE_VTK :
1975 import vtk
2076 if isinstance (image , vtk .vtkImageData ):
21- return vtk_image_to_ndarray (image )
22- if HAVE_DASK :
23- import dask
24- if isinstance (image , dask .array .core .Array ):
25- return dask_array_to_ndarray (image )
77+ spatial_image = vtk_image_to_spatial_image (image )
78+ scale_factors = _spatial_image_scale_factors (spatial_image , min_length )
79+ multiscale = to_multiscale (spatial_image , scale_factors )
80+ return _make_multiscale_store (multiscale )
81+
82+ if isinstance (image , dask .array .core .Array ):
83+ spatial_image = to_spatial_image (image )
84+ scale_factors = _spatial_image_scale_factors (spatial_image , min_length )
85+ multiscale = to_multiscale (spatial_image , scale_factors )
86+ return _make_multiscale_store (multiscale )
87+
88+ if isinstance (image , zarr .Array ):
89+ spatial_image = to_spatial_image (image )
90+ scale_factors = _spatial_image_scale_factors (spatial_image , min_length )
91+ multiscale = to_multiscale (spatial_image , scale_factors )
92+ return _make_multiscale_store (multiscale )
93+
94+ # NGFF Zarr
95+ if isinstance (image , zarr .Group ) and 'multiscales' in image .attrs :
96+ return _make_multiscale_store (image .store )
97+
2698 if HAVE_TORCH :
2799 import torch
28100 if isinstance (image , torch .Tensor ):
29- return image .numpy ()
30- if HAVE_XARRAY :
31- import xarray
32- if isinstance (image , xarray .DataArray ):
33- return xarray_data_array_to_numpy (image )
34- if isinstance (image , xarray .Dataset ):
35- return xarray_data_set_to_numpy (image )
36- return image
101+ spatial_image = to_spatial_image (image .numpy ())
102+ scale_factors = _spatial_image_scale_factors (spatial_image , min_length )
103+ multiscale = to_multiscale (spatial_image , scale_factors )
104+ return _make_multiscale_store (multiscale )
105+
106+ # Todo: preserve dask Array, if present, check if dims are NGFF -> use dims, coords
107+ # Check if coords are uniform, if not, resample
108+ if isinstance (image , xr .DataArray ):
109+ if is_spatial_image (image ):
110+ scale_factors = _spatial_image_scale_factors (image , min_length )
111+ multiscale = to_multiscale (image , scale_factors )
112+ return _make_multiscale_store (multiscale )
113+
114+ return xarray_data_array_to_numpy (image )
115+ if isinstance (image , xr .Dataset ):
116+ da = image [next (iter (image .variables .keys ()))]
117+ if is_spatial_image (da ):
118+ scale_factors = _spatial_image_scale_factors (da , min_length )
119+ multiscale = to_multiscale (da , scale_factors )
120+ return _make_multiscale_store (multiscale )
121+ return xarray_data_set_to_numpy (image )
122+
123+ if isinstance (image , np .ndarray ):
124+ spatial_image = to_spatial_image (image )
125+ scale_factors = _spatial_image_scale_factors (spatial_image , min_length )
126+ multiscale = to_multiscale (spatial_image , scale_factors )
127+ return _make_multiscale_store (multiscale )
128+ raise RuntimeError ("Could not process the viewer image" )
37129
38130
39131def _get_viewer_point_sets (point_sets ):
40132 if HAVE_VTK :
41133 import vtk
42134 if isinstance (point_sets , vtk .vtkPolyData ):
43135 return vtk_polydata_to_vtkjs (point_sets )
44- if HAVE_DASK :
45- import dask
46- if isinstance (point_sets , dask .array .core .Array ):
47- return dask_array_to_ndarray (point_sets )
136+ if isinstance (point_sets , dask .array .core .Array ):
137+ return np .asarray (point_sets )
48138 if HAVE_TORCH :
49139 import torch
50140 if isinstance (point_sets , torch .Tensor ):
51141 return point_sets .numpy ()
52- if HAVE_XARRAY :
53- import xarray
54- if isinstance (point_sets , xarray .DataArray ):
55- return xarray_data_array_to_numpy (point_sets )
56- if isinstance (point_sets , xarray .Dataset ):
57- return xarray_data_set_to_numpy (point_sets )
142+ if isinstance (point_sets , xr .DataArray ):
143+ return xarray_data_array_to_numpy (point_sets )
144+ if isinstance (point_sets , xr .Dataset ):
145+ return xarray_data_set_to_numpy (point_sets )
58146 return point_sets
59147
60148
@@ -63,6 +151,13 @@ def _detect_render_type(data, input_type) -> RenderType:
63151 return RenderType .IMAGE
64152 elif isinstance (data , itkwasm .PointSet ):
65153 return RenderType .POINT_SET
154+ elif isinstance (data , MultiscaleSpatialImage ):
155+ return RenderType .IMAGE
156+ elif isinstance (data , (zarr .Array , zarr .Group )):
157+ # For now assume zarr.Group is an image
158+ # In the future, once NGFF supports point sets fully
159+ # We may need to do more introspection
160+ return RenderType .IMAGE
66161 elif isinstance (data , np .ndarray ):
67162 if input_type == 'point_sets' :
68163 return RenderType .POINT_SET
@@ -83,29 +178,25 @@ def _detect_render_type(data, input_type) -> RenderType:
83178 return RenderType .IMAGE
84179 elif isinstance (data , vtk .vtkPolyData ):
85180 return RenderType .POINT_SET
86- if HAVE_DASK :
87- import dask
88- if isinstance (data , dask .array .core .Array ):
89- if input_type == 'point_sets' :
90- return RenderType .POINT_SET
91- else :
92- return RenderType .IMAGE
181+ if isinstance (data , dask .array .core .Array ):
182+ if input_type == 'point_sets' :
183+ return RenderType .POINT_SET
184+ else :
185+ return RenderType .IMAGE
93186 if HAVE_TORCH :
94187 import torch
95188 if isinstance (data , torch .Tensor ):
96189 if input_type == 'point_sets' :
97190 return RenderType .POINT_SET
98191 else :
99192 return RenderType .IMAGE
100- if HAVE_XARRAY :
101- import xarray
102- if isinstance (data , xarray .DataArray ):
103- if input_type == 'point_sets' :
104- return RenderType .POINT_SET
105- else :
106- return RenderType .IMAGE
107- if isinstance (data , xarray .Dataset ):
108- if input_type == 'point_sets' :
109- return RenderType .POINT_SET
110- else :
111- return RenderType .IMAGE
193+ if isinstance (data , xr .DataArray ):
194+ if input_type == 'point_sets' :
195+ return RenderType .POINT_SET
196+ else :
197+ return RenderType .IMAGE
198+ if isinstance (data , xr .Dataset ):
199+ if input_type == 'point_sets' :
200+ return RenderType .POINT_SET
201+ else :
202+ return RenderType .IMAGE
0 commit comments