Skip to content

Commit b210ce4

Browse files
committed
fixed scaling of masks
1 parent ba9e120 commit b210ce4

File tree

3 files changed

+96
-76
lines changed

3 files changed

+96
-76
lines changed
Lines changed: 38 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,43 @@
11
transcripts:
2-
filename: "transcripts.parquet"
3-
x: "global_x"
4-
y: "global_y"
5-
z: "global_z"
6-
id: "transcript_id"
7-
label: "gene"
8-
nuclear_column: "CellComp"
9-
nuclear_value: "Nuclear"
10-
filter_substrings:
11-
- "NegPrb_"
12-
- "SystemControl"
13-
- "Negative"
14-
xy:
15-
- "global_x"
16-
- "global_y"
17-
xyz:
18-
- "global_x"
19-
- "global_y"
20-
- "global_z"
21-
columns:
22-
- "global_x"
23-
- "global_y"
24-
- "global_z"
25-
- "gene"
26-
- "cell_id"
27-
- "CellComp"
28-
- "transcript_id"
2+
filename: "detected_transcripts.parquet"
3+
x: "global_x"
4+
y: "global_y"
5+
z: "global_z"
6+
id: "transcript_id"
7+
label: "gene"
8+
nuclear_column: "overlaps_nucleus"
9+
nuclear_value: 1
10+
filter_substrings:
11+
- "Blank-"
12+
- "BLANK"
13+
xy:
14+
- "global_x"
15+
- "global_y"
16+
xyz:
17+
- "global_x"
18+
- "global_y"
19+
- "global_z"
20+
columns:
21+
- "global_x"
22+
- "global_y"
23+
- "global_z"
24+
- "gene"
25+
- "cell_id"
26+
- "overlaps_nucleus"
27+
- "transcript_id"
2928

3029
boundaries:
31-
filename: "nucleus_boundaries.parquet"
32-
x: "global_x"
33-
y: "global_y"
34-
id: "cell"
35-
label: "cell"
30+
filename: "cellpose_nucleus_micron_space.parquet"
31+
geometry: "Geometry"
32+
id: "EntityID"
33+
label: "EntityID"
34+
x: "centroid_x"
35+
y: "centroid_y"
3636
xy:
37-
- "x_global_px"
38-
- "y_global_px"
37+
- "centroid_x"
38+
- "centroid_y"
3939
columns:
40-
- "x_global_px"
41-
- "y_global_px"
42-
- "cell"
40+
- "Geometry"
41+
- "EntityID"
42+
- "centroid_x"
43+
- "centroid_y"

