Skip to content

Commit 758e7c5

Browse files
Merge pull request #145 from MannLabs/development
merge development branch
2 parents b46b1e9 + 7391e85 commit 758e7c5

File tree

15 files changed

+511
-153
lines changed

15 files changed

+511
-153
lines changed

requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,10 @@ torch
2929
pytorch-lightning
3030
torchvision
3131

32-
spatialdata
32+
spatialdata>=0.2.0
3333
napari-spatialdata
3434
pyqt5
3535
lxml_html_clean
3636
ashlar>=1.19.0
3737
networkx
38-
py-lmd @ git+https://github.com/MannLabs/py-lmd.git@refs/pull/11/head#egg=py-lmd
38+
py-lmd>=1.3.1

requirements_dev.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,13 @@ torch
2929
pytorch-lightning
3030
torchvision
3131

32-
spatialdata
32+
spatialdata>=0.2.0
3333
napari-spatialdata
3434
pyqt5
3535
lxml_html_clean
3636
ashlar>=1.19.0
3737
networkx
38-
py-lmd @ git+https://github.com/MannLabs/py-lmd.git@refs/pull/11/head#egg=py-lmd
38+
py-lmd>=1.3.1
3939

4040
#packages for building the documentation
4141
sphinx

src/scportrait/__init__.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,18 @@
11
"""Top-level package for scPortrait"""
22

3+
# silence warnings
4+
import warnings
5+
36
from scportrait import io
47
from scportrait import pipeline as pipeline
58
from scportrait import plotting as pl
69
from scportrait import processing as pp
710
from scportrait import tools as tl
11+
12+
# silence warning from spatialdata resulting in an older dask version see #139
13+
warnings.filterwarnings("ignore", message="ignoring keyword argument 'read_only'")
14+
15+
# silence warning from cellpose resulting in missing parameter set in model call see #141
16+
warnings.filterwarnings(
17+
"ignore", message=r"You are using `torch.load` with `weights_only=False`.*", category=FutureWarning
18+
)

src/scportrait/pipeline/_base.py

Lines changed: 52 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
import numpy as np
1010
import torch
1111

12+
from scportrait.pipeline._utils.helper import read_config
13+
1214

1315
class Logable:
1416
"""Create log entries.
@@ -92,6 +94,27 @@ def _clean_log_file(self):
9294
if os.path.exists(log_file_path):
9395
os.remove(log_file_path)
9496

97+
# def _clear_cache(self, vars_to_delete=None):
98+
# """Helper function to help clear memory usage. Mainly relevant for GPU based segmentations.
99+
100+
# Args:
101+
# vars_to_delete (list): List of variable names (as strings) to delete.
102+
# """
103+
104+
# # delete all specified variables
105+
# if vars_to_delete is not None:
106+
# for var_name in vars_to_delete:
107+
# if var_name in globals():
108+
# del globals()[var_name]
109+
110+
# if torch.cuda.is_available():
111+
# torch.cuda.empty_cache()
112+
113+
# if torch.backends.mps.is_available():
114+
# torch.mps.empty_cache()
115+
116+
# gc.collect()
117+
95118
def _clear_cache(self, vars_to_delete=None):
96119
"""Helper function to help clear memory usage. Mainly relevant for GPU based segmentations."""
97120

@@ -137,7 +160,7 @@ class ProcessingStep(Logable):
137160
DEFAULT_SEGMENTATION_DIR_NAME = "segmentation"
138161
DEFAULT_TILES_FOLDER = "tiles"
139162

140-
DEFAULT_EXTRACTIN_DIR_NAME = "extraction"
163+
DEFAULT_EXTRACTION_DIR_NAME = "extraction"
141164
DEFAULT_DATA_DIR = "data"
142165

143166
DEFAULT_IMAGE_DTYPE = np.uint16
@@ -155,19 +178,41 @@ class ProcessingStep(Logable):
155178
DEFAULT_SELECTION_DIR_NAME = "selection"
156179

