Skip to content

Commit 74bf9db

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 32f8a0e commit 74bf9db

File tree

19 files changed

+147
-151
lines changed

19 files changed

+147
-151
lines changed

examples/code_snippets/plotting/plot_rgb_single_cell_image_grid.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,11 @@
66

77
import matplotlib.pyplot as plt
88
import numpy as np
9-
10-
from scportrait.data._single_cell_images import dataset2_h5sc
119
from scportrait.pl.h5sc import _plot_image_grid
1210
from scportrait.pl.vis import colorize
1311

12+
from scportrait.data._single_cell_images import dataset2_h5sc
13+
1414
# get dataset
1515
h5sc = dataset2_h5sc()
1616

examples/code_snippets/plotting/plot_rgb_single_cell_images.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44
"""
55

66
import matplotlib.pyplot as plt
7+
from scportrait.pl.vis import generate_composite
78

89
from scportrait.data._single_cell_images import dataset2_h5sc
9-
from scportrait.pl.vis import generate_composite
1010

1111
# select images you want to plot and colorize
1212
h5sc = dataset2_h5sc()

src/scportrait/pipeline/_utils/sdata_io.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -445,9 +445,9 @@ def _load_seg_to_memmap(
445445
"""
446446
_sdata = self._check_sdata_status(return_sdata=True)
447447

448-
assert all(
449-
seg in _sdata.labels for seg in seg_name
450-
), "Not all passed segmentation elements found in sdata object."
448+
assert all(seg in _sdata.labels for seg in seg_name), (
449+
"Not all passed segmentation elements found in sdata object."
450+
)
451451

452452
seg_objects = [_sdata.labels[seg] for seg in seg_name]
453453

src/scportrait/pipeline/extraction.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -116,9 +116,9 @@ def _check_config(self):
116116

117117
if normalization_range is not None:
118118
assert len(normalization_range) == 2, "Normalization range must be a tuple or list of length 2."
119-
assert all(
120-
isinstance(x, float | int) and (0 <= x <= 1) for x in normalization_range
121-
), "Normalization range must be defined as a float between 0 and 1."
119+
assert all(isinstance(x, float | int) and (0 <= x <= 1) for x in normalization_range), (
120+
"Normalization range must be defined as a float between 0 and 1."
121+
)
122122

123123
# conver to tuple to ensure consistency
124124
if isinstance(normalization_range, list):
@@ -388,9 +388,9 @@ def _get_centers(self) -> None:
388388
self.centers_cell_ids = _sdata[centers_name].index.values.compute()
389389

390390
# ensure that the centers ids are unique
391-
assert len(self.centers_cell_ids) == len(
392-
set(self.centers_cell_ids)
393-
), "Cell ids in centers are not unique. Cannot proceed with extraction."
391+
assert len(self.centers_cell_ids) == len(set(self.centers_cell_ids)), (
392+
"Cell ids in centers are not unique. Cannot proceed with extraction."
393+
)
394394

395395
# double check that the cell_ids contained in the seg masks match to those from centers
396396
# THIS NEEDS TO BE IMPLEMENTED HERE

src/scportrait/pipeline/featurization.py

Lines changed: 35 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -128,9 +128,9 @@ def _setup_channel_selection(self) -> None:
128128
if "channel_selection" in self.config.keys():
129129
channel_selection = self.config["channel_selection"]
130130
if isinstance(channel_selection, list):
131-
assert all(
132-
isinstance(x, int) for x in channel_selection
133-
), "channel_selection should be a list of integers"
131+
assert all(isinstance(x, int) for x in channel_selection), (
132+
"channel_selection should be a list of integers"
133+
)
134134
self.channel_selection = channel_selection
135135

136136
elif isinstance(channel_selection, int):
@@ -212,21 +212,21 @@ def _get_single_cell_datafile_specs(self) -> None:
212212

