Skip to content

Commit c1bc39e

Browse files
author
Andrew Ramirez
committed
Better documentation of funcsions
1 parent caa9143 commit c1bc39e

File tree

5 files changed

+416
-12
lines changed

5 files changed

+416
-12
lines changed

RISE/factorization.py

Lines changed: 118 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,31 @@
1111

1212

1313
def correct_conditions(X: anndata.AnnData):
14-
"""Correct the conditions factors by overall read depth."""
14+
"""Correct the condition factors by normalizing for overall read depth.
15+
16+
This function adjusts condition factors (stored in X.uns["Pf2_A"]) to account for
17+
differences in sequencing depth across conditions. It uses linear regression to
18+
model the relationship between total read counts and condition factor magnitudes,
19+
then applies a correction.
20+
21+
Parameters
22+
----------
23+
X : anndata.AnnData
24+
AnnData object containing RISE decomposition results. Must have:
25+
- X.obs["condition_unique_idxs"]: 0-indexed condition assignments
26+
- X.uns["Pf2_A"]: Condition factors from PARAFAC2 decomposition
27+
28+
Returns
29+
-------
30+
numpy.ndarray
31+
Corrected condition factors normalized by sequencing depth
32+
33+
Examples
34+
--------
35+
>>> from RISE.factorization import pf2, correct_conditions
36+
>>> X = pf2(adata, rank=20)
37+
>>> corrected_factors = correct_conditions(X)
38+
"""
1539
sgIndex = X.obs["condition_unique_idxs"]
1640

