Skip to content

Commit 2840769

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 2d778db commit 2840769

File tree

138 files changed

+1108
-1950
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

138 files changed

+1108
-1950
lines changed

dance/atlas/sc_similarity/anndata_similarity.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -231,10 +231,7 @@ def compute_mmd_alternative(self) -> float:
231231
def compute_mmd(self) -> float:
232232
"""Compute Maximum Mean Discrepancy between datasets.
233233
234-
Returns
235-
-------
236-
float
237-
Normalized MMD similarity score between 0 and 1
234+
Returns ------- float Normalized MMD similarity score between 0 and 1
238235
239236
"""
240237
X = self.X
@@ -276,10 +273,8 @@ def data_company(self):
276273
def wasserstein_dist(self) -> float:
277274
"""Compute Wasserstein distance between datasets.
278275
279-
Returns
280-
-------
281-
float
282-
Normalized Wasserstein similarity score between 0 and 1
276+
Returns ------- float Normalized Wasserstein similarity score between 0 and
277+
1
283278
284279
"""
285280
X = self.X
@@ -359,10 +354,7 @@ def get_dataset_meta_sim(self):
359354
"""Compute metadata similarity between datasets based on discrete and continuous
360355
features.
361356
362-
Returns
363-
-------
364-
float
365-
Average similarity score across all metadata features
357+
Returns ------- float Average similarity score across all metadata features
366358
367359
"""
368360
# dis_cols=['assay', 'cell_type', 'development_stage','disease','is_primary_data','self_reported_ethnicity','sex', 'suspension_type', 'tissue','tissue_type', 'tissue_general']

dance/data/base.py

Lines changed: 65 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -204,34 +204,30 @@ def y(self):
204204
def config(self) -> Dict[str, Any]:
205205
"""Return the dance data object configuration dict.
206206
207-
Notes
208-
-----
209-
The configuration dictionary is saved in the ``data`` attribute, which is an :class:`~anndata.AnnData`
210-
object. Inparticular, the config will be saved in the ``.uns`` attribute with the key ``"dance_config"``.
207+
Notes ----- The configuration dictionary is saved in the ``data`` attribute,
208+
which is an :class:`~anndata.AnnData` object. Inparticular, the config will be
209+
saved in the ``.uns`` attribute with the key ``"dance_config"``.
211210
212211
"""
213212
return self._data.uns["dance_config"]
214213

215214
def set_config(self, *, overwrite: bool = False, **kwargs):
216215
"""Set dance data object configuration.
217216
218-
See
219-
:meth: `~BaseData.set_config_from_dict`.
217+
See :meth: `~BaseData.set_config_from_dict`.
220218
221219
"""
222220
self.set_config_from_dict(kwargs, overwrite=overwrite)
223221

224222
def set_config_from_dict(self, config_dict: Dict[str, Any], *, overwrite: bool = False):
225223
"""Set dance data object configuration from a config dict.
226224
227-
Parameters
228-
----------
229-
config_dict
230-
Configuration dictionary.
231-
overwrite
232-
Used to determine the behaviour of resolving config conflicts. In the case of a conflict, where the config
233-
dict passed contains a key with value that differs from an existing setting, if ``overwrite`` is set to
234-
``False``, then raise a ``KeyError``. Otherwise, overwrite the configuration with the new values.
225+
Parameters ---------- config_dict Configuration dictionary. overwrite
226+
Used to determine the behaviour of resolving config conflicts. In the case of a
227+
conflict, where the config dict passed contains a key with value that
228+
differs from an existing setting, if ``overwrite`` is set to ``False``, then
229+
raise a ``KeyError``. Otherwise, overwrite the configuration with the new
230+
values.
235231
236232
"""
237233
# Check config key validity
@@ -304,29 +300,20 @@ def copy(self):
304300
def set_split_idx(self, split_name: str, split_idx: Sequence[int]):
305301
"""Set cell indices for a particular split.
306302
307-
Parameters
308-
----------
309-
split_name
310-
Name of the split to set.
311-
split_idx
312-
Indices to be used in this split.
303+
Parameters ---------- split_name Name of the split to set. split_idx
304+
Indices to be used in this split.
313305
314306
"""
315307
self._split_idx_dict[split_name] = split_idx
316308