213213
# check to ensure that metadata that must be consistent between datasets is
214214
assert (x == n_masks[0] for x in n_masks), "number of masks are not consistent over all passed inputfiles."
215-
assert (
216-
x == n_channels[0] for x in n_channels
217-
), "number of channels are not consistent over all passed input files."
218-
assert (
219-
x == n_image_channels[0] for x in n_image_channels
220-
), "number of image channels are not consistent over all passed input files."
221-
assert (
222-
x == channel_mapping[0] for x in channel_mapping
223-
), "channel mapping is not consistent over all passed input files."
224-
assert (
225-
x == channel_names[0] for x in channel_names
226-
), "channel names are not consistent over all passed input files."
227-
assert (
228-
x == segmentation_channel[0] for x in segmentation_channel
229-
), "segmentation channel is not consistent over all passed input files."
215+
assert (x == n_channels[0] for x in n_channels), (
216+
"number of channels are not consistent over all passed input files."
217+
)
218+
assert (x == n_image_channels[0] for x in n_image_channels), (
219+
"number of image channels are not consistent over all passed input files."
220+
)
221+
assert (x == channel_mapping[0] for x in channel_mapping), (
222+
"channel mapping is not consistent over all passed input files."
223+
)
224+
assert (x == channel_names[0] for x in channel_names), (
225+
"channel names are not consistent over all passed input files."
226+
)
227+
assert (x == segmentation_channel[0] for x in segmentation_channel), (
228+
"segmentation channel is not consistent over all passed input files."
229+
)
230230

