@@ -57,7 +57,13 @@ def __init__(self, *args, **kwargs):
5757 self .overwrite_run_path = self .overwrite
5858
5959 def _get_compression_type (self ):
60- self .compression_type = "lzf" if self .compression else None
60+ if (self .compression is True ) or (self .compression == "lzf" ):
61+ self .compression_type = "lzf"
62+ elif self .compression == "gzip" :
63+ self .compression_type = "gzip"
64+ else :
65+ self .compression_type = None
66+ self .log (f"Compression algorithm: { self .compression_type } " )
6167 return self .compression_type
6268
6369 def _check_config (self ):
@@ -261,24 +267,55 @@ def _get_segmentation_info(self):
261267 f"Found no segmentation masks with key { self .segmentation_key } . Cannot proceed with extraction."
262268 )
263269
264- # get relevant segmentation masks to perform extraction on
265- nucleus_key = f"{ self .segmentation_key } _nucleus"
270+ # intialize default values to track what should be extracted
271+ self .nucleus_key = None
272+ self .cytosol_key = None
273+ self .extract_nucleus_mask = False
274+ self .extract_cytosol_mask = False
266275
267- if nucleus_key in relevant_masks :
268- self . extract_nucleus_mask = True
269- self .nucleus_key = nucleus_key
270- else :
271- self .extract_nucleus_mask = False
272- self .nucleus_key = None
276+ if "segmentation_mask" in self . config :
277+ allowed_mask_values = [ "nucleus" , "cytosol" ]
278+ allowed_mask_values = [ f" { self .segmentation_key } _ { x } " for x in allowed_mask_values ]
279+
280+ if isinstance ( self .config [ "segmentation_mask" ], str ):
281+ assert self .config [ "segmentation_mask" ] in allowed_mask_values
273282
274- cytosol_key = f"{ self .segmentation_key } _cytosol"
283+ if "nucleus" in self .config ["segmentation_mask" ]:
284+ self .nucleus_key = self .config ["segmentation_mask" ]
285+ self .extract_nucleus_mask = True
286+
287+ elif "cytosol" in self .config ["segmentation_mask" ]:
288+ self .cytosol_key = self .config ["segmentation_mask" ]
289+ self .extract_cytosol_mask = True
290+ else :
291+ raise ValueError (
292+ f"Segmentation mask { self .config ['segmentation_mask' ]} is not a valid mask to extract from."
293+ )
294+
295+ elif isinstance (self .config ["segmentation_mask" ], list ):
296+ assert all (x in allowed_mask_values for x in self .config ["segmentation_mask" ])
297+
298+ for x in self .config ["segmentation_mask" ]:
299+ if "nucleus" in x :
300+ self .nucleus_key = x
301+ self .extract_nucleus_mask = True
302+ if "cytosol" in x :
303+ self .cytosol_key = x
304+ self .extract_cytosol_mask = True
275305
276- if cytosol_key in relevant_masks :
277- self .extract_cytosol_mask = True
278- self .cytosol_key = cytosol_key
279306 else :
280- self .extract_cytosol_mask = False
281- self .cytosol_key = None
307+ # get relevant segmentation masks to perform extraction on
308+ nucleus_key = f"{ self .segmentation_key } _nucleus"
309+
310+ if nucleus_key in relevant_masks :
311+ self .extract_nucleus_mask = True
312+ self .nucleus_key = nucleus_key
313+
314+ cytosol_key = f"{ self .segmentation_key } _cytosol"
315+
316+ if cytosol_key in relevant_masks :
317+ self .extract_cytosol_mask = True
318+ self .cytosol_key = cytosol_key
282319
283320 self .n_masks = np .sum ([self .extract_nucleus_mask , self .extract_cytosol_mask ])
284321 self .masks = [x for x in [self .nucleus_key , self .cytosol_key ] if x is not None ]
@@ -415,7 +452,7 @@ def _save_removed_classes(self, classes):
415452 # define path where classes should be saved
416453 filtered_path = os .path .join (
417454 self .project_location ,
418- self .DEFAULT_SEGMENTATION_DIR_NAME ,
455+ self .DEFAULT_EXTRACTION_DIR_NAME ,
419456 self .DEFAULT_REMOVED_CLASSES_FILE ,
420457 )
421458
@@ -636,7 +673,7 @@ def _transfer_tempmmap_to_hdf5(self):
636673 axs [i ].imshow (img , vmin = 0 , vmax = 1 )
637674 axs [i ].axis ("off" )
638675 fig .tight_layout ()
639- fig .show ()
676+ plt .show (fig )
640677
641678 self .log ("Transferring extracted single cells to .hdf5" )
642679
@@ -651,7 +688,8 @@ def _transfer_tempmmap_to_hdf5(self):
651688 ) # increase to 64 bit otherwise information may become truncated
652689
653690 self .log ("single-cell index created." )
654- self ._clear_cache (vars_to_delete = [cell_ids ])
691+ del cell_ids
692+ # self._clear_cache(vars_to_delete=[cell_ids]) # this is not working as expected so we will just delete the variable directly
655693
656694 _ , c , x , y = _tmp_single_cell_data .shape
657695 single_cell_data = hf .create_dataset (
@@ -668,7 +706,8 @@ def _transfer_tempmmap_to_hdf5(self):
668706 single_cell_data [ix ] = _tmp_single_cell_data [i ]
669707
670708 self .log ("single-cell data created" )
671- self ._clear_cache (vars_to_delete = [single_cell_data ])
709+ del single_cell_data
710+ # self._clear_cache(vars_to_delete=[single_cell_data]) # this is not working as expected so we will just delete the variable directly
672711
673712 # also transfer labelled index to HDF5
674713 index_labelled = _tmp_single_cell_index [keep_index ]
@@ -684,18 +723,27 @@ def _transfer_tempmmap_to_hdf5(self):
684723 hf .create_dataset ("single_cell_index_labelled" , data = index_labelled , chunks = None , dtype = dt )
685724
686725 self .log ("single-cell index labelled created." )
687- self ._clear_cache (vars_to_delete = [index_labelled ])
726+ del index_labelled
727+ # self._clear_cache(vars_to_delete=[index_labelled]) # this is not working as expected so we will just delete the variable directly
688728
689729 hf .create_dataset (
690730 "channel_information" ,
691731 data = np .char .encode (self .channel_names .astype (str )),
692732 dtype = h5py .special_dtype (vlen = str ),
693733 )
694734
735+ hf .create_dataset (
736+ "n_masks" ,
737+ data = self .n_masks ,
738+ dtype = int ,
739+ )
740+
695741 self .log ("channel information created." )
696742
697743 # cleanup memory
698- self ._clear_cache (vars_to_delete = [_tmp_single_cell_index , index_labelled ])
744+ del _tmp_single_cell_index
745+ # self._clear_cache(vars_to_delete=[_tmp_single_cell_index]) # this is not working as expected so we will just delete the variable directly
746+
699747 os .remove (self ._tmp_single_cell_data_path )
700748 os .remove (self ._tmp_single_cell_index_path )
701749
@@ -808,7 +856,6 @@ def process(self, partial=False, n_cells=None, seed=42):
808856 # directory where intermediate results should be saved
809857 cache: "/mnt/temp/cache"
810858 """
811-
812859 total_time_start = timeit .default_timer ()
813860
814861 start_setup = timeit .default_timer ()
@@ -871,31 +918,33 @@ def process(self, partial=False, n_cells=None, seed=42):
871918
872919 self .log ("Running in single threaded mode." )
873920 results = []
874- for arg in tqdm (args ):
921+ for arg in tqdm (args , total = len ( args ), desc = "Processing cell batches" ):
875922 x = f (arg )
876923 results .append (x )
877924 else :
878925 # set up function for multi-threaded processing
879926 f = func_partial (self ._extract_classes_multi , self .px_centers )
880- batched_args = self ._generate_batched_args (args )
927+ args = self ._generate_batched_args (args )
881928
882929 self .log (f"Running in multiprocessing mode with { self .threads } threads." )
883930 with mp .get_context ("fork" ).Pool (
884931 processes = self .threads
885932 ) as pool : # both spawn and fork work but fork is faster so forcing fork here
886933 results = list (
887934 tqdm (
888- pool .imap (f , batched_args ),
889- total = len (batched_args ),
935+ pool .imap (f , args ),
936+ total = len (args ),
890937 desc = "Processing cell batches" ,
891938 )
892939 )
893940 pool .close ()
894941 pool .join ()
895- print ("multiprocessing done." )
896942
897943 self .save_index_to_remove = flatten (results )
898944
945+ # cleanup memory and remove any no longer required variables
946+ del results , args
947+ # self._clear_cache(vars_to_delete=["results", "args"]) # this is not working as expected at the moment so need to manually delete the variables
899948 stop_extraction = timeit .default_timer ()
900949
901950 # calculate duration
@@ -912,7 +961,6 @@ def process(self, partial=False, n_cells=None, seed=42):
912961 self .DEFAULT_LOG_NAME = "processing.log" # change log name back to default
913962
914963 self ._post_extraction_cleanup ()
915-
916964 total_time_stop = timeit .default_timer ()
917965 total_time = total_time_stop - total_time_start
918966
0 commit comments