Skip to content

Commit e3fbd30

Browse files
committed
adapt selection workflow to work with new py-lmd version + improves selection performance even for large datasets
see MannLabs/py-lmd#11 for more information
1 parent 405ea46 commit e3fbd30

File tree

1 file changed

+168
-27
lines changed

1 file changed

+168
-27
lines changed

src/scportrait/pipeline/selection.py

Lines changed: 168 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -4,28 +4,78 @@
44
from alphabase.io import tempmmap
55
from lmd.lib import SegmentationLoader
66

7+
import h5py
8+
import timeit
9+
import pandas as pd
10+
import pickle
11+
from scipy.sparse import coo_array
12+
from tqdm.auto import tqdm
13+
from functools import partial as func_partial
14+
import multiprocessing as mp
15+
716
from scportrait.pipeline._base import ProcessingStep
17+
from scportrait.pipeline._utils.helper import flatten
818

19+
import matplotlib.pyplot as plt
920

1021
class LMDSelection(ProcessingStep):
1122
"""
1223
Select single cells from a segmented sdata file and generate cutting data for the Leica LMD microscope.
1324
This method class relies on the functionality of the pylmd library.
1425
"""
1526

16-
# define all valid path optimization methods used with the "path_optimization" argument in the configuration
17-
VALID_PATH_OPTIMIZERS = ["none", "hilbert", "greedy"]
18-
1927
def __init__(self, *args, **kwargs):
2028
super().__init__(*args, **kwargs)
29+
self._check_config()
2130

2231
self.name = None
2332
self.cell_sets = None
2433
self.calibration_marker = None
2534

35+
self.deep_debug = False #flag for deep debugging by developers
36+
37+
def _check_config(self):
38+
assert "segmentation_channel" in self.config, "segmentation_channel not defined in config"
39+
self.segmentation_channel_to_select = self.config["segmentation_channel"]
40+
41+
# check for optional config parameters
42+
43+
#this defines how large the box mask around the center of a cell is for the coordinate extraction
44+
#assumption is that all pixels belonging to each mask are within the box otherwise they will be cut off during cutting contour generation
45+
46+
if "cell_width" in self.config:
47+
self.cell_radius = self.config["cell_width"]
48+
else:
49+
self.cell_radius = 100
50+
51+
if "threads" in self.config:
52+
self.threads = self.config["threads"]
53+
assert self.threads > 0, "threads must be greater than 0"
54+
assert isinstance(self.threads, int), "threads must be an integer"
55+
else:
56+
self.threads = 10
57+
58+
if "batch_size_coordinate_extraction" in self.config:
59+
self.batch_size = self.config["batch_size_coordinate_extraction"]
60+
assert self.batch_size > 0, "batch_size_coordinate_extraction must be greater than 0"
61+
assert isinstance(self.batch_size, int), "batch_size_coordinate_extraction must be an integer"
62+
else:
63+
self.batch_size = 100
64+
65+
if "orientation_transform" in self.config:
66+
self.orientation_transform = self.config["orientation_transform"]
67+
else:
68+
self.orientation_transform = np.array([[0, -1], [1, 0]])
69+
self.config["orientation_transform"] = self.orientation_transform #ensure its also in config so its passed on to the segmentation loader
70+
71+
if "processes_cell_sets" in self.config:
72+
self.processes_cell_sets = self.config["processes_cell_sets"]
73+
assert self.processes_cell_sets > 0, "processes_cell_sets must be greater than 0"
74+
assert isinstance(self.processes_cell_sets, int), "processes_cell_sets must be an integer"
75+
else:
76+
self.processes_cell_sets = 1
77+
2678
def _setup_selection(self):
27-
# set orientation transform
28-
self.config["orientation_transform"] = np.array([[0, -1], [1, 0]])
2979

3080
# configure name of extraction
3181
if self.name is None:
@@ -39,6 +89,102 @@ def _setup_selection(self):
3989
savename = name.replace(" ", "_") + ".xml"
4090
self.savepath = os.path.join(self.directory, savename)
4191