317309
def get_split_idx(self, split_name: str, error_on_miss: bool = False):
318310
"""Obtain cell indices for a particular split.
319311
320-
Parameters
321-
----------
322-
split_name
323-
Name of the split to retrieve.
324-
error_on_miss
325-
If set to True, raise KeyError if the queried split does not exit, otherwise return None.
312+
Parameters ---------- split_name Name of the split to retrieve.
313+
error_on_miss If set to True, raise KeyError if the queried split does not
314+
exit, otherwise return None.
326315
327-
See Also
328-
--------
329-
:meth:`~get_split_mask`
316+
See Also -------- :meth:`~get_split_mask`
330317
331318
"""
332319
if split_name is None:
@@ -341,12 +328,8 @@ def get_split_idx(self, split_name: str, error_on_miss: bool = False):
341328
def get_split_mask(self, split_name: str, return_type: FeatType = "numpy") -> Union[np.ndarray, torch.Tensor]:
342329
"""Obtain mask representation of a particular split.
343330
344-
Parameters
345-
----------
346-
split_name
347-
Name of the split to retrieve.
348-
return_type
349-
Return numpy array if set to 'numpy', or torch Tensor if set to 'torch'.
331+
Parameters ---------- split_name Name of the split to retrieve. return_type
332+
Return numpy array if set to 'numpy', or torch Tensor if set to 'torch'.
350333
351334
"""
352335
split_idx = self.get_split_idx(split_name, error_on_miss=True)
@@ -362,10 +345,7 @@ def get_split_mask(self, split_name: str, return_type: FeatType = "numpy") -> Un
362345
def get_split_data(self, split_name: str) -> Union[anndata.AnnData, mudata.MuData]:
363346
"""Obtain the underlying data of a particular split.
364347
365-
Parameters
366-
----------
367-
split_name
368-
Name of the split to retrieve.
348+
Parameters ---------- split_name Name of the split to retrieve.
369349
370350
"""
371351
split_idx = self.get_split_idx(split_name, error_on_miss=True)
@@ -417,24 +397,20 @@ def get_feature(self, *, split_name: Optional[str] = None, return_type: FeatType
417397
mod: Optional[str] = None): # yapf: disable
418398
"""Retrieve features from data.
419399
420-
Parameters
421-
----------
422-
split_name
423-
Name of the split to retrieve. If not set, return all.
424-
return_type
425-
How should the features be returned. **sparse**: return as a sparse matrix; **numpy**: return as a numpy
426-
array; **torch**: return as a torch tensor; **anndata**: return as an anndata object.
427-
channel
428-
Return a particular channel as features. If ``channel_type`` is ``X`` or ``raw_X``, then return ``.X`` or
429-
the ``.raw.X`` attribute from the :class:`~anndata.AnnData` directly. If ``channel_type`` is ``obs``, return
430-
the column named by ``channel``, similarly for ``var``. Finally, if ``channel_type`` is ``obsm``, ``obsp``,
431-
``varm``, ``varp``, ``layers``, or ``uns``, then return the value correspond to the ``channel`` in the
432-
dictionary.
433-
channel_type
434-
Channel type to use, default to ``obsm`` (will be changed to ``X`` in the near future).
435-
mod
436-
Modality to use, default to ``None``. Options other than ``None`` are only available when the underlying
437-
data object is :class:`~mudata.Mudata`.
400+
Parameters ---------- split_name Name of the split to retrieve. If not set,
401+
return all. return_type How should the features be returned. **sparse**:
402+
return as a sparse matrix; **numpy**: return as a numpy array; **torch**:
403+
return as a torch tensor; **anndata**: return as an anndata object. channel
404+
Return a particular channel as features. If ``channel_type`` is ``X`` or
405+
``raw_X``, then return ``.X`` or the ``.raw.X`` attribute from the
406+
:class:`~anndata.AnnData` directly. If ``channel_type`` is ``obs``, return
407+
the column named by ``channel``, similarly for ``var``. Finally, if
408+
``channel_type`` is ``obsm``, ``obsp``, ``varm``, ``varp``, ``layers``, or
409+
``uns``, then return the value correspond to the ``channel`` in the
410+
dictionary. channel_type Channel type to use, default to ``obsm`` (will be
411+
changed to ``X`` in the near future). mod Modality to use, default to
412+
``None``. Options other than ``None`` are only available when the underlying
413+
data object is :class:`~mudata.Mudata`.
438414
439415
"""
440416
feature = self._get_feature(self.data, channel, channel_type, mod)
@@ -486,27 +462,22 @@ def append(
486462
):
487463
"""Append another dance data object to the current data object.
488464
489-
Parameters
490-
----------
491-
data
492-
New dance data object to be added.
493-
mode
494-
How to combine the splits from the new data and the current data. (1) ``"merge"``: merge the splits from
495-
the data, e.g., the training indexes from both data are used as the training indexes in the new combined
496-
data. (2) ``"rename"``: rename the splits of the new data and add to the current split index dictionary,
497-
e.g., renaming 'train' to 'ref'. Requires passing the ``rename_dict``. Raise an error if the newly renamed
498-
key is already used in the current split index dictionary. (3) ``"new_split"``: assign the whole new data
499-
to a new split. Requires pssing the ``new_split_name`` that is not already used as a split name in the
500-
current data. (4) ``None``: do not specify split index to the newly added data.
501-
rename_dict
502-
Optional argument that is only used when ``mode="rename"``. A dictionary to map the split names in the new
503-
data to other names.
504-
new_split_name
505-
Optional argument that is only used when ``mode="new_split"``. Name of the split to assign to the new data.
506-
label_batch
507-
Add "batch" column to ``.obs`` when set to True.
508-
**concat_kwargs
509-
See :meth:`anndata.concat`.
465+
Parameters ---------- data New dance data object to be added. mode How
466+
to combine the splits from the new data and the current data. (1) ``"merge"``:
467+
merge the splits from the data, e.g., the training indexes from both data
468+
are used as the training indexes in the new combined data. (2) ``"rename"``:
469+
rename the splits of the new data and add to the current split index dictionary,
470+
e.g., renaming 'train' to 'ref'. Requires passing the ``rename_dict``. Raise an
471+
error if the newly renamed key is already used in the current split index
472+
dictionary. (3) ``"new_split"``: assign the whole new data to a new split.
473+
Requires pssing the ``new_split_name`` that is not already used as a split name
474+
in the current data. (4) ``None``: do not specify split index to the newly
475+
added data. rename_dict Optional argument that is only used when
476+
``mode="rename"``. A dictionary to map the split names in the new data to
477+
other names. new_split_name Optional argument that is only used when
478+
``mode="new_split"``. Name of the split to assign to the new data. label_batch
479+
Add "batch" column to ``.obs`` when set to True. **concat_kwargs See
480+
:meth:`anndata.concat`.
510481
511482
"""
512483
offset = self.shape[0]
@@ -580,29 +551,21 @@ def pop(self, *, split_name: str):
580551
def filter_cells(self, **kwargs):
581552
"""Apply cell filtering using scanpy.pp.filter_cells and update splits.
582553
583-
Filters the cells in `self.data` based on the provided criteria,
584-
similar to `scanpy.pp.filter_cells`. Crucially, this method also
585-
updates the internal split indices (`train_idx`, `val_idx`, etc.)
586-
to reflect the cells remaining after filtering.
554+
Filters the cells in `self.data` based on the provided criteria, similar to
555+
`scanpy.pp.filter_cells`. Crucially, this method also updates the internal split
556+
indices (`train_idx`, `val_idx`, etc.) to reflect the cells remaining after
557+
filtering.
587558
588-
Parameters
589-
----------
590-
**kwargs
591-
Arguments passed directly to `scanpy.pp.filter_cells`.
592-
Common arguments include `min_counts`, `max_counts`,
593-
`min_genes`, `max_genes`. Note: `inplace` is forced to `False`
594-
internally to get the filter mask, then applied effectively inplace.
559+
Parameters ---------- **kwargs Arguments passed directly to
560+
`scanpy.pp.filter_cells`. Common arguments include `min_counts`,
561+
`max_counts`, `min_genes`, `max_genes`. Note: `inplace` is forced to `False`
562+
internally to get the filter mask, then applied effectively inplace.
595563
596-
Returns
597-
-------
598-
self
599-
Returns the instance to allow method chaining.
564+
Returns ------- self Returns the instance to allow method chaining.
600565
601-
Raises
602-
------
603-
NotImplementedError
604-
If the underlying `self.data` is not an `anndata.AnnData` object.
605-
Filtering `MuData` requires more careful consideration of modalities.
566+
Raises ------ NotImplementedError If the underlying `self.data` is not an
567+
`anndata.AnnData` object. Filtering `MuData` requires more careful
568+
consideration of modalities.
606569
607570
"""
608571
if not isinstance(self.data, anndata.AnnData):
@@ -856,13 +819,10 @@ def get_data(
856819
) -> Tuple[Any, Any]:
857820
"""Retrieve cell features and labels from a particular split.
858821
859-
Parameters
860-
----------
861-
split_name
862-
Name of the split to retrieve. If not set, return all.
863-
return_type
864-
How should the features be returned. **numpy**: return as a numpy array; **torch**: return as a torch
865-
tensor; **anndata**: return as an anndata object.
822+
Parameters ---------- split_name Name of the split to retrieve. If not set,
823+
return all. return_type How should the features be returned. **numpy**:
824+
return as a numpy array; **torch**: return as a torch tensor; **anndata**:
825+
return as an anndata object.
866826
867827
"""
868828
x = self.get_x(split_name, return_type, **x_kwargs)

