Skip to content

Commit 5ddb3a6

Browse files
committed
ADMM application and non-negativity for W
1 parent 355aee8 commit 5ddb3a6

File tree

2 files changed

+184
-56
lines changed

2 files changed

+184
-56
lines changed

pf2rnaseq/factorization.py

Lines changed: 91 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,11 @@ def gradient(x):
288288

289289
# ===== Gradient w.r.t. W =====
290290
# 1. Reconstruction term: ∂/∂W [||A - WH||²] = 2(error @ H^T), L1 penalty: ∂/∂W [α||W||₁] = α * sign(W)
291-
grad_W = 2 * ((W @ H - A) @ H.T) + alpha * np.sign(W) - np.diag(alpha * np.sign(np.diag(W)))
291+
grad_W = (
292+
2 * ((W @ H - A) @ H.T)
293+
+ alpha * np.sign(W)
294+
- np.diag(alpha * np.sign(np.diag(W)))
295+
)
292296

293297
# ===== Gradient w.r.t. H =====
294298
# 1. Reconstruction term: ∂/∂H [||A - WH||²] = 2(W^T @ error), L1 penalty: ∂/∂H [α||H||₁] = α * sign(H)
@@ -347,19 +351,21 @@ def gradient(x):
347351

348352
return W, H
349353

354+
350355
def deconvolution_cytokine_admm(
351356
A: np.ndarray,
352357
alpha_h: float = 0.1,
353358
alpha_w: float = 0.01,
354359
rho: float = 1.0,
355-
max_iter: int = 5000,
356-
tol: float = 1e-4, # Single tolerance for both primal and dual
360+
max_iter: int = 10000,
361+
tol: float = 1e-4,
357362
random_state: int = 1,
358363
adaptive_rho: bool = True,
364+
non_negative_w: bool = True,
359365
) -> tuple[np.ndarray, np.ndarray, dict]:
360366
"""
361367
Decompose cytokine factor matrix using ADMM: A ≈ W @ H
362-
368+
363369
Parameters
364370
----------
365371
A : np.ndarray
@@ -378,7 +384,9 @@ def deconvolution_cytokine_admm(
378384
Random seed
379385
adaptive_rho : bool
380386
Whether to adaptively adjust rho
381-
387+
non_negative_w : bool
388+
If True, enforce W ≥ 0 (cytokines only activate, not inhibit)
389+
382390
Returns
383391
-------
384392
Z_W : np.ndarray
@@ -390,112 +398,132 @@ def deconvolution_cytokine_admm(
390398
"""
391399
n_cytokines, n_components = A.shape
392400
np.random.seed(random_state)
393-
401+
394402
# Initialize
395403
W = np.eye(n_cytokines)
396404
H = A.copy()
397405
Z_W = W.copy()
398406
Z_H = H.copy()
399407
U_W = np.zeros_like(W)
400408
U_H = np.zeros_like(H)
401-
409+
402410
print("Cytokine deconvolution with ADMM:")
403411
print(f" A shape: {A.shape}")
404412
print(f" Alpha_W: {alpha_w}, Alpha_H: {alpha_h}")
405413
print(f" Rho: {rho}")
406414
print(f" Tolerance: {tol}")
407-
415+
print(f" Non-negative W: {non_negative_w}")
416+
408417
off_diag_mask = ~np.eye(n_cytokines, dtype=bool)
409-
418+
410419
def soft_threshold(X, threshold):
411420
return np.sign(X) * np.maximum(np.abs(X) - threshold, 0)
412-
421+
413422
def update_W(H, Z_W, U_W, rho):
414-
"""Update W: constrain diagonal to 1.0"""
423+
"""Update W: constrain diagonal to 1.0, optional non-negativity"""
415424
H_HT = H @ H.T
416425
A_HT = A @ H.T
417426
lhs = H_HT + rho * np.eye(n_cytokines)
418427
rhs = A_HT + rho * (Z_W - U_W)
419-
428+
420429
W_new = np.linalg.solve(lhs, rhs.T).T
421-
np.fill_diagonal(W_new, 1.0)
422430