92+
#check that the segmentation label exists
93+
assert self.segmentation_channel_to_select in self.project.filehandler.get_sdata()._shared_keys, f"Segmentation channel {self.segmentation_channel_to_select} not found in sdata."
94+
95+
def __get_coords(self,
96+
cell_ids: list,
97+
centers:list[tuple[int, int]],
98+
width:int = 60) -> list[tuple[int, np.ndarray]]:
99+
results = []
100+
101+
_sdata = self.project.filehandler.get_sdata()
102+
for i, _id in enumerate(cell_ids):
103+
values = centers[i]
104+
105+
x_start = np.max([int(values[0]) - width, 0])
106+
y_start = np.max([int(values[1]) - width, 0])
107+
108+
x_end = x_start + width*2
109+
y_end = y_start + width*2
110+
111+
_cropped = _sdata[self.segmentation_channel_to_select][slice(x_start, x_end), slice(y_start, y_end)].compute()
112+
113+
#optional plotting output for deep debugging
114+
if self.deep_debug:
115+
if self.threads == 1:
116+
plt.figure()
117+
plt.imshow(_cropped)
118+
plt.show()
119+
else:
120+
raise ValueError("Deep debug is not supported with multiple threads.")
121+
122+
sparse = coo_array(_cropped == _id)
123+
124+
if 0 in sparse:
125+
Warning(f"Cell {i} with id {_id} is potentially not fully contained in the bounding mask. Consider increasing the value for the 'cell_width' parameter in your config.")
126+
127+
x = sparse.coords[0] + x_start
128+
y = sparse.coords[1] + y_start
129+
130+
results.append((_id, np.array(list(zip(x, y, strict = True)))))
131+
132+
return(results)
133+
134+
def _get_coords_multi(self, width:int, arg: tuple[list[int], np.ndarray]) -> list[tuple[int, np.ndarray]]:
135+
cell_ids, centers = arg
136+
results = self.__get_coords(cell_ids, centers, width)
137+
return(results)
138+
139+
def _get_coords(self,
140+
cell_ids: list,
141+
centers:list[tuple[int, int]],
142+
width:int = 60,
143+
batch_size:int = 100,
144+
threads:int = 10) -> dict:
145+
146+
#create batches
147+
n_batches = int(np.ceil(len(cell_ids)/batch_size))
148+
slices = [(i*batch_size, i*batch_size + batch_size) for i in range(n_batches - 1)]
149+
slices.append(((n_batches - 1)*batch_size, len(cell_ids)))
150+
151+
batched_args = [(cell_ids[start:end], centers[start:end]) for start, end in slices]
152+
153+
f = func_partial(self._get_coords_multi,
154+
width
155+
)
156+
157+
if threads == 1: # if only one thread is used, the function is called directly to avoid the overhead of multiprocessing
158+
results = [f(arg) for arg in batched_args]
159+
else:
160+
with mp.get_context(self.context).Pool(processes=threads) as pool:
161+
results = list(tqdm(
162+
pool.imap(f, batched_args),
163+
total=len(batched_args),
164+
desc="Processing cell batches",
165+
)
166+
)
167+
pool.close()
168+
pool.join()
169+
170+
results = flatten(results)
171+
return(dict(results))
172+
173+
def _get_cell_ids(self, cell_sets: list[dict]) -> list[int]:
174+
cell_ids = []
175+
for cell_set in cell_sets:
176+
if "classes" in cell_set:
177+
cell_ids.extend(cell_set["classes"])
178+
else:
179+
Warning(f"Cell set {cell_set['name']} does not contain any classes.")
180+
return(cell_ids)
181+
182+
def _get_centers(self, cell_ids: list[int]) -> list[tuple[int, int]]:
183+
_sdata = self.project.filehandler.get_sdata()
184+
centers = _sdata["centers_cells"].compute()
185+
centers = centers.loc[cell_ids, :]
186+
return(centers[["y", "x"]].values.tolist()) #needs to be returned as yx to match the coordinate system as saved in spatialdataobjects
187+
42188
def _post_processing_cleanup(self, vars_to_delete: list | None = None):
43189
if vars_to_delete is not None:
44190
self._clear_cache(vars_to_delete=vars_to_delete)
@@ -51,7 +197,6 @@ def _post_processing_cleanup(self, vars_to_delete: list | None = None):
51197

