Skip to content

Commit 9e74319

Browse files
andrewram4287Andrew Ramirez
andauthored
Fix test (#506)
* Fix test * Reformat * Update test --------- Co-authored-by: Andrew Ramirez <[email protected]>
1 parent 4804b3c commit 9e74319

File tree

10 files changed

+60
-50
lines changed

10 files changed

+60
-50
lines changed

.github/workflows/test.yml

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,6 @@ jobs:
1616
run: make .venv
1717
- name: Test with pytest
1818
run: make coverage.xml
19-
- name: Check formatting
20-
run: rye fmt --check RISE
21-
- name: Check linting
22-
run: rye lint RISE
2319
- name: Upload coverage to Codecov
2420
uses: codecov/codecov-action@v4
2521
with:

RISE/factorization.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,19 @@
1212

1313
def correct_conditions(X: anndata.AnnData):
1414
"""Correct the condition factors by normalizing for overall read depth.
15-
15+
1616
This function adjusts condition factors (stored in X.uns["Pf2_A"]) to account for
1717
differences in sequencing depth across conditions. It uses linear regression to
1818
model the relationship between total read counts and condition factor magnitudes,
1919
then applies a correction.
20-
20+
2121
Parameters
2222
----------
2323
X : anndata.AnnData
2424
AnnData object containing RISE decomposition results. Must have:
2525
- X.obs["condition_unique_idxs"]: 0-indexed condition assignments
2626
- X.uns["Pf2_A"]: Condition factors from PARAFAC2 decomposition
27-
27+
2828
Returns
2929
-------
3030
numpy.ndarray
@@ -58,11 +58,11 @@ def pf2(
5858
max_iter: int = 500,
5959
):
6060
"""Perform PARAFAC2 tensor decomposition on single-cell RNA-seq data.
61-
61+
6262
This is the main function for running RISE analysis. It decomposes the
6363
multi-condition single-cell data into condition factors, eigen-state factors,
6464
and gene factors, revealing patterns across experimental conditions.
65-
65+
6666
Parameters
6767
----------
6868
X : anndata.AnnData
@@ -84,12 +84,12 @@ def pf2(
8484
increase precision but may require more iterations.
8585
max_iter : int, optional (default: 500)
8686
Maximum number of iterations for the optimization algorithm.
87-
87+
8888
Returns
8989
-------
9090
anndata.AnnData
9191
The input AnnData object with added RISE decomposition results:
92-
92+
9393
- X.uns["Pf2_weights"]: Component weights (shape: rank,)
9494
- X.uns["Pf2_A"]: Condition factors (shape: n_conditions, rank)
9595
- X.uns["Pf2_B"]: Eigen-state factors (shape: rank, rank)
@@ -113,11 +113,11 @@ def pf2(
113113

114114
def rise_pca_r2x(X: anndata.AnnData, ranks):
115115
"""Compute variance explained (R²X) for RISE and PCA across different ranks.
116-
116+
117117
This function evaluates how much variance in the data is explained by
118118
RISE (PARAFAC2) and PCA decompositions at different component ranks.
119119
Used to determine the optimal number of components for RISE analysis.
120-
120+
121121
Parameters
122122
----------
123123
X : anndata.AnnData
@@ -126,12 +126,12 @@ def rise_pca_r2x(X: anndata.AnnData, ranks):
126126
ranks : array-like of int
127127
Array of rank values to test (e.g., [1, 5, 10, 15, 20, 25, 30]).
128128
Each rank represents a different number of components.
129-
129+
130130
Returns
131131
-------
132132
tuple of numpy.ndarray
133133
(rise_r2x, pca_r2x) where:
134-
134+
135135
- rise_r2x: Variance explained by RISE for each rank (shape: len(ranks),)
136136
- pca_r2x: Variance explained by PCA for each rank (shape: len(ranks),)
137137
"""

RISE/figures/commonFuncs/plotFactors.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,12 @@ def plot_condition_factors(
2323
group_cond=False,
2424
):
2525
"""Plot condition factors as a heatmap showing how conditions contribute to components.
26-
26+
2727
This visualization shows how each experimental condition (rows) contributes to
2828
each RISE component (columns). High values indicate strong association between
2929
a condition and a component's pattern. Log transformation and normalization
3030
help reveal relative differences across conditions.
31-
31+
3232
Parameters
3333
----------
3434
data : anndata.AnnData
@@ -124,12 +124,12 @@ def plot_condition_factors(
124124

125125
def plot_eigenstate_factors(data: anndata.AnnData, ax: Axes):
126126
"""Plot eigen-state factors as a heatmap showing cell state patterns.
127-
127+
128128
Eigen-state factors represent the underlying cell state patterns across components.
129129
Each row represents an eigen-state (a summary of similar cells), and each column
130130
represents a component. High values indicate strong association between a cell
131131
state pattern and a component.
132-
132+
133133
Parameters
134134
----------
135135
data : anndata.AnnData
@@ -159,11 +159,11 @@ def plot_eigenstate_factors(data: anndata.AnnData, ax: Axes):
159159

160160
def plot_gene_factors(data: anndata.AnnData, ax: Axes, weight=0.08, trim=True):
161161
"""Plot gene factors as a heatmap showing which genes contribute to each component.
162-
162+
163163
This visualization reveals coordinated gene modules by showing which genes (rows)
164164
are highly weighted in each component (columns). The weight parameter filters out
165165
genes with low contributions, focusing on the most important genes for interpretation.
166-
166+
167167
Parameters
168168
----------
169169
data : anndata.AnnData

RISE/figures/commonFuncs/plotGeneral.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,11 @@
99

1010
def plot_r2x(data, rank_vec, ax: Axes):
1111
"""Plot variance explained (R²X) for RISE and PCA across different ranks.
12-
12+
1313
This visualization helps determine the optimal number of components by showing
1414
how variance explained increases with rank. The elbow point where the curve
1515
flattens indicates a good balance between model complexity and explanatory power.
16-
16+
1717
Parameters
1818
----------
1919
data : anndata.AnnData

RISE/figures/commonFuncs/plotPaCMAP.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,11 @@ def ds_show(result, ax: Axes):
4545

4646
def plot_gene_pacmap(gene: str, X: anndata.AnnData, ax: Axes, clip_outliers=0.9995):
4747
"""Plot PaCMAP embedding colored by gene expression levels.
48-
48+
4949
This visualization overlays gene expression onto the PaCMAP embedding of cells,
5050
revealing which cell populations express specific genes. Useful for validating
5151
component interpretations by checking if marker genes align with component patterns.
52-
52+
5353
Parameters
5454
----------
5555
gene : str
@@ -99,12 +99,12 @@ def plot_gene_pacmap(gene: str, X: anndata.AnnData, ax: Axes, clip_outliers=0.99
9999

100100
def plot_wp_pacmap(X: anndata.AnnData, cmp: int, ax: Axes, cbarMax: float = 1.0):
101101
"""Plot PaCMAP embedding colored by weighted projections for a component.
102-
102+
103103
This visualization shows which cells contribute most strongly to a specific
104104
component by coloring them according to their weighted projections. Cells with
105105
high weighted projections (bright colors) are most representative of that
106106
component's expression pattern.
107-
107+
108108
Parameters
109109
----------
110110
X : anndata.AnnData
@@ -157,11 +157,11 @@ def plot_labels_pacmap(
157157
color_key=None,
158158
):
159159
"""Plot PaCMAP embedding colored by categorical labels (cell type or condition).
160-
160+
161161
This visualization shows the overall structure of the cell embedding, revealing
162162
how cells cluster by cell type, experimental condition, or other categorical
163163
metadata. Useful for understanding the biological organization captured by RISE.
164-
164+
165165
Parameters
166166
----------
167167
X : anndata.AnnData

RISE/figures/figure2b_e.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,5 +33,4 @@ def makeFigure():
3333

3434
plot_labels_pacmap(X, "Cell Type", ax[3])
3535

36-
3736
return f

RISE/figures/figureS4.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,12 @@ def makeFigure():
3333

3434
def calculateFMS(A: anndata.AnnData, B: anndata.AnnData):
3535
"""Calculate Factor Match Score (FMS) between two RISE decompositions.
36-
36+
3737
Factor Match Score measures the similarity between two tensor decompositions
3838
by comparing their factor matrices. Values range from 0 (no similarity) to 1
3939
(identical factors). Used to assess decomposition stability across different
4040
initializations or data subsamples.
41-
41+
4242
Parameters
4343
----------
4444
A : anndata.AnnData
@@ -50,14 +50,14 @@ def calculateFMS(A: anndata.AnnData, B: anndata.AnnData):
5050
B : anndata.AnnData
5151
Second AnnData object with RISE decomposition results. Must have the
5252
same rank as A and contain the same decomposition attributes.
53-
53+
5454
Returns
5555
-------
5656
float
5757
Factor Match Score between 0 and 1. Higher values indicate more similar
5858
decompositions. Typically: >0.9 = highly stable, >0.6 = acceptable,
5959
<0.6 = unstable decomposition.
60-
60+
6161
Notes
6262
-----
6363
This function uses tlviz.factor_tools.factor_match_score with weights
@@ -146,12 +146,12 @@ def plot_fms_diff_ranks(
146146
runs: int,
147147
):
148148
"""Plot Factor Match Score (FMS) across different ranks to assess stability.
149-
149+
150150
FMS measures the reproducibility of PARAFAC2 decomposition results across
151151
multiple runs. Values above ~0.6 indicate stable, reproducible components.
152152
This helps determine which ranks produce reliable decompositions that are
153153
not overly sensitive to initialization or noise.
154-
154+
155155
Parameters
156156
----------
157157
X : anndata.AnnData
@@ -166,7 +166,7 @@ def plot_fms_diff_ranks(
166166
Number of independent runs per rank to use for FMS calculation.
167167
Higher values give more reliable FMS estimates but take longer.
168168
Typical values: 3-5 runs.
169-
169+
170170
Notes
171171
-----
172172
FMS values interpretation:

RISE/tests/test_parafac2.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,10 @@ def test_factor_thomson_reprod():
2222
XX = import_thomson()
2323
XX.obs["condition_unique_idxs"] = pd.Categorical(XX.obs["condition_unique_idxs"])
2424
XX = pf2(XX, 10, doEmbedding=False, tolerance=1e-6)
25-
np.testing.assert_allclose(np.array(XX.varm["Pf2_C"]), C_first, atol=1e-5, rtol=1e-5)
25+
np.testing.assert_allclose(
26+
np.array(XX.varm["Pf2_C"]), C_first, atol=1e-5, rtol=1e-5
27+
)
28+
2629

2730
def test_factor_thomson_R2X():
2831
"""Import and factor Thomson.

docs/conf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from pathlib import Path
66

77
# Disable cupy in anndata to avoid import errors during doc build
8-
os.environ['ANNDATA_CUPY'] = '0'
8+
os.environ["ANNDATA_CUPY"] = "0"
99

1010
# Add the parent directory to the path so we can import RISE
1111
# This allows Sphinx to find the RISE package

docs/generate_tutorial_figures.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
"""Generate figures for the RISE tutorial documentation."""
2+
23
import matplotlib
3-
matplotlib.use('Agg') # Non-interactive backend
4+
5+
matplotlib.use("Agg") # Non-interactive backend
46
import matplotlib.pyplot as plt
57
from pathlib import Path
68

@@ -13,8 +15,16 @@
1315
from RISE.figures.commonFuncs.plotGeneral import plot_r2x
1416
from RISE.figures.figureS4 import plot_fms_diff_ranks
1517
from RISE.factorization import pf2
16-
from RISE.figures.commonFuncs.plotFactors import plot_condition_factors, plot_eigenstate_factors, plot_gene_factors
17-
from RISE.figures.commonFuncs.plotPaCMAP import plot_labels_pacmap, plot_gene_pacmap, plot_wp_pacmap
18+
from RISE.figures.commonFuncs.plotFactors import (
19+
plot_condition_factors,
20+
plot_eigenstate_factors,
21+
plot_gene_factors,
22+
)
23+
from RISE.figures.commonFuncs.plotPaCMAP import (
24+
plot_labels_pacmap,
25+
plot_gene_pacmap,
26+
plot_wp_pacmap,
27+
)
1828

1929
print("Loading dataset...")
2030
X = import_thomson()
@@ -25,15 +35,15 @@
2535
fig, ax = plt.subplots(figsize=(5, 5))
2636
plot_r2x(X, ranks, ax)
2737
plt.tight_layout()
28-
plt.savefig(output_dir / "step2_r2x.png", dpi=150, bbox_inches='tight')
38+
plt.savefig(output_dir / "step2_r2x.png", dpi=150, bbox_inches="tight")
2939
plt.close()
3040

3141
# Figure 2: Factor Match Score
3242
print("Generating Figure 2: FMS plot...")
3343
fig, ax = plt.subplots(figsize=(5, 5))
3444
plot_fms_diff_ranks(X, ax, ranksList=list(ranks), runs=3)
3545
plt.tight_layout()
36-
plt.savefig(output_dir / "step3_fms.png", dpi=150, bbox_inches='tight')
46+
plt.savefig(output_dir / "step3_fms.png", dpi=150, bbox_inches="tight")
3747
plt.close()
3848

3949
# Perform RISE factorization
@@ -46,15 +56,15 @@
4656
fig, ax = plt.subplots(figsize=(8, 8))
4757
plot_condition_factors(X, ax=ax, cond="Condition", log_transform=True)
4858
plt.tight_layout()
49-
plt.savefig(output_dir / "step5_condition_factors.png", dpi=150, bbox_inches='tight')
59+
plt.savefig(output_dir / "step5_condition_factors.png", dpi=150, bbox_inches="tight")
5060
plt.close()
5161

5262
# Figure 4: Cell Embedding
5363
print("Generating Figure 4: Cell embedding...")
5464
fig, ax = plt.subplots(figsize=(8, 8))
5565
plot_labels_pacmap(X, labelType="Cell Type", ax=ax)
5666
plt.tight_layout()
57-
plt.savefig(output_dir / "step6_cell_embedding.png", dpi=150, bbox_inches='tight')
67+
plt.savefig(output_dir / "step6_cell_embedding.png", dpi=150, bbox_inches="tight")
5868
plt.close()
5969

6070
# Figure 5: Eigen-state Factors
@@ -63,15 +73,15 @@
6373
plot_eigenstate_factors(X, ax=ax)
6474
plt.ylabel("Eigen-state")
6575
plt.tight_layout()
66-
plt.savefig(output_dir / "step7_eigenstate_factors.png", dpi=150, bbox_inches='tight')
76+
plt.savefig(output_dir / "step7_eigenstate_factors.png", dpi=150, bbox_inches="tight")
6777
plt.close()
6878

6979
# Figure 6: Gene Factors
7080
print("Generating Figure 6: Gene factors...")
7181
fig, ax = plt.subplots(figsize=(7, 8))
7282
plot_gene_factors(X, ax=ax, weight=0.2, trim=True)
7383
plt.tight_layout()
74-
plt.savefig(output_dir / "step8_gene_factors.png", dpi=150, bbox_inches='tight')
84+
plt.savefig(output_dir / "step8_gene_factors.png", dpi=150, bbox_inches="tight")
7585
plt.close()
7686

7787
# Figure 7: Gene Expression on PaCMAP
@@ -80,15 +90,17 @@
8090
gene = "MS4A1"
8191
plot_gene_pacmap(gene, X, ax=ax, clip_outliers=0.9995)
8292
plt.tight_layout()
83-
plt.savefig(output_dir / "step9_gene_expression.png", dpi=150, bbox_inches='tight')
93+
plt.savefig(output_dir / "step9_gene_expression.png", dpi=150, bbox_inches="tight")
8494
plt.close()
8595

8696
# Figure 8: Weighted Projections
8797
print("Generating Figure 8: Weighted projections...")
8898
fig, ax = plt.subplots(figsize=(8, 8))
8999
plot_wp_pacmap(X, cmp=10, ax=ax, cbarMax=0.9)
90100
plt.tight_layout()
91-
plt.savefig(output_dir / "step10_weighted_projections.png", dpi=150, bbox_inches='tight')
101+
plt.savefig(
102+
output_dir / "step10_weighted_projections.png", dpi=150, bbox_inches="tight"
103+
)
92104
plt.close()
93105

94106
print(f"\nAll figures saved to {output_dir}")

0 commit comments

Comments
 (0)