44from alphabase .io import tempmmap
55from lmd .lib import SegmentationLoader
66
7+ import h5py
8+ import timeit
9+ import pandas as pd
10+ import pickle
11+ from scipy .sparse import coo_array
12+ from tqdm .auto import tqdm
13+ from functools import partial as func_partial
14+ import multiprocessing as mp
15+
716from scportrait .pipeline ._base import ProcessingStep
17+ from scportrait .pipeline ._utils .helper import flatten
818
19+ import matplotlib .pyplot as plt
920
1021class LMDSelection (ProcessingStep ):
1122 """
1223 Select single cells from a segmented sdata file and generate cutting data for the Leica LMD microscope.
1324 This method class relies on the functionality of the pylmd library.
1425 """
1526
16- # define all valid path optimization methods used with the "path_optimization" argument in the configuration
17- VALID_PATH_OPTIMIZERS = ["none" , "hilbert" , "greedy" ]
18-
1927 def __init__ (self , * args , ** kwargs ):
2028 super ().__init__ (* args , ** kwargs )
29+ self ._check_config ()
2130
2231 self .name = None
2332 self .cell_sets = None
2433 self .calibration_marker = None
2534
35+ self .deep_debug = False #flag for deep debugging by developers
36+
37+ def _check_config (self ):
38+ assert "segmentation_channel" in self .config , "segmentation_channel not defined in config"
39+ self .segmentation_channel_to_select = self .config ["segmentation_channel" ]
40+
41+ # check for optional config parameters
42+
43+ #this defines how large the box mask around the center of a cell is for the coordinate extraction
44+ #assumption is that all pixels belonging to each mask are within the box otherwise they will be cut off during cutting contour generation
45+
46+ if "cell_width" in self .config :
47+ self .cell_radius = self .config ["cell_width" ]
48+ else :
49+ self .cell_radius = 100
50+
51+ if "threads" in self .config :
52+ self .threads = self .config ["threads" ]
53+ assert self .threads > 0 , "threads must be greater than 0"
54+ assert isinstance (self .threads , int ), "threads must be an integer"
55+ else :
56+ self .threads = 10
57+
58+ if "batch_size_coordinate_extraction" in self .config :
59+ self .batch_size = self .config ["batch_size_coordinate_extraction" ]
60+ assert self .batch_size > 0 , "batch_size_coordinate_extraction must be greater than 0"
61+ assert isinstance (self .batch_size , int ), "batch_size_coordinate_extraction must be an integer"
62+ else :
63+ self .batch_size = 100
64+
65+ if "orientation_transform" in self .config :
66+ self .orientation_transform = self .config ["orientation_transform" ]
67+ else :
68+ self .orientation_transform = np .array ([[0 , - 1 ], [1 , 0 ]])
69+ self .config ["orientation_transform" ] = self .orientation_transform #ensure its also in config so its passed on to the segmentation loader
70+
71+ if "processes_cell_sets" in self .config :
72+ self .processes_cell_sets = self .config ["processes_cell_sets" ]
73+ assert self .processes_cell_sets > 0 , "processes_cell_sets must be greater than 0"
74+ assert isinstance (self .processes_cell_sets , int ), "processes_cell_sets must be an integer"
75+ else :
76+ self .processes_cell_sets = 1
77+
2678 def _setup_selection (self ):
27- # set orientation transform
28- self .config ["orientation_transform" ] = np .array ([[0 , - 1 ], [1 , 0 ]])
2979
3080 # configure name of extraction
3181 if self .name is None :
@@ -39,6 +89,102 @@ def _setup_selection(self):
3989 savename = name .replace (" " , "_" ) + ".xml"
4090 self .savepath = os .path .join (self .directory , savename )
4191
92+ #check that the segmentation label exists
93+ assert self .segmentation_channel_to_select in self .project .filehandler .get_sdata ()._shared_keys , f"Segmentation channel { self .segmentation_channel_to_select } not found in sdata."
94+
95+ def __get_coords (self ,
96+ cell_ids : list ,
97+ centers :list [tuple [int , int ]],
98+ width :int = 60 ) -> list [tuple [int , np .ndarray ]]:
99+ results = []
100+
101+ _sdata = self .project .filehandler .get_sdata ()
102+ for i , _id in enumerate (cell_ids ):
103+ values = centers [i ]
104+
105+ x_start = np .max ([int (values [0 ]) - width , 0 ])
106+ y_start = np .max ([int (values [1 ]) - width , 0 ])
107+
108+ x_end = x_start + width * 2
109+ y_end = y_start + width * 2
110+
111+ _cropped = _sdata [self .segmentation_channel_to_select ][slice (x_start , x_end ), slice (y_start , y_end )].compute ()
112+
113+ #optional plotting output for deep debugging
114+ if self .deep_debug :
115+ if self .threads == 1 :
116+ plt .figure ()
117+ plt .imshow (_cropped )
118+ plt .show ()
119+ else :
120+ raise ValueError ("Deep debug is not supported with multiple threads." )
121+
122+ sparse = coo_array (_cropped == _id )
123+
124+ if 0 in sparse :
125+ Warning (f"Cell { i } with id { _id } is potentially not fully contained in the bounding mask. Consider increasing the value for the 'cell_width' parameter in your config." )
126+
127+ x = sparse .coords [0 ] + x_start
128+ y = sparse .coords [1 ] + y_start
129+
130+ results .append ((_id , np .array (list (zip (x , y , strict = True )))))
131+
132+ return (results )
133+
134+ def _get_coords_multi (self , width :int , arg : tuple [list [int ], np .ndarray ]) -> list [tuple [int , np .ndarray ]]:
135+ cell_ids , centers = arg
136+ results = self .__get_coords (cell_ids , centers , width )
137+ return (results )
138+
139+ def _get_coords (self ,
140+ cell_ids : list ,
141+ centers :list [tuple [int , int ]],
142+ width :int = 60 ,
143+ batch_size :int = 100 ,
144+ threads :int = 10 ) -> dict :
145+
146+ #create batches
147+ n_batches = int (np .ceil (len (cell_ids )/ batch_size ))
148+ slices = [(i * batch_size , i * batch_size + batch_size ) for i in range (n_batches - 1 )]
149+ slices .append (((n_batches - 1 )* batch_size , len (cell_ids )))
150+
151+ batched_args = [(cell_ids [start :end ], centers [start :end ]) for start , end in slices ]
152+
153+ f = func_partial (self ._get_coords_multi ,
154+ width
155+ )
156+
157+ if threads == 1 : # if only one thread is used, the function is called directly to avoid the overhead of multiprocessing
158+ results = [f (arg ) for arg in batched_args ]
159+ else :
160+ with mp .get_context (self .context ).Pool (processes = threads ) as pool :
161+ results = list (tqdm (
162+ pool .imap (f , batched_args ),
163+ total = len (batched_args ),
164+ desc = "Processing cell batches" ,
165+ )
166+ )
167+ pool .close ()
168+ pool .join ()
169+
170+ results = flatten (results )
171+ return (dict (results ))
172+
173+ def _get_cell_ids (self , cell_sets : list [dict ]) -> list [int ]:
174+ cell_ids = []
175+ for cell_set in cell_sets :
176+ if "classes" in cell_set :
177+ cell_ids .extend (cell_set ["classes" ])
178+ else :
179+ Warning (f"Cell set { cell_set ['name' ]} does not contain any classes." )
180+ return (cell_ids )
181+
182+ def _get_centers (self , cell_ids : list [int ]) -> list [tuple [int , int ]]:
183+ _sdata = self .project .filehandler .get_sdata ()
184+ centers = _sdata ["centers_cells" ].compute ()
185+ centers = centers .loc [cell_ids , :]
186+ return (centers [["y" , "x" ]].values .tolist ()) #needs to be returned as yx to match the coordinate system as saved in spatialdataobjects
187+
42188 def _post_processing_cleanup (self , vars_to_delete : list | None = None ):
43189 if vars_to_delete is not None :
44190 self ._clear_cache (vars_to_delete = vars_to_delete )
@@ -51,7 +197,6 @@ def _post_processing_cleanup(self, vars_to_delete: list | None = None):
51197
52198 def process (
53199 self ,
54- segmentation_name : str ,
55200 cell_sets : list [dict ],
56201 calibration_marker : np .array ,
57202 name : str | None = None ,
@@ -61,9 +206,9 @@ def process(
61206 Under the hood this method relies on the pylmd library and utilizies its `SegmentationLoader` Class.
62207
63208 Args:
64- segmentation_name (str): Name of the segmentation to be used for shape generation in the sdata object.
65209 cell_sets (list of dict): List of dictionaries containing the sets of cells which should be sorted into a single well. Mandatory keys for each dictionary are: name, classes. Optional keys are: well.
66210 calibration_marker (numpy.array): Array of size ‘(3,2)’ containing the calibration marker coordinates in the ‘(row, column)’ format.
211+ name (str, optional): Name of the output file. If not provided, the name will be generated based on the names of the cell sets or if also not specified set to "selected_cells".
67212
68213 Example:
69214
@@ -77,7 +222,6 @@ def process(
77222 # A numpy Array of shape (3, 2) should be passed.
78223 calibration_marker = np.array([marker_0, marker_1, marker_2])
79224
80-
81225 # Sets of cells can be defined by providing a name and a list of classes in a dictionary.
82226 cells_to_select = [{"name": "dataset1", "classes": [1, 2, 3]}]
83227
@@ -122,7 +266,7 @@ def process(
122266 convolution_smoothing: 25
123267
124268 # fold reduction of datapoints for compression
125- poly_compression_factor: 30
269+ rdp: 0.7
126270
127271 # Optimization of the cutting path inbetween shapes
128272 # optimized paths improve the cutting time and the microscopes focus
@@ -160,32 +304,29 @@ def process(
160304
161305 self ._setup_selection ()
162306
163- ## TO Do
164- # check if classes and seglookup table already exist as pickle file
165- # if not create them
166- # else load them and proceed with selection
167-
168- # load segmentation from hdf5
169- self .path_seg_mask = self .filehandler ._load_seg_to_memmap (
170- [segmentation_name ], tmp_dir_abs_path = self ._tmp_dir_path
171- )
307+ print ("Here" , flush = True )
172308
173- segmentation = tempmmap .mmap_array_from_path (self .path_seg_mask )
309+ start_time = timeit .default_timer ()
310+ cell_ids = self ._get_cell_ids (cell_sets )
311+ centers = self ._get_centers (cell_ids )
312+ coord_index = self ._get_coords (cell_ids = cell_ids ,
313+ centers = centers ,
314+ width = self .cell_radius ,
315+ batch_size = self .batch_size ,
316+ threads = self .threads )
317+ self .log (f"Coordinate lookup index calculation took { timeit .default_timer () - start_time } seconds." )
174318
175- # create segmentation loader
176319 sl = SegmentationLoader (
177320 config = self .config ,
178321 verbose = self .debug ,
179322 processes = self .config ["processes_cell_sets" ],
180323 )
181324
182- if len ( segmentation . shape ) == 3 :
183- segmentation = np . squeeze ( segmentation )
184- else :
185- raise ValueError ( f"Segmentation shape is not correct. Expected 2D array, got { segmentation . shape } " )
325+ shape_collection = sl ( None ,
326+ self . cell_sets ,
327+ self . calibration_marker ,
328+ coords_lookup = coord_index )
186329
187- # get shape collections
188- shape_collection = sl (segmentation , self .cell_sets , self .calibration_marker )
189330
190331 if self .debug :
191332 shape_collection .plot (calibration = True )
@@ -196,4 +337,4 @@ def process(
196337 self .log (f"Saved output at { self .savepath } " )
197338
198339 # perform post processing cleanup
199- self ._post_processing_cleanup (vars_to_delete = [shape_collection , sl , segmentation ])
340+ self ._post_processing_cleanup (vars_to_delete = [shape_collection , sl , coord_index ])
0 commit comments