dance/datasets/base.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,8 @@
1515
class BaseDataset(ABC):
1616
"""BaseDataset abstract object.
1717
18-
Parameters
19-
----------
20-
root
21-
Root directory of the dataset.
22-
full_download
23-
If set to ``True``, then attempt to download all raw files of the dataset.
18+
Parameters ---------- root Root directory of the dataset. full_download If
19+
set to ``True``, then attempt to download all raw files of the dataset.
2420
2521
"""
2622

dance/datasets/singlemodality.py

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -375,16 +375,10 @@ def _load_dfs(paths: List[str], *, index_col: Optional[int] = 0, transpose: bool
375375
def get_map_dict(map_file_path: str, tissue: str) -> Dict[str, Set[str]]:
376376
"""Load cell-type mappings.
377377
378-
Parameters
379-
----------
380-
map_file_path
381-
Path to the mapping file.
382-
tissue
383-
Tissue of interest.
378+
Parameters ---------- map_file_path Path to the mapping file. tissue
379+
Tissue of interest.
384380
385-
Notes
386-
-----
387-
Merge mapping across all test sets for the required tissue.
381+
Notes ----- Merge mapping across all test sets for the required tissue.
388382
389383
"""
390384
map_df = pd.read_excel(osp.join(map_file_path, "map.xlsx"))
@@ -399,12 +393,9 @@ def get_map_dict(map_file_path: str, tissue: str) -> Dict[str, Set[str]]:
399393
class ClusteringDataset(BaseDataset):
400394
"""Data downloading and loading for clustering.
401395
402-
Parameters
403-
----------
404-
data_dir
405-
Path to store datasets.
406-
dataset
407-
Choice of dataset. Available options are '10X_PBMC', 'mouse_bladder_cell', 'mouse_ES_cell', 'worm_neuron_cell'.
396+
Parameters ---------- data_dir Path to store datasets. dataset Choice of
397+
dataset. Available options are '10X_PBMC', 'mouse_bladder_cell', 'mouse_ES_cell',
398+
'worm_neuron_cell'.
408399
409400
"""
410401

dance/datasets/spatial.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -161,11 +161,9 @@ def _raw_to_dance(self, raw_data):
161161
class CellTypeDeconvoDataset(BaseDataset):
162162
"""Load raw data.
163163
164-
Parameters
165-
----------
166-
subset_common_celltypes
167-
If set to True, then subset both the reference and the real data to contain only cell types that are
168-
present in both reference and real.
164+
Parameters ---------- subset_common_celltypes If set to True, then subset both
165+
the reference and the real data to contain only cell types that are present in
166+
both reference and real.
169167
170168
"""
171169

dance/models/nn/gnn.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,18 +18,18 @@ class AdaptiveSAGE(nn.Module):
1818
Parameters
1919
----------
2020
dim_in
21-
Input feature dimensions.
21+
Input feature dimensions.
2222
dim_out
23-
output feature dimensions.
23+
output feature dimensions.
2424
alpha
25-
Shared learnable parameters containing gene-cell interaction strengths and those for the cell and gene
26-
self-loops.
25+
Shared learnable parameters containing gene-cell interaction strengths and those for the cell and gene
26+
self-loops.
2727
dropout_layer
28-
Dropout layer.
28+
Dropout layer.
2929
act_layer
30-
Activation layer.
30+
Activation layer.
3131
norm_layer
32-
Normalization layer.
32+
Normalization layer.
3333
3434
Note
3535
----
@@ -62,10 +62,11 @@ def __init__(
6262
def message_func(self, edges):
6363
"""Message update function.
6464
65-
Reweight messages based on 1) the shared learnable interaction strengths and 2) the underlying edgeweights of
66-
the graph. In particular, for 1), gene-cell interaction (undirectional) will be weighted by the gene specific
67-
``beta`` value, and the cell and gene self-interactions will be weighted based on the corresponding ``alpha``
68-
values.
65+
Reweight messages based on 1) the shared learnable interaction strengths and 2)
66+
the underlying edgeweights of the graph. In particular, for 1), gene-cell
67+
interaction (undirectional) will be weighted by the gene specific ``beta``
68+
value, and the cell and gene self-interactions will be weighted based on the
69+
corresponding ``alpha`` values.
6970
7071
"""
7172
number_of_edges = edges.src["h"].shape[0]

0 commit comments

Comments
 (0)