231231
# set variable names after assertions have passed to the first instance of each value
232232
self.n_masks = n_masks[0]
@@ -567,25 +567,25 @@ def generate_dataloader(
567567
# generate dataset
568568
self.log(f"Reading data from path: {dataset_paths}")
569569

570-
assert isinstance(
571-
self.transforms, transforms.Compose
572-
), f"Transforms should be a torchvision.transforms.Compose object but recieved {self.transforms.__class__} instead."
570+
assert isinstance(self.transforms, transforms.Compose), (
571+
f"Transforms should be a torchvision.transforms.Compose object but recieved {self.transforms.__class__} instead."
572+
)
573573
t = self.transforms
574574

575575
if self.expected_imagesize is not None:
576576
self.log(f"Expected image size is set to {self.expected_imagesize}. Resizing images to this size.")
577577
t = transforms.Compose([t, transforms.Resize(self.expected_imagesize)])
578578

579579
if isinstance(dataset_paths, list):
580-
assert isinstance(
581-
dataset_labels, list
582-
), "If multiple directories are provided, multiple labels must be provided."
580+
assert isinstance(dataset_labels, list), (
581+
"If multiple directories are provided, multiple labels must be provided."
582+
)
583583
paths = dataset_paths
584584
dataset_labels = dataset_labels
585585
elif isinstance(dataset_paths, str):
586-
assert isinstance(
587-
dataset_labels, int
588-
), "If only one directory is provided, only one label must be provided."
586+
assert isinstance(dataset_labels, int), (
587+
"If only one directory is provided, only one label must be provided."
588+
)
589589
paths = [dataset_paths]
590590
dataset_labels = [dataset_labels]
591591

@@ -923,9 +923,9 @@ def __init__(self, *args, **kwargs):
923923
self._clean_log_file()
924924

925925
# checks for additional essential parameters in the config file
926-
assert (
927-
self.label is not None
928-
), "'label' must be specified in the config file. This is the label used to save the results."
926+
assert self.label is not None, (
927+
"'label' must be specified in the config file. This is the label used to save the results."
928+
)
929929

930930
def _get_network_dir(self) -> pl.LightningModule:
931931
if self.network_dir in self.PRETRAINED_MODEL_NAMES:
@@ -1107,7 +1107,7 @@ def process(
11071107
path = os.path.join(self.run_path, f"{output_name}.csv")
11081108

11091109
# self._write_results_csv(results, path)
1110-
self._write_results_sdata(results, label=f"{self.__class__.__name__ }_{self.label}_{model.__name__}")
1110+
self._write_results_sdata(results, label=f"{self.__class__.__name__}_{self.label}_{model.__name__}")
11111111
else:
11121112
all_results.append(results)
11131113

@@ -1143,9 +1143,9 @@ def __init__(self, *args, **kwargs):
11431143
self._clean_log_file()
11441144

11451145
# checks for additional essential parameters in the config file
1146-
assert (
1147-
self.label is not None
1148-
), "'label' musst be specified in the config file. This is the label used to save the results."
1146+
assert self.label is not None, (
1147+
"'label' musst be specified in the config file. This is the label used to save the results."
1148+
)
11491149

11501150
def _setup_transforms(self):
11511151
if self.transforms is not None:
@@ -1273,7 +1273,7 @@ class based on the previous single-cell extraction. Therefore, no parameters nee
12731273
os.path.join(self.run_path, f"{output_name}.csv")
12741274

12751275
# self._write_results_csv(results, path)
1276-
self._write_results_sdata(results, label=f"{self.__class__.__name__ }_{model_name}")
1276+
self._write_results_sdata(results, label=f"{self.__class__.__name__}_{model_name}")
12771277
else:
12781278
all_results[model_name] = results
12791279

src/scportrait/pipeline/project.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -720,15 +720,15 @@ def plot_input_image(
720720

721721
if channels is not None:
722722
if isinstance(channels[0], int):
723-
assert all(
724-
x in range(c) for x in channels
725-
), "The specified channel indices are not found in the spatialdata object."
723+
assert all(x in range(c) for x in channels), (
724+
"The specified channel indices are not found in the spatialdata object."
725+
)
726726
valid_channels = [i for i in channels if isinstance(i, int)]
727727
channel_names = [channel_names[i] for i in valid_channels]
728728
if isinstance(channels[0], str):
729-
assert all(
730-
x in channel_names for x in channels
731-
), "The specified channel names are not found in the spatialdata object."
729+
assert all(x in channel_names for x in channels), (
730+
"The specified channel names are not found in the spatialdata object."
731+
)
732732
channel_names = channels
733733

734734
c = len(channels)
@@ -1358,9 +1358,9 @@ def load_input_from_dask(self, dask_array, channel_names: list[str], overwrite:
13581358

13591359
self._cleanup_sdata_object()
13601360

1361-
assert (
1362-
len(channel_names) == dask_array.shape[0]
1363-
), "Number of channel names does not match number of input images."
1361+
assert len(channel_names) == dask_array.shape[0], (
1362+
"Number of channel names does not match number of input images."
1363+
)
13641364

13651365
self.channel_names = channel_names
13661366

@@ -1646,9 +1646,9 @@ def featurize(
16461646

16471647
# check that prerequisits are fullfilled to featurize cells
16481648
assert self.featurization_f is not None, "No featurization method defined."
1649-
assert (
1650-
self.nuc_seg_status or self.cyto_seg_status
1651-
), "No nucleus or cytosol segmentation loaded. Please load a segmentation first."
1649+
assert self.nuc_seg_status or self.cyto_seg_status, (
1650+
"No nucleus or cytosol segmentation loaded. Please load a segmentation first."
1651+
)
16521652
assert self.extraction_status, "No single cell data extracted. Please extract single cell data first."
16531653

16541654
extraction_dir = self.extraction_f.get_directory()

src/scportrait/pipeline/segmentation/segmentation.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1145,9 +1145,9 @@ def complete_segmentation(self, input_image, force_run=False):
11451145
sharding_plan = self._get_sharding_plan(overwrite=False, force_read=True)
11461146

11471147
# check to make sure that calculated sharding plan matches to existing sharding results
1148-
assert (
1149-
len(sharding_plan) == len(tile_directories)
1150-
), "Calculated a different number of shards than found shard directories. This indicates a mismatch between the current loaded config file and the config file used to generate the exisiting partial segmentation. Please rerun the complete segmentation to ensure accurate results."
1148+
assert len(sharding_plan) == len(tile_directories), (
1149+
"Calculated a different number of shards than found shard directories. This indicates a mismatch between the current loaded config file and the config file used to generate the exisiting partial segmentation. Please rerun the complete segmentation to ensure accurate results."
1150+
)
11511151

11521152
# select only those shards that did not complete successfully for further processing
11531153
sharding_plan_complete = sharding_plan

src/scportrait/pipeline/segmentation/workflows/_base_segmentation_workflow.py

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -36,25 +36,25 @@ def _setup_maximum_intensity_projection(self):
3636
# check if channels that should be maximum intensity projected and combined are defined in the config
3737
if "combine_cytosol_channels" in self.config.keys():
3838
self.combine_cytosol_channels = self.config["combine_cytosol_channels"]
39-
assert isinstance(
40-
self.combine_cytosol_channels, list
41-
), "combine_cytosol_channels must be a list of integers specifying the indexes of the channels to combine."
42-
assert (
43-
len(self.combine_cytosol_channels) > 1
44-
), "combine_cytosol_channels must contain at least two integers specifying the indexes of the channels to combine."
39+
assert isinstance(self.combine_cytosol_channels, list), (
40+
"combine_cytosol_channels must be a list of integers specifying the indexes of the channels to combine."
41+
)
42+
assert len(self.combine_cytosol_channels) > 1, (
43+
"combine_cytosol_channels must contain at least two integers specifying the indexes of the channels to combine."
44+
)
4545
self.maximum_project_cytosol = True
4646
else:
4747
self.combine_cytosol_channels = None
4848
self.maximum_project_cytosol = False
4949

5050
if "combine_nucleus_channels" in self.config.keys():
5151
self.combine_nucleus_channels = self.config["combine_nucleus_channels"]
52-
assert isinstance(
53-
self.combine_nucleus_channels, list
54-
), "combine_nucleus_channels must be a list of integers specifying the indexes of the channels to combine."
55-
assert (
56-
len(self.combine_nucleus_channels) > 1
57-
), "combine_nucleus_channels must contain at least two integers specifying the indexes of the channels to combine."
52+
assert isinstance(self.combine_nucleus_channels, list), (
53+
"combine_nucleus_channels must be a list of integers specifying the indexes of the channels to combine."
54+
)
55+
assert len(self.combine_nucleus_channels) > 1, (
56+
"combine_nucleus_channels must contain at least two integers specifying the indexes of the channels to combine."
57+
)
5858
self.maximum_project_nucleus = True
5959
else:
6060
self.combine_nucleus_channels = None
@@ -88,14 +88,14 @@ def _define_channels_to_extract_for_segmentation(self):
8888

8989
# check validity of resulting list of segmentation channels
9090
assert len(self.segmentation_channels) > 0, "No segmentation channels specified in config file."
91-
assert (
92-
len(self.segmentation_channels) >= self.N_INPUT_CHANNELS
93-
), f"Fewer segmentation channels {self.segmentation_channels} provided than expected by segmentation method {self.N_INPUT_CHANNELS}."
91+
assert len(self.segmentation_channels) >= self.N_INPUT_CHANNELS, (
92+
f"Fewer segmentation channels {self.segmentation_channels} provided than expected by segmentation method {self.N_INPUT_CHANNELS}."
93+
)
9494

9595
if len(self.segmentation_channels) > self.N_INPUT_CHANNELS:
96-
assert (
97-
self.maximum_project_nucleus or self.maximum_project_cytosol
98-
), "More input channels provided than accepted by the segmentation method and no maximum intensity projection performed on any of the input values."
96+
assert self.maximum_project_nucleus or self.maximum_project_cytosol, (
97+
"More input channels provided than accepted by the segmentation method and no maximum intensity projection performed on any of the input values."
98+
)
9999

100100
def _remap_maximum_intensity_projection_channels(self):
101101
"""After selecting channels that are passed to the segmentation update indexes of the channels for maximum intensity projection so that they reflect the provided image subset"""
@@ -153,9 +153,9 @@ def _transform_input_image(self, input_image):
153153

154154
input_image = np.vstack(values)
155155

156-
assert (
157-
input_image.shape[0] == self.N_INPUT_CHANNELS
158-
), f"Number of channels in input image {input_image.shape[0]} does not match the number of channels expected by segmentation method {self.N_INPUT_CHANNELS}."
156+
assert input_image.shape[0] == self.N_INPUT_CHANNELS, (
157+
f"Number of channels in input image {input_image.shape[0]} does not match the number of channels expected by segmentation method {self.N_INPUT_CHANNELS}."
158+
)
159159

160160
stop_transform = timeit.default_timer()
161161
self.transform_time = stop_transform - start_transform
@@ -474,9 +474,9 @@ def _check_for_size_filtering(self, mask_types: list[str]) -> None:
474474
If size filtering is turned on, the thresholds for filtering are loaded from the config file.
475475
"""
476476

477-
assert all(
478-
mask_type in self.MASK_NAMES for mask_type in mask_types
479-
), f"mask_types must be a list of strings that are valid mask names {self.MASK_NAMES}."
477+
assert all(mask_type in self.MASK_NAMES for mask_type in mask_types), (
478+
f"mask_types must be a list of strings that are valid mask names {self.MASK_NAMES}."
479+
)
480480

481481
if "filter_masks_size" in self.config.keys():
482482
self.filter_size = self.config["filter_masks_size"]
@@ -648,9 +648,9 @@ def _check_for_mask_matching_filtering(self) -> None:
648648
# sanity check provided values
649649
assert isinstance(self.filter_match_masks, bool), "`match_masks` must be a boolean value."
650650
if self.filter_match_masks:
651-
assert isinstance(
652-
self.mask_matching_filtering_threshold, float
653-
), "`filtering_threshold_mask_matching` for mask matching must be a float."
651+
assert isinstance(self.mask_matching_filtering_threshold, float), (
652+
"`filtering_threshold_mask_matching` for mask matching must be a float."
653+
)
654654

655655
def _perform_mask_matching_filtering(
656656
self,

src/scportrait/pipeline/selection.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,9 +89,9 @@ def _setup_selection(self):
8989
self.savepath = os.path.join(self.directory, savename)
9090

9191
# check that the segmentation label exists
92-
assert (
93-
self.segmentation_channel_to_select in self.project.filehandler.get_sdata()._shared_keys
94-
), f"Segmentation channel {self.segmentation_channel_to_select} not found in sdata."
92+
assert self.segmentation_channel_to_select in self.project.filehandler.get_sdata()._shared_keys, (
93+
f"Segmentation channel {self.segmentation_channel_to_select} not found in sdata."
94+
)
9595

9696
def __get_coords(
9797
self, cell_ids: list, centers: list[tuple[int, int]], width: int = 60

src/scportrait/plotting/h5sc.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -298,9 +298,9 @@ def cell_grid_multi_channel(
298298
_cell_ids = adata.obs[DEFAULT_CELL_ID_NAME].sample(n_cells).values
299299

300300
if row_labels is not None:
301-
assert (
302-
show_cell_id is False
303-
), "If manually providing row_labels, can not automatically annotate rows with cell IDs. Set `show_cell_id` to False."
301+
assert show_cell_id is False, (
302+
"If manually providing row_labels, can not automatically annotate rows with cell IDs. Set `show_cell_id` to False."
303+
)
304304
assert len(row_labels) == n_cells, "Length of `row_labels` must match the number of cells to be visualized."
305305
else:
306306
row_labels = [f"cell ID {_id}" for _id in _cell_ids] if show_cell_id else None

0 commit comments

Comments
 (0)