Skip to content

Commit aba71d4

Browse files
committed
Add two simple tests for ADMM approach
1 parent 83f2003 commit aba71d4

File tree

3 files changed

+153
-3
lines changed

3 files changed

+153
-3
lines changed

pf2rnaseq/factorization.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -400,8 +400,9 @@ def deconvolution_cytokine_admm(
400400
np.random.seed(random_state)
401401

402402
# Initialize
403-
W = np.eye(n_cytokines)
404-
H = A.copy()
403+
# W initialized as identity, H is original A
404+
W = np.random.rand(n_cytokines, n_cytokines) * 0.1 + np.eye(n_cytokines)
405+
H = np.random.rand(n_cytokines, n_components) * np.mean(np.abs(A))
405406
Z_W = W.copy()
406407
Z_H = H.copy()
407408
U_W = np.zeros_like(W)

pf2rnaseq/figures/figureParseADMM.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def makeFigure():
6060
)
6161
ax[0].set_title("Deconvolved matrix (H)", fontsize=12, fontweight="bold")
6262

63-
#Plot original median subtracted factor matrix for reference
63+
# Plot original median subtracted factor matrix for reference
6464
plot_condition_factors(
6565
X,
6666
ax[1],
Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
"""
2+
Test the cytokine deconvolution method.
3+
"""
4+
5+
import numpy as np
6+
import pytest
7+
8+
from ..factorization import deconvolution_cytokine_admm
9+
10+
11+
def test_deconvolution_cytokine_admm_sparse():
12+
"""
13+
Test deconvolution_cytokine_admm with sparse ground truth matrices.
14+
15+
This test generates sparse W (cytokine interaction) and H (effect basis) matrices,
16+
computes A = W @ H, and verifies that the deconvolution recovers the structure.
17+
"""
18+
np.random.seed(42)
19+
20+
# Dimensions
21+
n_cytokines = 8
22+
n_components = 12
23+
24+
# Generate sparse ground truth W (cytokine interaction matrix)
25+
# W should have 1s on diagonal and sparse off-diagonal elements
26+
W_true = np.eye(n_cytokines)
27+
28+
# Add sparse off-diagonal interactions (only 20% of off-diagonal elements)
29+
off_diag_mask = ~np.eye(n_cytokines, dtype=bool)
30+
n_off_diag = np.sum(off_diag_mask)
31+
n_nonzero_w = int(0.2 * n_off_diag)
32+
33+
# Randomly select positions for non-zero off-diagonal elements
34+
off_diag_positions = np.where(off_diag_mask)
35+
nonzero_indices = np.random.choice(n_off_diag, n_nonzero_w, replace=False)
36+
37+
for idx in nonzero_indices:
38+
i, j = off_diag_positions[0][idx], off_diag_positions[1][idx]
39+
# Use small positive values for cytokine interactions
40+
W_true[i, j] = np.random.uniform(0.1, 0.5)
41+
42+
# Generate sparse ground truth H (effect basis matrix)
43+
# H should have about 30% non-zero elements
44+
H_true = np.zeros((n_cytokines, n_components))
45+
n_nonzero_h = int(0.3 * n_cytokines * n_components)
46+
47+
for _ in range(n_nonzero_h):
48+
i = np.random.randint(0, n_cytokines)
49+
j = np.random.randint(0, n_components)
50+
# H can have both positive and negative values
51+
H_true[i, j] = np.random.uniform(-2.0, 2.0)
52+
53+
# Compute the observed matrix A
54+
A = W_true @ H_true
55+
56+
# Add small noise
57+
noise_level = 0.01
58+
A_noisy = A + noise_level * np.random.randn(n_cytokines, n_components)
59+
60+
# Run deconvolution
61+
W_recovered, H_recovered, history = deconvolution_cytokine_admm(
62+
A_noisy,
63+
alpha_h=0.1,
64+
alpha_w=0.05,
65+
rho=1.0,
66+
max_iter=1000,
67+
tol=1e-6,
68+
random_state=42,
69+
adaptive_rho=True,
70+
non_negative_w=True,
71+
)
72+
73+
# Verify shapes
74+
assert W_recovered.shape == (n_cytokines, n_cytokines)
75+
assert H_recovered.shape == (n_cytokines, n_components)
76+
77+
# Verify diagonal of W is constrained to 1
78+
np.testing.assert_allclose(np.diag(W_recovered), np.ones(n_cytokines), atol=1e-10)
79+
80+
# Verify non-negativity of W
81+
assert np.all(W_recovered >= -1e-10), "W should be non-negative"
82+
83+
# Verify reconstruction quality
84+
A_reconstructed = W_recovered @ H_recovered
85+
reconstruction_error = np.linalg.norm(
86+
A_noisy - A_reconstructed, "fro"
87+
) / np.linalg.norm(A_noisy, "fro")
88+
assert reconstruction_error < 0.1, (
89+
f"Reconstruction error too high: {reconstruction_error}"
90+
)
91+
92+
# Verify sparsity of W (off-diagonal should be sparse)
93+
w_sparsity = np.sum(np.abs(W_recovered[off_diag_mask]) < 1e-3) / np.sum(
94+
off_diag_mask
95+
)
96+
assert w_sparsity > 0.5, f"W should be sparse, but sparsity is only {w_sparsity}"
97+
98+
# Verify sparsity of H
99+
h_sparsity = np.sum(np.abs(H_recovered) < 1e-3) / H_recovered.size
100+
assert h_sparsity > 0.3, f"H should be sparse, but sparsity is only {h_sparsity}"
101+
102+
# Verify history contains expected keys
103+
assert "objective" in history
104+
assert "primal_residual" in history
105+
assert "dual_residual" in history
106+
assert "rho" in history
107+
assert "w_sparsity" in history
108+
assert "h_sparsity" in history
109+
110+
# Verify objective decreases (generally)
111+
assert len(history["objective"]) > 0
112+
# Check that final objective is lower than initial (with some tolerance for fluctuations)
113+
initial_obj = history["objective"][0]
114+
final_obj = history["objective"][-1]
115+
assert final_obj < initial_obj * 1.1, "Objective should generally decrease"
116+
117+
print("\nTest passed!")
118+
print(f"Reconstruction error: {reconstruction_error:.4f}")
119+
print(f"W off-diagonal sparsity: {w_sparsity:.2%}")
120+
print(f"H sparsity: {h_sparsity:.2%}")
121+
print(f"Converged in {len(history['objective'])} iterations")
122+
123+
124+
def test_deconvolution_cytokine_admm_small():
125+
"""
126+
Test with a small problem to ensure basic functionality.
127+
"""
128+
np.random.seed(999)
129+
130+
n_cytokines = 3
131+
n_components = 5
132+
133+
# Simple test matrix
134+
A = np.random.randn(n_cytokines, n_components)
135+
136+
# Run with default parameters
137+
W, H, history = deconvolution_cytokine_admm(
138+
A, max_iter=100, tol=1e-6, random_state=999
139+
)
140+
141+
# Basic checks
142+
assert W.shape == (n_cytokines, n_cytokines)
143+
assert H.shape == (n_cytokines, n_components)
144+
assert len(history["objective"]) > 0
145+
146+
# Verify diagonal constraint
147+
np.testing.assert_allclose(np.diag(W), np.ones(n_cytokines), atol=1e-10)
148+
149+
print("\nSmall problem test passed!")

0 commit comments

Comments
 (0)