src/segger/data/parquet/_utils.py

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import pandas as pd
22
import geopandas as gpd
33
import shapely
4+
from shapely.affinity import scale
45
from pyarrow import parquet as pq
56
import numpy as np
67
import scipy as sp
@@ -127,11 +128,22 @@ def read_parquet_region(
127128

128129
columns = list({x, y} | set(extra_columns))
129130

130-
region = pd.read_parquet(
131-
filepath,
132-
filters=filters,
133-
columns=columns,
134-
)
131+
# Check if 'Geometry', 'geometry', 'polygon', or 'Polygon' is in the columns
132+
if any(col in columns for col in ['Geometry', 'geometry', 'polygon', 'Polygon']):
133+
import geopandas as gpd
134+
# If geometry columns are present, read with geopandas
135+
region = gpd.read_parquet(
136+
filepath,
137+
filters=filters,
138+
columns=columns,
139+
)
140+
else:
141+
# Otherwise, read with pandas
142+
region = pd.read_parquet(
143+
filepath,
144+
filters=filters,
145+
columns=columns,
146+
)
135147
return region
136148

137149

@@ -140,7 +152,7 @@ def get_polygons_from_xy(
140152
x: str,
141153
y: str,
142154
label: str,
143-
buffer_ratio: float = 1.0,
155+
scale_factor: float = 1.0,
144156
) -> gpd.GeoSeries:
145157
"""
146158
Convert boundary coordinates from a DataFrame to a GeoSeries of polygons.
@@ -156,8 +168,8 @@ def get_polygons_from_xy(
156168
The name of the column representing the y-coordinate.
157169
label : str
158170
The name of the column representing the cell or nucleus label.
159-
buffer_ratio : float, optional
160-
A ratio to expand or shrink the polygons. A value of 1.0 means no change,
171+
scale_factor : float, optional
172+
A ratio to scale the polygons. A value of 1.0 means no change,
161173
greater than 1.0 expands the polygons, and less than 1.0 shrinks the polygons
162174
(default is 1.0).
163175
@@ -181,19 +193,18 @@ def get_polygons_from_xy(
181193
)
182194
gs = gpd.GeoSeries(polygons, index=np.unique(ids))
183195

184-
if buffer_ratio != 1.0:
185-
# Calculate buffer distance based on polygon area
186-
areas = gs.area
187-
# Use the square root of the area to get a linear distance
188-
buffer_distances = np.sqrt(areas / np.pi) * (buffer_ratio - 1.0)
189-
# Apply buffer to each polygon with its specific distance
196+
# print(gs)
197+
198+
if scale_factor != 1.0:
199+
# Scale polygons around their centroid
190200
gs = gpd.GeoSeries(
191201
[
192-
geom.buffer(dist) if dist != 0 else geom
193-
for geom, dist in zip(gs, buffer_distances)
202+
scale(geom, xfact=scale_factor, yfact=scale_factor, origin='centroid')
203+
for geom in gs
194204
],
195205
index=gs.index,
196206
)
207+
# print(gs)
197208

198209
return gs
199210

src/segger/data/parquet/sample.py

Lines changed: 31 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def __init__(
3535
self,
3636
base_dir: os.PathLike,
3737
n_workers: Optional[int] = 1,
38-
buffer_ratio: Optional[float] = 1.0,
38+
scale_factor: Optional[float] = 1.0,
3939
sample_type: str = None,
4040
weights: pd.DataFrame = None,
4141
):
@@ -52,8 +52,8 @@ def __init__(
5252
The sample type of the raw data, e.g., 'xenium' or 'merscope'.
5353
weights : Optional[pd.DataFrame], default None
5454
DataFrame containing weights for transcript embedding.
55-
buffer_ratio : Optional[float], default None
56-
The buffer ratio to be used for expanding the boundary extents
55+
scale_factor : Optional[float], default None
56+
The scale factor to be used for expanding the boundary extents
5757
during spatial queries. If not provided, the default from settings
5858
will be used.
5959
@@ -71,15 +71,15 @@ def __init__(
7171
boundaries_fn = self.settings.boundaries.filename
7272
self._boundaries_filepath = self._base_dir / boundaries_fn
7373
self.n_workers = n_workers
74-
self.settings.boundaries.buffer_ratio = 1
74+
self.settings.boundaries.scale_factor = 1
7575
nuclear_column = getattr(self.settings.transcripts, "nuclear_column", None)
76-
if nuclear_column is None or self.settings.boundaries.buffer_ratio != 1.0:
76+
if nuclear_column is None or self.settings.boundaries.scale_factor != 1.0:
7777
print(
7878
"Boundary-transcript overlap information has not been pre-computed. It will be calculated during tile generation."
7979
)
80-
# Set buffer ratio if provided
81-
if buffer_ratio != 1.0:
82-
self.settings.boundaries.buffer_ratio = buffer_ratio
80+
# Set scale factor if provided
81+
if scale_factor != 1.0:
82+
self.settings.boundaries.scale_factor = scale_factor
8383

8484
# Ensure transcript IDs exist
8585
utils.ensure_transcript_ids(
@@ -1164,13 +1164,12 @@ def get_boundary_props(
11641164
of the code.
11651165
"""
11661166
# Get polygons from coordinates
1167-
polygons = utils.get_polygons_from_xy(
1168-
self.boundaries,
1169-
x=self.settings.boundaries.x,
1170-
y=self.settings.boundaries.y,
1171-
label=self.settings.boundaries.label,
1172-
buffer_ratio=self.settings.boundaries.buffer_ratio,
1173-
)
1167+
# Use getattr to check for the geometry column
1168+
geometry_column = getattr(self.settings.boundaries, 'geometry', None)
1169+
if geometry_column and geometry_column in self.boundaries.columns:
1170+
polygons = self.boundaries[geometry_column]
1171+
else:
1172+
polygons = self.boundaries['geometry'] # Assign None if the geometry column does not exist
11741173
# Geometric properties of polygons
11751174
props = self.get_polygon_props(polygons)
11761175
props = torch.as_tensor(props.values).float()
@@ -1230,13 +1229,22 @@ def to_pyg_dataset(
12301229
pyg_data["tx", "neighbors", "tx"].edge_index = nbrs_edge_idx
12311230

12321231
# Set up Boundary nodes
1233-
polygons = utils.get_polygons_from_xy(
1234-
self.boundaries,
1235-
self.settings.boundaries.x,
1236-
self.settings.boundaries.y,
1237-
self.settings.boundaries.label,
1238-
self.settings.boundaries.buffer_ratio,
1239-
)
1232+
# Check if boundaries have geometries
1233+
geometry_column = getattr(self.settings.boundaries, 'geometry', None)
1234+
if geometry_column and geometry_column in self.boundaries.columns:
1235+
polygons = gpd.GeoSeries(self.boundaries[geometry_column], index=self.boundaries.index)
1236+
else:
1237+
# Fallback: compute polygons
1238+
polygons = utils.get_polygons_from_xy(
1239+
self.boundaries,
1240+
x=self.settings.boundaries.x,
1241+
y=self.settings.boundaries.y,
1242+
label=self.settings.boundaries.label,
1243+
scale_factor=self.settings.boundaries.scale_factor,
1244+
)
1245+
1246+
# Ensure self.boundaries is a GeoDataFrame with correct geometry
1247+
self.boundaries = gpd.GeoDataFrame(self.boundaries.copy(), geometry=polygons)
12401248
centroids = polygons.centroid.get_coordinates()
12411249
pyg_data["bd"].id = polygons.index.to_numpy()
12421250
pyg_data["bd"].pos = torch.tensor(centroids.values, dtype=torch.float32)
@@ -1273,7 +1281,7 @@ def to_pyg_dataset(
12731281
nuclear_column = getattr(self.settings.transcripts, "nuclear_column", None)
12741282
nuclear_value = getattr(self.settings.transcripts, "nuclear_value", None)
12751283

1276-
if nuclear_column is None or self.settings.boundaries.buffer_ratio != 1.0:
1284+
if nuclear_column is None or self.settings.boundaries.scale_factor != 1.0:
12771285
is_nuclear = utils.compute_nuclear_transcripts(
12781286
polygons=polygons,
12791287
transcripts=self.transcripts,

0 commit comments

Comments
 (0)