52198
def process(
53199
self,
54-
segmentation_name: str,
55200
cell_sets: list[dict],
56201
calibration_marker: np.array,
57202
name: str | None = None,
@@ -61,9 +206,9 @@ def process(
61206
Under the hood this method relies on the pylmd library and utilizies its `SegmentationLoader` Class.
62207
63208
Args:
64-
segmentation_name (str): Name of the segmentation to be used for shape generation in the sdata object.
65209
cell_sets (list of dict): List of dictionaries containing the sets of cells which should be sorted into a single well. Mandatory keys for each dictionary are: name, classes. Optional keys are: well.
66210
calibration_marker (numpy.array): Array of size ‘(3,2)’ containing the calibration marker coordinates in the ‘(row, column)’ format.
211+
name (str, optional): Name of the output file. If not provided, the name will be generated based on the names of the cell sets or if also not specified set to "selected_cells".
67212
68213
Example:
69214
@@ -77,7 +222,6 @@ def process(
77222
# A numpy Array of shape (3, 2) should be passed.
78223
calibration_marker = np.array([marker_0, marker_1, marker_2])
79224
80-
81225
# Sets of cells can be defined by providing a name and a list of classes in a dictionary.
82226
cells_to_select = [{"name": "dataset1", "classes": [1, 2, 3]}]
83227
@@ -122,7 +266,7 @@ def process(
122266
convolution_smoothing: 25
123267
124268
# fold reduction of datapoints for compression
125-
poly_compression_factor: 30
269+
rdp: 0.7
126270
127271
# Optimization of the cutting path inbetween shapes
128272
# optimized paths improve the cutting time and the microscopes focus
@@ -160,32 +304,29 @@ def process(
160304

161305
self._setup_selection()
162306

163-
## TO Do
164-
# check if classes and seglookup table already exist as pickle file
165-
# if not create them
166-
# else load them and proceed with selection
167-
168-
# load segmentation from hdf5
169-
self.path_seg_mask = self.filehandler._load_seg_to_memmap(
170-
[segmentation_name], tmp_dir_abs_path=self._tmp_dir_path
171-
)
307+
print("Here", flush=True)
172308

173-
segmentation = tempmmap.mmap_array_from_path(self.path_seg_mask)
309+
start_time = timeit.default_timer()
310+
cell_ids = self._get_cell_ids(cell_sets)
311+
centers = self._get_centers(cell_ids)
312+
coord_index = self._get_coords(cell_ids = cell_ids,
313+
centers = centers,
314+
width = self.cell_radius,
315+
batch_size = self.batch_size,
316+
threads = self.threads)
317+
self.log(f"Coordinate lookup index calculation took {timeit.default_timer() - start_time} seconds.")
174318

175-
# create segmentation loader
176319
sl = SegmentationLoader(
177320
config=self.config,
178321
verbose=self.debug,
179322
processes=self.config["processes_cell_sets"],
180323
)
181324

182-
if len(segmentation.shape) == 3:
183-
segmentation = np.squeeze(segmentation)
184-
else:
185-
raise ValueError(f"Segmentation shape is not correct. Expected 2D array, got {segmentation.shape}")
325+
shape_collection = sl(None,
326+
self.cell_sets,
327+
self.calibration_marker,
328+
coords_lookup=coord_index)
186329

187-
# get shape collections
188-
shape_collection = sl(segmentation, self.cell_sets, self.calibration_marker)
189330

190331
if self.debug:
191332
shape_collection.plot(calibration=True)
@@ -196,4 +337,4 @@ def process(
196337
self.log(f"Saved output at {self.savepath}")
197338

198339
# perform post processing cleanup
199-
self._post_processing_cleanup(vars_to_delete=[shape_collection, sl, segmentation])
340+
self._post_processing_cleanup(vars_to_delete=[shape_collection, sl, coord_index])

0 commit comments

Comments
 (0)