Skip to content

Commit 907bf19

Browse files
committed
[FIX} _get_shape should always return 3D tuple, if image is only 2D the first is set to np.NaN
1 parent 570e850 commit 907bf19

File tree

1 file changed

+11
-4
lines changed

1 file changed

+11
-4
lines changed

src/scportrait/spdata/write/_helper.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,35 @@
11
from typing import Any, Literal, TypeAlias
22

3+
import numpy as np
34
from spatialdata import SpatialData
45
from spatialdata.models import get_model
56
from xarray import DataArray, DataTree
67

78
ObjectType: TypeAlias = Literal["images", "labels", "points", "tables", "shapes"]
89

910

10-
def _get_shape(elem: DataArray | DataTree) -> tuple[int, int] | tuple[int, int, int]:
11+
def _get_shape(elem: DataArray | DataTree) -> tuple[int, int, int]:
1112
"""Get the shape of the element.
1213
1314
Args:
1415
elem: Element to get the shape of
1516
1617
Returns:
17-
Tuple of the shape of the element
18+
Tuple of the shape of the element with c, x, y dimensions. If the element is 2D, the first dimension is set to np.nan.
1819
"""
1920
if isinstance(elem, DataArray):
20-
return elem.shape
21+
shape = elem.shape
2122
elif isinstance(elem, DataTree):
22-
return elem.scale0.image.shape
23+
shape = elem.scale0.image.shape
2324
else:
2425
raise ValueError(f"Element type {type(elem)} not supported.")
2526

27+
if len(shape) == 2:
28+
shape = (np.nan, shape[0], shape[1])
29+
elif len(shape) == 3:
30+
shape = (shape[0], shape[1], shape[2])
31+
return shape
32+
2633

2734
def _make_key_lookup(sdata: SpatialData) -> dict:
2835
"""Make a lookup dictionary for the keys in the SpatialData object.

0 commit comments

Comments
 (0)