Skip to content

Commit 79729ef

Browse files
committed
ensure we can override and add test
1 parent 8aa110e commit 79729ef

File tree

2 files changed

+68
-3
lines changed

2 files changed

+68
-3
lines changed

src/ome_arrow/core.py

Lines changed: 56 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ def __init__(
5959
tcz: Tuple[int, int, int] = (0, 0, 0),
6060
column_name: str = "ome_arrow",
6161
row_index: int = 0,
62+
image_type: str | None = None,
6263
) -> None:
6364
"""
6465
Construct an OMEArrow from:
@@ -71,6 +72,7 @@ def __init__(
7172
with from_numpy defaults)
7273
- a dict already matching the OME-Arrow schema
7374
- a pa.StructScalar already typed to OME_ARROW_STRUCT
75+
- optionally override/set image_type metadata on ingest
7476
"""
7577

7678
# set the tcz for viewing
@@ -83,6 +85,7 @@ def __init__(
8385
default_dim_for_unspecified="C",
8486
map_series_to="T",
8587
clamp_to_uint16=True,
88+
image_type=image_type,
8689
)
8790

8891
# --- 2) String path/URL: OME-Zarr / OME-Parquet / OME-TIFF ---------------
@@ -98,6 +101,14 @@ def __init__(
98101
or (path.exists() and path.is_dir() and path.suffix.lower() == ".zarr")
99102
):
100103
self.data = from_ome_zarr(s)
104+
if image_type is not None:
105+
self.data = pa.scalar(
106+
{
107+
**self.data.as_py(),
108+
"image_type": str(image_type),
109+
},
110+
type=OME_ARROW_STRUCT,
111+
)
101112

102113
# OME-Parquet
103114
elif s.lower().endswith((".parquet", ".pq")) or path.suffix.lower() in {
@@ -107,18 +118,42 @@ def __init__(
107118
self.data = from_ome_parquet(
108119
s, column_name=column_name, row_index=row_index
109120
)
121+
if image_type is not None:
122+
self.data = pa.scalar(
123+
{
124+
**self.data.as_py(),
125+
"image_type": str(image_type),
126+
},
127+
type=OME_ARROW_STRUCT,
128+
)
110129

111130
# Vortex
112131
elif s.lower().endswith(".vortex") or path.suffix.lower() == ".vortex":
113132
self.data = from_ome_vortex(
114133
s, column_name=column_name, row_index=row_index
115134
)
135+
if image_type is not None:
136+
self.data = pa.scalar(
137+
{
138+
**self.data.as_py(),
139+
"image_type": str(image_type),
140+
},
141+
type=OME_ARROW_STRUCT,
142+
)
116143

117144
# TIFF
118145
elif path.suffix.lower() in {".tif", ".tiff"} or s.lower().endswith(
119146
(".tif", ".tiff")
120147
):
121148
self.data = from_tiff(s)
149+
if image_type is not None:
150+
self.data = pa.scalar(
151+
{
152+
**self.data.as_py(),
153+
"image_type": str(image_type),
154+
},
155+
type=OME_ARROW_STRUCT,
156+
)
122157

123158
elif path.exists() and path.is_dir():
124159
raise ValueError(
@@ -140,15 +175,33 @@ def __init__(
140175
# Uses from_numpy defaults: dim_order="TCZYX", clamp_to_uint16=True, etc.
141176
# If the array is YX/ZYX/CYX/etc.,
142177
# from_numpy will expand/reorder accordingly.
143-
self.data = from_numpy(data)
178+
self.data = from_numpy(data, image_type=image_type)
144179

145180
# --- 4) Already-typed Arrow scalar ---------------------------------------
146181
elif isinstance(data, pa.StructScalar):
147-
self.data = data
182+
if image_type is None:
183+
self.data = data
184+
else:
185+
self.data = pa.scalar(
186+
{
187+
**data.as_py(),
188+
"image_type": str(image_type),
189+
},
190+
type=OME_ARROW_STRUCT,
191+
)
148192

149193
# --- 5) Plain dict matching the schema -----------------------------------
150194
elif isinstance(data, dict):
151-
self.data = pa.scalar(data, type=OME_ARROW_STRUCT)
195+
if image_type is None:
196+
self.data = pa.scalar(data, type=OME_ARROW_STRUCT)
197+
else:
198+
self.data = pa.scalar(
199+
{
200+
**data,
201+
"image_type": str(image_type),
202+
},
203+
type=OME_ARROW_STRUCT,
204+
)
152205

153206
# --- otherwise ------------------------------------------------------------
154207
else:

tests/test_core.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,18 @@ def test_vortex_roundtrip(tmp_path: pathlib.Path) -> None:
369369
assert reloaded.info() == oa.info()
370370

371371

372+
def test_parquet_roundtrip_preserves_image_type(tmp_path: pathlib.Path) -> None:
373+
"""Ensure image_type round-trips through OME-Parquet."""
374+
arr = np.arange(16, dtype=np.uint16).reshape(1, 1, 1, 4, 4)
375+
oa = OMEArrow(arr, image_type="label")
376+
out = tmp_path / "example.ome.parquet"
377+
378+
oa.export(how="omeparquet", out=str(out))
379+
reloaded = OMEArrow(str(out))
380+
381+
assert reloaded.data.as_py()["image_type"] == "label"
382+
383+
372384
def test_vortex_custom_column_name(tmp_path: pathlib.Path) -> None:
373385
"""Ensure custom Vortex column names are preserved on round-trip."""
374386
pytest.importorskip(

0 commit comments

Comments
 (0)