Skip to content

Commit 98994e2

Browse files
committed
feat: add simplification level and precision params to merger
1 parent 3db3403 commit 98994e2

File tree

2 files changed

+46
-3
lines changed

2 files changed

+46
-3
lines changed

cellseg_models_pytorch/inference/wsi_segmenter.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,14 @@ def segment(self, save_dir: str, maptype: str = "amap") -> None:
9999

100100
self._has_processed = True
101101

102-
def merge_instances(self, src: str, dst: str, clear_in_dir: bool = False) -> None:
102+
def merge_instances(
103+
self,
104+
src: str,
105+
dst: str,
106+
clear_in_dir: bool = False,
107+
simplify_level: int = 0.3,
108+
precision: int = None,
109+
) -> None:
103110
"""Merge the instances at the image boundaries.
104111
105112
Parameters:
@@ -110,14 +117,19 @@ def merge_instances(self, src: str, dst: str, clear_in_dir: bool = False) -> Non
110117
'.parquet', '.geojson', and '.feather'.
111118
clear_in_dir (bool, default=False):
112119
Whether to clear the source directory after merging.
120+
simplify_level (int, default=1):
121+
The level of simplification to apply to the merged instances.
122+
precision (int, optional):
123+
The precision level to apply to the merged instances. If None, no rounding
124+
is applied.
113125
"""
114126
if not self._has_processed:
115127
raise ValueError("You must segment the instances first.")
116128

117129
in_dir = Path(src)
118130
gdf = gpd.read_parquet(in_dir)
119131
merger = InstMerger(gdf, self.coordinates)
120-
merger.merge(dst)
132+
merger.merge(dst, simplify_level=simplify_level, precision=precision)
121133

122134
if clear_in_dir:
123135
for f in in_dir.glob("*"):

cellseg_models_pytorch/wsi/inst_merger.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
1+
from functools import partial
12
from pathlib import Path
23
from typing import List, Tuple, Union
34

45
import geopandas as gpd
56
import numpy as np
67
import pandas as pd
78
from libpysal.cg import alpha_shape_auto
9+
from shapely import wkt
810
from shapely.geometry import LineString, Polygon, box
11+
from shapely.geometry.base import BaseGeometry
12+
from shapely.wkt import dumps
913
from tqdm import tqdm
1014

1115
__all__ = ["InstMerger"]
@@ -29,7 +33,7 @@ def __init__(
2933
self.gdf = gdf
3034

3135
def merge(
32-
self, dst: str = None, simplify_level: int = 1
36+
self, dst: str = None, simplify_level: int = 1, precision: int = None
3337
) -> Union[gpd.GeoDataFrame, None]:
3438
"""Merge the instances at the image boundaries.
3539
@@ -39,6 +43,9 @@ def merge(
3943
If None, the merged GeoDataFrame is returned.
4044
simplify_level (int, default=1):
4145
The level of simplification to apply to the merged instances.
46+
precision (int, optional):
47+
The precision level to apply to the merged instances. If None, no rounding
48+
is applied.
4249
4350
Returns:
4451
Union[gpd.GeoDataFrame, None]:
@@ -67,6 +74,11 @@ def merge(
6774
drop=True
6875
)
6976
merged.geometry = merged.geometry.simplify(simplify_level)
77+
merged = _set_uid(_set_crs(merged), drop=True)
78+
79+
if precision is not None:
80+
set_prec = partial(_set_geom_precision, precision=precision)
81+
merged["geometry"] = merged["geometry"].apply(set_prec)
7082

7183
if dst is not None:
7284
if suff == ".parquet":
@@ -238,3 +250,22 @@ def _get_classes(
238250
class_names.append(objs.loc[objs.area.idxmax()]["class_name"])
239251

240252
return class_names
253+
254+
255+
def _set_uid(
256+
gdf: gpd.GeoDataFrame, start_ix: int = 0, id_col: str = "uid", drop: bool = False
257+
) -> gpd.GeoDataFrame:
258+
# if id_col not in gdf.columns:
259+
gdf = gdf.assign(**{id_col: range(start_ix, len(gdf) + start_ix)})
260+
gdf = gdf.set_index(id_col, drop=drop)
261+
262+
return gdf
263+
264+
265+
def _set_crs(gdf: gpd.GeoDataFrame, crs: int = 4328) -> bool:
266+
return gdf.set_crs(epsg=crs, allow_override=True)
267+
268+
269+
def _set_geom_precision(geom: BaseGeometry, precision: int = 6) -> BaseGeometry:
270+
wkt_str = dumps(geom, rounding_precision=precision)
271+
return wkt.loads(wkt_str)

0 commit comments

Comments
 (0)