@@ -55,6 +55,8 @@ def _get_shape_element(sdata, element_name) -> tuple[int, int]:
5555 _ , x , y = shape
5656 elif len (shape ) == 2 :
5757 x , y = shape
58+ else :
59+ raise ValueError (f"Unsupported shape for element '{ element_name } ': expected 2D or 3D array, got { shape } ." )
5860 return x , y
5961
6062
@@ -115,6 +117,7 @@ def plot_image(
115117 if ax is not None :
116118 if dpi is not None :
117119 warnings .warn ("DPI is ignored when an axis is provided." , stacklevel = 2 )
120+ fig = ax .figure
118121 else :
119122 # get size of spatialdata object to plot (required for calculating figure size if DPI is set)
120123 x , y = _get_shape_element (sdata , image_name )
@@ -189,6 +192,7 @@ def plot_segmentation_mask(
189192 if ax is not None :
190193 if dpi is not None :
191194 warnings .warn ("DPI is ignored when an axis is provided." , stacklevel = 2 )
195+ fig = ax .figure
192196 else :
193197 # get size of spatialdata object to plot (required for calculating figure size if DPI is set)
194198 x , y = _get_shape_element (sdata , masks [0 ])
@@ -217,20 +221,27 @@ def plot_segmentation_mask(
217221 if selected_channels is not None :
218222 if not isinstance (selected_channels , Iterable ):
219223 selected_channels = [selected_channels ]
224+ if any (i < 0 or i >= len (channel_names ) for i in selected_channels ):
225+ raise ValueError (
226+ f"selected_channels contains out-of-range indices for background image '{ background_image } '."
227+ )
228+ if len (selected_channels ) > len (PALETTE ):
229+ raise ValueError ("selected_channels has more entries than the available palette length." )
220230 channel_names = [channel_names [i ] for i in selected_channels ]
221231 c = len (channel_names )
222- palette = [ PALETTE [x ] for x in selected_channels ]
232+ palette = PALETTE [: c ]
223233 else :
224234 if c > max_channels_to_plot :
225- c = 4
235+ c = min ( c , max_channels_to_plot )
226236 palette = PALETTE [:c ]
227237 channel_names = list (channel_names [:c ])
228238
229239 sdata .pl .render_images (background_image , channel = channel_names , palette = palette ).pl .show (ax = ax , colorbar = False )
230240
231241 # plot selected segmentation masks
232242 for mask in masks :
233- assert mask in sdata , f"Mask { mask } not found in sdata object."
243+ if mask not in sdata :
244+ raise KeyError (f"Mask { mask } not found in sdata object." )
234245 if f"{ mask } _vectorized" not in sdata :
235246 sdata [f"{ mask } _vectorized" ] = spatialdata .to_polygons (sdata [mask ])
236247 sdata .pl .render_shapes (
@@ -298,6 +309,7 @@ def plot_shapes(
298309 if ax is not None :
299310 if dpi is not None :
300311 warnings .warn ("DPI is ignored when an axis is provided." , stacklevel = 2 )
312+ fig = ax .figure
301313 else :
302314 # get size of spatialdata object to plot (required for calculating figure size if DPI is set)
303315 x , y = _get_shape_element (sdata , shapes_layer )
@@ -306,7 +318,8 @@ def plot_shapes(
306318 fig , ax = _create_figure_dpi (x = x , y = y , dpi = dpi )
307319
308320 # plot selected shapes layer
309- assert shapes_layer in sdata , f"Shapes layer { shapes_layer } not found in sdata object."
321+ if shapes_layer not in sdata :
322+ raise KeyError (f"Shapes layer { shapes_layer } not found in sdata object." )
310323
311324 sdata .pl .render_shapes (
312325 f"{ shapes_layer } " ,
@@ -374,6 +387,7 @@ def plot_labels(
374387 if ax is not None :
375388 if dpi is not None :
376389 warnings .warn ("DPI is ignored when an axis is provided." , stacklevel = 2 )
390+ fig = ax .figure
377391 else :
378392 # get size of spatialdata object to plot (required for calculating figure size if DPI is set)
379393 x , y = _get_shape_element (sdata , label_layer )
@@ -410,18 +424,25 @@ def plot_labels(
410424 annotating_table = spatialdata .models .TableModel .parse (annotating_table )
411425 break
412426 if found_annotation is not None :
427+ had_annotation = "_annotation" in sdata
428+ prev_annotation = sdata ["_annotation" ] if had_annotation else None
413429 sdata ["_annotation" ] = annotating_table
414- sdata .pl .render_shapes (
415- f"{ label_layer } _vectorized" ,
416- color = color ,
417- fill_alpha = fill_alpha ,
418- outline_alpha = 0 ,
419- cmap = cmap ,
420- palette = palette ,
421- groups = groups ,
422- norm = norm ,
423- ).pl .show (ax = ax )
424- del sdata ["_annotation" ] # delete element again after plotting
430+ try :
431+ sdata .pl .render_shapes (
432+ f"{ label_layer } _vectorized" ,
433+ color = color ,
434+ fill_alpha = fill_alpha ,
435+ outline_alpha = 0 ,
436+ cmap = cmap ,
437+ palette = palette ,
438+ groups = groups ,
439+ norm = norm ,
440+ ).pl .show (ax = ax )
441+ finally :
442+ if had_annotation :
443+ sdata ["_annotation" ] = prev_annotation
444+ else :
445+ del sdata ["_annotation" ] # delete element again after plotting
425446 else :
426447 try :
427448 sdata .pl .render_labels (
0 commit comments