|
1 | 1 | from typing import Any, Literal, TypeAlias |
2 | 2 |
|
| 3 | +import numpy as np |
3 | 4 | from spatialdata import SpatialData |
4 | 5 | from spatialdata.models import get_model |
5 | 6 | from xarray import DataArray, DataTree |
6 | 7 |
|
7 | 8 | ObjectType: TypeAlias = Literal["images", "labels", "points", "tables", "shapes"] |
8 | 9 |
|
9 | 10 |
|
10 | | -def _get_shape(elem: DataArray | DataTree) -> tuple[int, int] | tuple[int, int, int]: |
| 11 | +def _get_shape(elem: DataArray | DataTree) -> tuple[int, int, int]: |
11 | 12 | """Get the shape of the element. |
12 | 13 |
|
13 | 14 | Args: |
14 | 15 | elem: Element to get the shape of |
15 | 16 |
|
16 | 17 | Returns: |
17 | | - Tuple of the shape of the element |
| 18 | + Tuple of the shape of the element with c, x, y dimensions. If the element is 2D, the first dimension is set to np.nan. |
18 | 19 | """ |
19 | 20 | if isinstance(elem, DataArray): |
20 | | - return elem.shape |
| 21 | + shape = elem.shape |
21 | 22 | elif isinstance(elem, DataTree): |
22 | | - return elem.scale0.image.shape |
| 23 | + shape = elem.scale0.image.shape |
23 | 24 | else: |
24 | 25 | raise ValueError(f"Element type {type(elem)} not supported.") |
25 | 26 |
|
| 27 | + if len(shape) == 2: |
| 28 | + shape = (np.nan, shape[0], shape[1]) |
| 29 | + elif len(shape) == 3: |
| 30 | + shape = (shape[0], shape[1], shape[2]) |
| 31 | + return shape |
| 32 | + |
26 | 33 |
|
27 | 34 | def _make_key_lookup(sdata: SpatialData) -> dict: |
28 | 35 | """Make a lookup dictionary for the keys in the SpatialData object. |
|
0 commit comments