157180
def __init__(
158-
self, config, directory, project_location, debug=False, overwrite=False, project=None, filehandler=None
181+
self,
182+
config,
183+
directory=None,
184+
project_location=None,
185+
debug=False,
186+
overwrite=False,
187+
project=None,
188+
filehandler=None,
189+
from_project: bool = False,
159190
):
160191
super().__init__(directory=directory)
161192

162193
self.debug = debug
163194
self.overwrite = overwrite
164-
self.project_location = project_location
165-
self.config = config
195+
if from_project:
196+
self.project_run = True
197+
self.project_location = project_location
198+
self.project = project
199+
self.filehandler = filehandler
200+
else:
201+
self.project_run = False
202+
self.project_location = None
203+
self.project = None
204+
self.filehandler = None
205+
206+
if isinstance(config, str):
207+
config = read_config(config)
208+
if self.__class__.__name__ in config.keys():
209+
self.config = config[self.__class__.__name__]
210+
else:
211+
self.config = config
212+
else:
213+
self.config = config
166214
self.overwrite = overwrite
167215

168-
self.project = project
169-
self.filehandler = filehandler
170-
171216
self.get_context()
172217

173218
self.deep_debug = False

src/scportrait/pipeline/_utils/helper.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,20 @@
11
from typing import TypeVar
22

3+
import yaml
4+
35
T = TypeVar("T")
46

57