431+
# Non-negativity constraint for W
432+
if non_negative_w:
433+
W_new = np.maximum(W_new, 0)
434+
435+
# Diagonal constraint
436+
np.fill_diagonal(W_new, 1.0)
437+
423438
return W_new
424-
439+
425440
def update_H(W, Z_H, U_H, rho):
426-
"""Update H"""
441+
"""Update H: NO non-negativity constraint"""
427442
W_TW = W.T @ W
428443
W_TA = W.T @ A
429444
lhs = W_TW + rho * np.eye(n_cytokines)
430445
rhs = W_TA + rho * (Z_H - U_H)
446+
431447
return np.linalg.solve(lhs, rhs)
432-
448+
433449
def update_Z_W(W, U_W, alpha, rho):
434-
"""Update Z_W: soft-threshold off-diagonal only"""
450+
"""Update Z_W: soft-threshold off-diagonal, optional non-negativity"""
435451
X = W + U_W
436452
Z_W_new = X.copy()
453+
454+
# Soft-threshold off-diagonal
437455
Z_W_new[off_diag_mask] = soft_threshold(X[off_diag_mask], alpha / rho)
456+
457+
# Non-negativity constraint for W
458+
if non_negative_w:
459+
Z_W_new = np.maximum(Z_W_new, 0)
460+
461+
# Diagonal constraint
462+
np.fill_diagonal(Z_W_new, 1.0)
463+
438464
return Z_W_new
439-
465+
440466
def update_Z_H(H, U_H, alpha, rho):
441-
"""Update Z_H: soft-threshold entire matrix"""
467+
"""Update Z_H: soft-threshold, NO non-negativity"""
468+
# H can be negative
442469
return soft_threshold(H + U_H, alpha / rho)
443-
470+
444471
history = {
445-
'objective': [],
446-
'primal_residual': [],
447-
'dual_residual': [],
448-
'rho': [],
449-
'w_sparsity': [],
450-
'h_sparsity': []
472+
"objective": [],
473+
"primal_residual": [],
474+
"dual_residual": [],
475+
"rho": [],
476+
"w_sparsity": [],
477+
"h_sparsity": [],
451478
}
452-
479+
453480
print("\nStarting ADMM iterations...")
454-
481+
455482
for iteration in range(max_iter):
456483
Z_W_old = Z_W.copy()
457484
Z_H_old = Z_H.copy()
458-
485+
459486
# ADMM updates
460487
W = update_W(H, Z_W, U_W, rho)
461488
H = update_H(W, Z_H, U_H, rho)
462489
Z_W = update_Z_W(W, U_W, alpha_w, rho)
463490
Z_H = update_Z_H(H, U_H, alpha_h, rho)
464491
U_W = U_W + (W - Z_W)
465492
U_H = U_H + (H - Z_H)
466-
467-
# ===== SIMPLIFIED CONVERGENCE CHECK =====
468-
493+
469494
# Primal residual: ||W - Z_W||² + ||H - Z_H||²
470-
r_norm = np.sqrt(np.sum((W - Z_W)**2) + np.sum((H - Z_H)**2))
471-
495+
r_norm = np.sqrt(np.sum((W - Z_W) ** 2) + np.sum((H - Z_H) ** 2))
496+
472497
# Dual residual: ||ρ(Z_W - Z_W_old)||² + ||ρ(Z_H - Z_H_old)||²
473-
s_norm = np.sqrt(np.sum((rho * (Z_W - Z_W_old))**2) +
474-
np.sum((rho * (Z_H - Z_H_old))**2))
475-
498+
s_norm = np.sqrt(
499+
np.sum((rho * (Z_W - Z_W_old)) ** 2) + np.sum((rho * (Z_H - Z_H_old)) ** 2)
500+
)
501+
476502
# Compute objective
477503
recon_error = np.sum((A - W @ H) ** 2)
478504
l1_W = alpha_w * np.sum(np.abs(Z_W[off_diag_mask]))
479505
l1_H = alpha_h * np.sum(np.abs(Z_H))
480506
objective = recon_error + l1_W + l1_H
481-
507+
482508
# Track sparsity
483509
w_sparsity = np.sum(np.abs(Z_W[off_diag_mask]) < 1e-3) / np.sum(off_diag_mask)
484510
h_sparsity = np.sum(np.abs(Z_H) < 1e-3) / Z_H.size
485-
511+
486512
# Store history
487-
history['objective'].append(objective)
488-
history['primal_residual'].append(r_norm)
489-
history['dual_residual'].append(s_norm)
490-
history['rho'].append(rho)
491-
history['w_sparsity'].append(w_sparsity)
492-
history['h_sparsity'].append(h_sparsity)
493-
513+
history["objective"].append(objective)
514+
history["primal_residual"].append(r_norm)
515+
history["dual_residual"].append(s_norm)
516+
history["rho"].append(rho)
517+
history["w_sparsity"].append(w_sparsity)
518+
history["h_sparsity"].append(h_sparsity)
519+
494520
# Print progress
495521
if iteration % 10 == 0 or iteration < 10:
496-
print(f" Iter {iteration:4d}: Obj={objective:.4e}, "
497-
f"r={r_norm:.3e}, s={s_norm:.3e}, ρ={rho:.2f}")
498-
522+
print(
523+
f" Iter {iteration:4d}: Obj={objective:.4e}, "
524+
f"r={r_norm:.3e}, s={s_norm:.3e}, ρ={rho:.2f}"
525+
)
526+
499527
# Adaptive rho update
500528
if adaptive_rho and iteration > 0:
501529
if r_norm > 10 * s_norm:
@@ -508,32 +536,39 @@ def update_Z_H(H, U_H, alpha, rho):
508536
U_W = U_W * 2
509537
U_H = U_H * 2
510538
print(f" Decreased ρ → {rho:.2f}")
511-
539+
512540
# Simple convergence check
513541
if r_norm < tol and s_norm < tol:
514542
print(f"\n✓ Converged at iteration {iteration}")
515543
print(f" Primal residual: {r_norm:.4e} < {tol:.4e}")
516544
print(f" Dual residual: {s_norm:.4e} < {tol:.4e}")
517545
break
518-
546+
519547
# Final statistics
520548
A_recon = W @ H
521549
rel_error = np.linalg.norm(A - A_recon, "fro") / np.linalg.norm(A, "fro")
522-
550+
523551
w_sparsity = np.sum(np.abs(Z_W[off_diag_mask]) < 1e-3) / np.sum(off_diag_mask)
524552
h_sparsity = np.sum(np.abs(Z_H) < 1e-3) / Z_H.size
525-
553+
526554
print("\nOptimization complete:")
527555
print(f" Iterations: {iteration + 1}/{max_iter}")
528556
print(f" Relative reconstruction error: {rel_error:.4%}")
529-
print(f"\n W (cytokine interactions):")
557+
558+
print("\n W (cytokine interactions):")
530559
print(f" Off-diagonal sparsity: {w_sparsity:.2%}")
531560
print(f" Off-diagonal non-zeros: {np.sum(np.abs(Z_W[off_diag_mask]) > 1e-3)}")
532561
print(f" Mean |W_offdiag|: {np.abs(Z_W[off_diag_mask]).mean():.4f}")
562+
print(f" Min value: {W.min():.4f}") # Check non-negativity
563+
print(f" Max value: {W.max():.4f}")
533564
print(f" Diagonal: all 1.0 (constrained)")
534-
print(f"\n H (effect patterns):")
565+
566+
print("\n H (effect patterns):")
535567
print(f" Sparsity: {h_sparsity:.2%}")
536568
print(f" Non-zeros: {np.sum(np.abs(Z_H) > 1e-3)}/{Z_H.size}")
537569
print(f" Mean |H|: {np.abs(Z_H).mean():.4f}")
538-
570+
print(f" Min value: {H.min():.4f}") # Can be negative
571+
print(f" Max value: {H.max():.4f}")
572+
print(f" Negative values: {np.sum(H < 0)} ({100*np.sum(H < 0)/H.size:.1f}%)")
573+
539574
return Z_W, Z_H, history
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
"""
2+
Parse data: Plotting factors
3+
"""
4+
5+
import numpy as np
6+
import pandas as pd
7+
import seaborn as sns
8+
from anndata import read_h5ad
9+
from matplotlib import pyplot as plt
10+
11+
from ..factorization import correct_conditions, deconvolution_cytokine_admm
12+
from .common import getSetup, subplotLabel
13+
from .commonFuncs.plotFactors import (
14+
plot_condition_factors,
15+
)
16+
17+
18+
def samples_only(X) -> pd.DataFrame:
19+
"""Obtain samples once only with corresponding observations"""
20+
samples = X.obs
21+
df_samples = samples.drop_duplicates(subset="condition_unique_idxs")
22+
df_samples = df_samples.sort_values("condition_unique_idxs")
23+
return df_samples
24+
25+
26+
def makeFigure():
27+
"""Get a list of the axis objects and create a figure."""
28+
# Get list of axis objects
29+
ax, f = getSetup((25, 15), (1, 3))
30+
31+
# Add subplot labels
32+
subplotLabel(ax)
33+
34+
# Load data
35+
X = read_h5ad("/home/nicoleb/ParsePf2_100_D11_filt.h5ad")
36+
X.uns["Pf2_A"] = correct_conditions(X)
37+
A = X.uns["Pf2_A"]
38+
39+
# Center A by cytokine medians
40+
cytokine_medians = np.median(A, axis=1, keepdims=True)
41+
A_centered = A - cytokine_medians
42+
X.uns["Pf2_A"] = A_centered
43+
44+
W, H, _ = deconvolution_cytokine_admm(A_centered, alpha_h=0.05, alpha_w=0.05, rho=2)
45+
46+
# Get cytokine names in correct order
47+
samples_df = samples_only(X)
48+
49+
# Create deconvolved version for plotting
50+
X_deconv = X.copy()
51+
X_deconv.uns["Pf2_A"] = H # Use primary effects only
52+
53+
plot_condition_factors(
54+
X_deconv,
55+
ax[0],
56+
samples_df["cytokine"],
57+
groupConditions=True,
58+
cond="cytokine",
59+
log_scale=False,
60+
)
61+
ax[0].set_title("Deconvolved matrix (H)", fontsize=12, fontweight="bold")
62+
63+
#Plot original median subtracted factor matrix for reference
64+
plot_condition_factors(
65+
X,
66+
ax[1],
67+
samples_df["cytokine"],
68+
groupConditions=True,
69+
cond="cytokine",
70+
log_scale=False,
71+
)
72+
ax[1].set_title("Original Effects (A)", fontsize=12, fontweight="bold")
73+
74+
cytokine_names = samples_df["cytokine"].values
75+
76+
# Plot 2: W heatmap (primary effects)
77+
sns.heatmap(
78+
W,
79+
ax=ax[2],
80+
cmap="YlOrRd",
81+
robust=False,
82+
square=True,
83+
cbar_kws={"label": "Signaling Strength"},
84+
xticklabels=cytokine_names,
85+
yticklabels=cytokine_names,
86+
)
87+
ax[2].set_title("Cytokine Signaling (W)", fontsize=12, fontweight="bold")
88+
ax[2].set_xlabel("Inducing Cytokine →", fontsize=10)
89+
ax[2].set_ylabel("← Induced Cytokine", fontsize=10)
90+
plt.setp(ax[2].get_xticklabels(), rotation=90, ha="center", fontsize=6)
91+
plt.setp(ax[2].get_yticklabels(), rotation=0, fontsize=6)
92+
93+
return f

0 commit comments

Comments
 (0)