1741
counts = np.zeros((np.amax(sgIndex.to_numpy()) + 1, 1))
@@ -39,7 +63,61 @@ def pf2(
3963
tolerance=1e-9,
4064
max_iter: int = 500,
4165
):
42-
"""Run Pf2 model and store results in anndata file"""
66+
"""Perform PARAFAC2 tensor decomposition on single-cell RNA-seq data.
67+
68+
This is the main function for running RISE analysis. It decomposes the
69+
multi-condition single-cell data into condition factors, eigen-state factors,
70+
and gene factors, revealing patterns across experimental conditions.
71+
72+
Parameters
73+
----------
74+
X : anndata.AnnData
75+
Preprocessed AnnData object containing single-cell RNA-seq data.
76+
Must have X.obs["condition_unique_idxs"] indicating which condition
77+
each cell belongs to (0-indexed).
78+
rank : int
79+
Number of components to extract. Determines the complexity of the
80+
decomposition. Typically chosen based on variance explained and
81+
Factor Match Score analysis (see plot_r2x and plot_fms_diff_ranks).
82+
random_state : int, optional (default: 1)
83+
Random seed for reproducibility of the decomposition.
84+
doEmbedding : bool, optional (default: True)
85+
If True, automatically computes PaCMAP embedding of cell projections
86+
and stores in X.obsm["X_pf2_PaCMAP"]. This enables visualization
87+
functions like plot_labels_pacmap.
88+
tolerance : float, optional (default: 1e-9)
89+
Convergence threshold for the optimization algorithm. Lower values
90+
increase precision but may require more iterations.
91+
max_iter : int, optional (default: 500)
92+
Maximum number of iterations for the optimization algorithm.
93+
94+
Returns
95+
-------
96+
anndata.AnnData
97+
The input AnnData object with added RISE decomposition results:
98+
99+
- X.uns["Pf2_weights"]: Component weights (shape: rank,)
100+
- X.uns["Pf2_A"]: Condition factors (shape: n_conditions, rank)
101+
- X.uns["Pf2_B"]: Eigen-state factors (shape: rank, rank)
102+
- X.varm["Pf2_C"]: Gene factors (shape: n_genes, rank)
103+
- X.obsm["projections"]: Cell projections (shape: n_cells, rank)
104+
- X.obsm["weighted_projections"]: Weighted cell projections (shape: n_cells, rank)
105+
- X.obsm["X_pf2_PaCMAP"]: PaCMAP embedding (shape: n_cells, 2) if doEmbedding=True
106+
107+
Examples
108+
--------
109+
>>> from RISE.factorization import pf2
110+
>>> # Perform decomposition with 20 components
111+
>>> X = pf2(adata, rank=20, random_state=42)
112+
>>> # Access results
113+
>>> condition_factors = X.uns["Pf2_A"]
114+
>>> gene_factors = X.varm["Pf2_C"]
115+
116+
See Also
117+
--------
118+
rise_pca_r2x : Compute variance explained for different ranks
119+
plot_fms_diff_ranks : Evaluate factor stability across ranks
120+
"""
43121
pf_out, _ = parafac2_nd(
44122
X, rank=rank, random_state=random_state, tol=tolerance, n_iter_max=max_iter
45123
)
@@ -54,7 +132,44 @@ def pf2(
54132

55133

56134
def rise_pca_r2x(X: anndata.AnnData, ranks):
57-
"""Run RISE/PCA on data and save R2X values"""
135+
"""Compute variance explained (R²X) for RISE and PCA across different ranks.
136+
137+
This function evaluates how much variance in the data is explained by
138+
RISE (PARAFAC2) and PCA decompositions at different component ranks.
139+
Used to determine the optimal number of components for RISE analysis.
140+
141+
Parameters
142+
----------
143+
X : anndata.AnnData
144+
Preprocessed AnnData object containing single-cell RNA-seq data.
145+
Must have X.obs["condition_unique_idxs"] for RISE decomposition.
146+
ranks : array-like of int
147+
Array of rank values to test (e.g., [1, 5, 10, 15, 20, 25, 30]).
148+
Each rank represents a different number of components.
149+
150+
Returns
151+
-------
152+
tuple of numpy.ndarray
153+
(rise_r2x, pca_r2x) where:
154+
155+
- rise_r2x: Variance explained by RISE for each rank (shape: len(ranks),)
156+
- pca_r2x: Variance explained by PCA for each rank (shape: len(ranks),)
157+
158+
Examples
159+
--------
160+
>>> from RISE.factorization import rise_pca_r2x
161+
>>> ranks = [1, 5, 10, 15, 20]
162+
>>> rise_r2x, pca_r2x = rise_pca_r2x(adata, ranks)
163+
>>> # Plot results
164+
>>> import matplotlib.pyplot as plt
165+
>>> plt.plot(ranks, rise_r2x, label='RISE')
166+
>>> plt.plot(ranks, pca_r2x, label='PCA')
167+
168+
See Also
169+
--------
170+
plot_r2x : Convenience function to plot variance explained
171+
pf2 : Perform PARAFAC2 decomposition at chosen rank
172+
"""
58173
X = X.to_memory()
59174
XX = sps.csr_array(X.X)
60175

RISE/figures/commonFuncs/plotFactors.py

Lines changed: 108 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,50 @@ def plot_condition_factors(
2222
color_key=None,
2323
group_cond=False,
2424
):
25-
"""Plots condition factors"""
25+
"""Plot condition factors as a heatmap showing how conditions contribute to components.
26+
27+
This visualization shows how each experimental condition (rows) contributes to
28+
each RISE component (columns). High values indicate strong association between
29+
a condition and a component's pattern. Log transformation and normalization
30+
help reveal relative differences across conditions.
31+
32+
Parameters
33+
----------
34+
data : anndata.AnnData
35+
AnnData object with RISE decomposition results. Must contain:
36+
- data.uns["Pf2_A"]: Condition factors (n_conditions, rank)
37+
- data.obs[cond]: Condition labels for each cell
38+
ax : matplotlib.axes.Axes
39+
Matplotlib axes object to plot on.
40+
cond : str, optional (default: "Condition")
41+
Name of column in data.obs containing condition labels.
42+
log_transform : bool, optional (default: True)
43+
If True, applies log10 transformation to condition factors before plotting.
44+
This helps visualize differences when values span orders of magnitude.
45+
cond_group_labels : pandas.Series, optional (default: None)
46+
Series mapping conditions to group labels for colored row annotations.
47+
Useful for grouping related conditions (e.g., drug classes, patient cohorts).
48+
ThomsonNorm : bool, optional (default: False)
49+
If True, normalizes factors using only control conditions (those containing 'CTRL').
50+
color_key : list, optional (default: None)
51+
Custom colors for condition group labels. If None, uses default palette.
52+
group_cond : bool, optional (default: False)
53+
If True and cond_group_labels provided, sorts conditions by group.
54+
55+
Examples
56+
--------
57+
>>> from RISE.figures.commonFuncs.plotFactors import plot_condition_factors
58+
>>> import matplotlib.pyplot as plt
59+
>>> fig, ax = plt.subplots(figsize=(8, 8))
60+
>>> plot_condition_factors(adata, ax=ax, cond="Condition", log_transform=True)
61+
>>> plt.tight_layout()
62+
>>> plt.show()
63+
64+
See Also
65+
--------
66+
plot_eigenstate_factors : Visualize eigen-state factors
67+
plot_gene_factors : Visualize gene factors
68+
"""
2669
pd.set_option("display.max_rows", None)
2770
yt = pd.Series(np.unique(data.obs[cond]))
2871
X = np.array(data.uns["Pf2_A"])
@@ -94,7 +137,36 @@ def plot_condition_factors(
94137

95138

96139
def plot_eigenstate_factors(data: anndata.AnnData, ax: Axes):
97-
"""Plots Pf2 eigenstate factors"""
140+
"""Plot eigen-state factors as a heatmap showing cell state patterns.
141+
142+
Eigen-state factors represent the underlying cell state patterns across components.
143+
Each row represents an eigen-state (a summary of similar cells), and each column
144+
represents a component. High values indicate strong association between a cell
145+
state pattern and a component.
146+
147+
Parameters
148+
----------
149+
data : anndata.AnnData
150+
AnnData object with RISE decomposition results. Must contain:
151+
- data.uns["Pf2_B"]: Eigen-state factors (rank, rank)
152+
ax : matplotlib.axes.Axes
153+
Matplotlib axes object to plot on.
154+
155+
Examples
156+
--------
157+
>>> from RISE.figures.commonFuncs.plotFactors import plot_eigenstate_factors
158+
>>> import matplotlib.pyplot as plt
159+
>>> fig, ax = plt.subplots(figsize=(4, 4))
160+
>>> plot_eigenstate_factors(adata, ax=ax)
161+
>>> ax.set_ylabel("Eigen-state")
162+
>>> plt.tight_layout()
163+
>>> plt.show()
164+
165+
See Also
166+
--------
167+
plot_condition_factors : Visualize condition factors
168+
plot_gene_factors : Visualize gene factors
169+
"""
98170
rank = data.uns["Pf2_B"].shape[1]
99171
xticks = np.arange(1, rank + 1)
100172
X = data.uns["Pf2_B"]
@@ -115,7 +187,40 @@ def plot_eigenstate_factors(data: anndata.AnnData, ax: Axes):
115187

116188

117189
def plot_gene_factors(data: anndata.AnnData, ax: Axes, weight=0.08, trim=True):
118-
"""Plots Pf2 gene factors"""
190+
"""Plot gene factors as a heatmap showing which genes contribute to each component.
191+
192+
This visualization reveals coordinated gene modules by showing which genes (rows)
193+
are highly weighted in each component (columns). The weight parameter filters out
194+
genes with low contributions, focusing on the most important genes for interpretation.
195+
196+
Parameters
197+
----------
198+
data : anndata.AnnData
199+
AnnData object with RISE decomposition results. Must contain:
200+
- data.varm["Pf2_C"]: Gene factors (n_genes, rank)
201+
ax : matplotlib.axes.Axes
202+
Matplotlib axes object to plot on.
203+
weight : float, optional (default: 0.08)
204+
Minimum absolute weight threshold for including genes. Genes with maximum
205+
absolute weight below this value across all components are filtered out.
206+
Higher values show fewer, more important genes.
207+
trim : bool, optional (default: True)
208+
If True, filters genes based on the weight parameter. If False, shows all genes.
209+
210+
Examples
211+
--------
212+
>>> from RISE.figures.commonFuncs.plotFactors import plot_gene_factors
213+
>>> import matplotlib.pyplot as plt
214+
>>> fig, ax = plt.subplots(figsize=(7, 8))
215+
>>> plot_gene_factors(adata, ax=ax, weight=0.2, trim=True)
216+
>>> plt.tight_layout()
217+
>>> plt.show()
218+
219+
See Also
220+
--------
221+
plot_condition_factors : Visualize condition factors
222+
plot_gene_pacmap : Overlay gene expression on PaCMAP
223+
"""
119224
rank = data.varm["Pf2_C"].shape[1]
120225
X = np.array(data.varm["Pf2_C"])
121226
yt = data.var.index.values

RISE/figures/commonFuncs/plotGeneral.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,38 @@
88

99

1010
def plot_r2x(data, rank_vec, ax: Axes):
11-
"""Creates R2X plot for RISE tensor decomposition and pca"""
11+
\"\"\"Plot variance explained (R²X) for RISE and PCA across different ranks.
12+
13+
This visualization helps determine the optimal number of components by showing
14+
how variance explained increases with rank. The elbow point where the curve
15+
flattens indicates a good balance between model complexity and explanatory power.
16+
17+
Parameters
18+
----------
19+
data : anndata.AnnData
20+
Preprocessed AnnData object containing single-cell RNA-seq data.
21+
Must have X.obs[\"condition_unique_idxs\"] for RISE decomposition.
22+
rank_vec : array-like of int
23+
Array of rank values to test (e.g., [1, 5, 10, 15, 20, 25, 30]).
24+
Each rank represents a different number of components.
25+
ax : matplotlib.axes.Axes
26+
Matplotlib axes object to plot on.
27+
28+
Examples
29+
--------
30+
>>> from RISE.figures.commonFuncs.plotGeneral import plot_r2x
31+
>>> import matplotlib.pyplot as plt
32+
>>> fig, ax = plt.subplots(figsize=(5, 5))
33+
>>> ranks = [1, 5, 10, 15, 20, 25, 30]
34+
>>> plot_r2x(adata, ranks, ax)
35+
>>> plt.tight_layout()
36+
>>> plt.show()
37+
38+
See Also
39+
--------
40+
rise_pca_r2x : Underlying function that computes variance explained
41+
plot_fms_diff_ranks : Evaluate factor stability across ranks
42+
\"\"\"
1243
r2xError = rise_pca_r2x(data, rank_vec)
1344
labelNames = ["Fit: RISE", "Fit: PCA"]
1445
colorDecomp = ["r", "b"]

0 commit comments

Comments
 (0)