6-
def flatten(nested_list: list[list[T]]) -> list[T]:
8+
def read_config(config_path: str) -> dict:
9+
with open(config_path) as stream:
10+
try:
11+
config = yaml.safe_load(stream)
12+
except yaml.YAMLError as exc:
13+
print(exc)
14+
return config
15+
16+
17+
def flatten(nested_list: list[list[T]]) -> list[T | tuple[T]]:
718
"""Flatten a list of lists into a single list.
819
920
Args:

src/scportrait/pipeline/_utils/sdata_io.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,10 +71,12 @@ def _read_sdata(self) -> SpatialData:
7171
_sdata = SpatialData()
7272
_sdata.write(self.sdata_path, overwrite=True)
7373

74+
allowed_labels = ["seg_all_nucleus", "seg_all_cytosol"]
7475
for key in _sdata.labels:
75-
segmentation_object = _sdata.labels[key]
76-
if not hasattr(segmentation_object.attrs, "cell_ids"):
77-
segmentation_object = spLabels2DModel().convert(segmentation_object, classes=None)
76+
if key in allowed_labels:
77+
segmentation_object = _sdata.labels[key]
78+
if not hasattr(segmentation_object.attrs, "cell_ids"):
79+
segmentation_object = spLabels2DModel().convert(segmentation_object, classes=None)
7880

7981
return _sdata
8082

src/scportrait/pipeline/extraction.py

Lines changed: 76 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,13 @@ def __init__(self, *args, **kwargs):
5757
self.overwrite_run_path = self.overwrite
5858

5959
def _get_compression_type(self):
60-
self.compression_type = "lzf" if self.compression else None
60+
if (self.compression is True) or (self.compression == "lzf"):
61+
self.compression_type = "lzf"
62+
elif self.compression == "gzip":
63+
self.compression_type = "gzip"
64+
else:
65+
self.compression_type = None
66+
self.log(f"Compression algorithm: {self.compression_type}")
6167
return self.compression_type
6268

6369
def _check_config(self):
@@ -261,24 +267,55 @@ def _get_segmentation_info(self):
261267
f"Found no segmentation masks with key {self.segmentation_key}. Cannot proceed with extraction."
262268
)
263269

264-
# get relevant segmentation masks to perform extraction on
265-
nucleus_key = f"{self.segmentation_key}_nucleus"
270+
# intialize default values to track what should be extracted
271+
self.nucleus_key = None
272+
self.cytosol_key = None
273+
self.extract_nucleus_mask = False
274+
self.extract_cytosol_mask = False
266275

267-
if nucleus_key in relevant_masks:
268-
self.extract_nucleus_mask = True
269-
self.nucleus_key = nucleus_key
270-
else:
271-
self.extract_nucleus_mask = False
272-
self.nucleus_key = None
276+
if "segmentation_mask" in self.config:
277+
allowed_mask_values = ["nucleus", "cytosol"]
278+
allowed_mask_values = [f"{self.segmentation_key}_{x}" for x in allowed_mask_values]
279+
280+
if isinstance(self.config["segmentation_mask"], str):
281+
assert self.config["segmentation_mask"] in allowed_mask_values
273282

274-
cytosol_key = f"{self.segmentation_key}_cytosol"
283+
if "nucleus" in self.config["segmentation_mask"]:
284+
self.nucleus_key = self.config["segmentation_mask"]
285+
self.extract_nucleus_mask = True
286+
287+
elif "cytosol" in self.config["segmentation_mask"]:
288+
self.cytosol_key = self.config["segmentation_mask"]
289+
self.extract_cytosol_mask = True
290+
else:
291+
raise ValueError(
292+
f"Segmentation mask {self.config['segmentation_mask']} is not a valid mask to extract from."
293+
)
294+
295+
elif isinstance(self.config["segmentation_mask"], list):
296+
assert all(x in allowed_mask_values for x in self.config["segmentation_mask"])
297+
298+
for x in self.config["segmentation_mask"]:
299+
if "nucleus" in x:
300+
self.nucleus_key = x
301+
self.extract_nucleus_mask = True
302+
if "cytosol" in x:
303+
self.cytosol_key = x
304+
self.extract_cytosol_mask = True
275305

276-
if cytosol_key in relevant_masks:
277-
self.extract_cytosol_mask = True
278-
self.cytosol_key = cytosol_key
279306
else:
280-
self.extract_cytosol_mask = False
281-
self.cytosol_key = None
307+
# get relevant segmentation masks to perform extraction on
308+
nucleus_key = f"{self.segmentation_key}_nucleus"
309+
310+
if nucleus_key in relevant_masks:
311+
self.extract_nucleus_mask = True
312+
self.nucleus_key = nucleus_key
313+
314+
cytosol_key = f"{self.segmentation_key}_cytosol"
315+
316+
if cytosol_key in relevant_masks:
317+
self.extract_cytosol_mask = True
318+
self.cytosol_key = cytosol_key
282319

283320
self.n_masks = np.sum([self.extract_nucleus_mask, self.extract_cytosol_mask])
284321
self.masks = [x for x in [self.nucleus_key, self.cytosol_key] if x is not None]
@@ -415,7 +452,7 @@ def _save_removed_classes(self, classes):
415452
# define path where classes should be saved
416453
filtered_path = os.path.join(
417454
self.project_location,
418-
self.DEFAULT_SEGMENTATION_DIR_NAME,
455+
self.DEFAULT_EXTRACTION_DIR_NAME,
419456
self.DEFAULT_REMOVED_CLASSES_FILE,
420457
)
421458

@@ -636,7 +673,7 @@ def _transfer_tempmmap_to_hdf5(self):
636673
axs[i].imshow(img, vmin=0, vmax=1)
637674
axs[i].axis("off")
638675
fig.tight_layout()
639-
fig.show()
676+
plt.show(fig)
640677

641678
self.log("Transferring extracted single cells to .hdf5")
642679

@@ -651,7 +688,8 @@ def _transfer_tempmmap_to_hdf5(self):
651688
) # increase to 64 bit otherwise information may become truncated
652689

653690
self.log("single-cell index created.")
654-
self._clear_cache(vars_to_delete=[cell_ids])
691+
del cell_ids
692+
# self._clear_cache(vars_to_delete=[cell_ids]) # this is not working as expected so we will just delete the variable directly
655693

656694
_, c, x, y = _tmp_single_cell_data.shape
657695
single_cell_data = hf.create_dataset(
@@ -668,7 +706,8 @@ def _transfer_tempmmap_to_hdf5(self):
668706
single_cell_data[ix] = _tmp_single_cell_data[i]
669707

670708
self.log("single-cell data created")
671-
self._clear_cache(vars_to_delete=[single_cell_data])
709+
del single_cell_data
710+
# self._clear_cache(vars_to_delete=[single_cell_data]) # this is not working as expected so we will just delete the variable directly
672711

673712
# also transfer labelled index to HDF5
674713
index_labelled = _tmp_single_cell_index[keep_index]
@@ -684,18 +723,27 @@ def _transfer_tempmmap_to_hdf5(self):
684723
hf.create_dataset("single_cell_index_labelled", data=index_labelled, chunks=None, dtype=dt)
685724

686725
self.log("single-cell index labelled created.")
687-
self._clear_cache(vars_to_delete=[index_labelled])
726+
del index_labelled
727+
# self._clear_cache(vars_to_delete=[index_labelled]) # this is not working as expected so we will just delete the variable directly
688728

689729
hf.create_dataset(
690730
"channel_information",
691731
data=np.char.encode(self.channel_names.astype(str)),
692732
dtype=h5py.special_dtype(vlen=str),
693733
)
694734

735+
hf.create_dataset(
736+
"n_masks",
737+
data=self.n_masks,
738+
dtype=int,
739+
)
740+
695741
self.log("channel information created.")
696742

697743
# cleanup memory
698-
self._clear_cache(vars_to_delete=[_tmp_single_cell_index, index_labelled])
744+
del _tmp_single_cell_index
745+
# self._clear_cache(vars_to_delete=[_tmp_single_cell_index]) # this is not working as expected so we will just delete the variable directly
746+
699747
os.remove(self._tmp_single_cell_data_path)
700748
os.remove(self._tmp_single_cell_index_path)
701749

@@ -808,7 +856,6 @@ def process(self, partial=False, n_cells=None, seed=42):
808856
# directory where intermediate results should be saved
809857
cache: "/mnt/temp/cache"
810858
"""
811-
812859
total_time_start = timeit.default_timer()
813860

814861
start_setup = timeit.default_timer()
@@ -871,31 +918,33 @@ def process(self, partial=False, n_cells=None, seed=42):
871918

872919
self.log("Running in single threaded mode.")
873920
results = []
874-
for arg in tqdm(args):
921+
for arg in tqdm(args, total=len(args), desc="Processing cell batches"):
875922
x = f(arg)
876923
results.append(x)
877924
else:
878925
# set up function for multi-threaded processing
879926
f = func_partial(self._extract_classes_multi, self.px_centers)
880-
batched_args = self._generate_batched_args(args)
927+
args = self._generate_batched_args(args)
881928

882929
self.log(f"Running in multiprocessing mode with {self.threads} threads.")
883930
with mp.get_context("fork").Pool(
884931
processes=self.threads
885932
) as pool: # both spawn and fork work but fork is faster so forcing fork here
886933
results = list(
887934
tqdm(
888-
pool.imap(f, batched_args),
889-
total=len(batched_args),
935+
pool.imap(f, args),
936+
total=len(args),
890937
desc="Processing cell batches",
891938
)
892939
)
893940
pool.close()
894941
pool.join()
895-
print("multiprocessing done.")
896942

897943
self.save_index_to_remove = flatten(results)
898944

945+
# cleanup memory and remove any no longer required variables
946+
del results, args
947+
# self._clear_cache(vars_to_delete=["results", "args"]) # this is not working as expected at the moment so need to manually delete the variables
899948
stop_extraction = timeit.default_timer()
900949

901950
# calculate duration
@@ -912,7 +961,6 @@ def process(self, partial=False, n_cells=None, seed=42):
912961
self.DEFAULT_LOG_NAME = "processing.log" # change log name back to default
913962

914963
self._post_extraction_cleanup()
915-
916964
total_time_stop = timeit.default_timer()
917965
total_time = total_time_stop - total_time_start
918966

0 commit comments

Comments
 (0)