Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
229 changes: 228 additions & 1 deletion pf2rnaseq/factorization.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,11 @@ def gradient(x):

# ===== Gradient w.r.t. W =====
# 1. Reconstruction term: ∂/∂W [||A - WH||²] = 2(error @ H^T), L1 penalty: ∂/∂W [α||W||₁] = α * sign(W)
grad_W = 2 * ((W @ H - A) @ H.T) + alpha * np.sign(W) - np.diag(alpha * np.sign(np.diag(W)))
grad_W = (
2 * ((W @ H - A) @ H.T)
+ alpha * np.sign(W)
- np.diag(alpha * np.sign(np.diag(W)))
)

# ===== Gradient w.r.t. H =====
# 1. Reconstruction term: ∂/∂H [||A - WH||²] = 2(W^T @ error), L1 penalty: ∂/∂H [α||H||₁] = α * sign(H)
Expand Down Expand Up @@ -346,3 +350,226 @@ def gradient(x):
print(f" Non-zeros: {np.sum(np.abs(H) > 1e-3)}/{H.size}")

return W, H


def deconvolution_cytokine_admm(
A: np.ndarray,
alpha_h: float = 0.1,
alpha_w: float = 0.01,
rho: float = 1.0,
max_iter: int = 10000,
tol: float = 1e-4,
random_state: int = 1,
adaptive_rho: bool = True,
non_negative_w: bool = True,
) -> tuple[np.ndarray, np.ndarray, dict]:
"""
Decompose cytokine factor matrix using ADMM: A ≈ W @ H

Parameters
----------
A : np.ndarray
Input matrix (n_cytokines, n_components)
alpha_h : float
L1 regularization for H
alpha_w : float
L1 regularization for W (off-diagonal only)
rho : float
ADMM penalty parameter
max_iter : int
Maximum iterations
tol : float
Convergence tolerance for both primal and dual residuals
random_state : int
Random seed
adaptive_rho : bool
Whether to adaptively adjust rho
non_negative_w : bool
If True, enforce W ≥ 0 (cytokines only activate, not inhibit)

Returns
-------
Z_W : np.ndarray
Cytokine interaction matrix (n_cytokines, n_cytokines)
Z_H : np.ndarray
Effect basis matrix (n_cytokines, n_components)
history : dict
Optimization history
"""
n_cytokines, n_components = A.shape
np.random.seed(random_state)

# Initialize
# W initialized as identity, H is original A
W = np.random.rand(n_cytokines, n_cytokines) * 0.1 + np.eye(n_cytokines)
H = np.random.rand(n_cytokines, n_components) * np.mean(np.abs(A))
Z_W = W.copy()
Z_H = H.copy()
U_W = np.zeros_like(W)
U_H = np.zeros_like(H)

print("Cytokine deconvolution with ADMM:")
print(f" A shape: {A.shape}")
print(f" Alpha_W: {alpha_w}, Alpha_H: {alpha_h}")
print(f" Rho: {rho}")
print(f" Tolerance: {tol}")
print(f" Non-negative W: {non_negative_w}")

off_diag_mask = ~np.eye(n_cytokines, dtype=bool)

def soft_threshold(X, threshold):
return np.sign(X) * np.maximum(np.abs(X) - threshold, 0)

def update_W(H, Z_W, U_W, rho):
"""Update W: constrain diagonal to 1.0, optional non-negativity"""
H_HT = H @ H.T
A_HT = A @ H.T
lhs = H_HT + rho * np.eye(n_cytokines)
rhs = A_HT + rho * (Z_W - U_W)

W_new = np.linalg.solve(lhs, rhs.T).T

# Non-negativity constraint for W
if non_negative_w:
W_new = np.maximum(W_new, 0)

# Diagonal constraint
np.fill_diagonal(W_new, 1.0)

return W_new

def update_H(W, Z_H, U_H, rho):
"""Update H: NO non-negativity constraint"""
W_TW = W.T @ W
W_TA = W.T @ A
lhs = W_TW + rho * np.eye(n_cytokines)
rhs = W_TA + rho * (Z_H - U_H)

return np.linalg.solve(lhs, rhs)

def update_Z_W(W, U_W, alpha, rho):
"""Update Z_W: soft-threshold off-diagonal, optional non-negativity"""
X = W + U_W
Z_W_new = X.copy()

# Soft-threshold off-diagonal
Z_W_new[off_diag_mask] = soft_threshold(X[off_diag_mask], alpha / rho)

# Non-negativity constraint for W
if non_negative_w:
Z_W_new = np.maximum(Z_W_new, 0)

# Diagonal constraint
np.fill_diagonal(Z_W_new, 1.0)

return Z_W_new

def update_Z_H(H, U_H, alpha, rho):
"""Update Z_H: soft-threshold, NO non-negativity"""
# H can be negative
return soft_threshold(H + U_H, alpha / rho)

history = {
"objective": [],
"primal_residual": [],
"dual_residual": [],
"rho": [],
"w_sparsity": [],
"h_sparsity": [],
}

print("\nStarting ADMM iterations...")

for iteration in range(max_iter):
Z_W_old = Z_W.copy()
Z_H_old = Z_H.copy()

# ADMM updates
W = update_W(H, Z_W, U_W, rho)
H = update_H(W, Z_H, U_H, rho)
Z_W = update_Z_W(W, U_W, alpha_w, rho)
Z_H = update_Z_H(H, U_H, alpha_h, rho)
U_W = U_W + (W - Z_W)
U_H = U_H + (H - Z_H)

# Primal residual: ||W - Z_W||² + ||H - Z_H||²
r_norm = np.sqrt(np.sum((W - Z_W) ** 2) + np.sum((H - Z_H) ** 2))

# Dual residual: ||ρ(Z_W - Z_W_old)||² + ||ρ(Z_H - Z_H_old)||²
s_norm = np.sqrt(
np.sum((rho * (Z_W - Z_W_old)) ** 2) + np.sum((rho * (Z_H - Z_H_old)) ** 2)
)

# Compute objective
recon_error = np.sum((A - W @ H) ** 2)
l1_W = alpha_w * np.sum(np.abs(Z_W[off_diag_mask]))
l1_H = alpha_h * np.sum(np.abs(Z_H))
objective = recon_error + l1_W + l1_H

# Track sparsity
w_sparsity = np.sum(np.abs(Z_W[off_diag_mask]) < 1e-3) / np.sum(off_diag_mask)
h_sparsity = np.sum(np.abs(Z_H) < 1e-3) / Z_H.size

# Store history
history["objective"].append(objective)
history["primal_residual"].append(r_norm)
history["dual_residual"].append(s_norm)
history["rho"].append(rho)
history["w_sparsity"].append(w_sparsity)
history["h_sparsity"].append(h_sparsity)

# Print progress
if iteration % 10 == 0 or iteration < 10:
print(
f" Iter {iteration:4d}: Obj={objective:.4e}, "
f"r={r_norm:.3e}, s={s_norm:.3e}, ρ={rho:.2f}"
)

# Adaptive rho update
if adaptive_rho and iteration > 0:
if r_norm > 10 * s_norm:
rho = rho * 2
U_W = U_W / 2
U_H = U_H / 2
print(f" Increased ρ → {rho:.2f}")
elif s_norm > 10 * r_norm:
rho = rho / 2
U_W = U_W * 2
U_H = U_H * 2
print(f" Decreased ρ → {rho:.2f}")

# Simple convergence check
if r_norm < tol and s_norm < tol:
print(f"\n✓ Converged at iteration {iteration}")
print(f" Primal residual: {r_norm:.4e} < {tol:.4e}")
print(f" Dual residual: {s_norm:.4e} < {tol:.4e}")
break

# Final statistics
A_recon = W @ H
rel_error = np.linalg.norm(A - A_recon, "fro") / np.linalg.norm(A, "fro")

w_sparsity = np.sum(np.abs(Z_W[off_diag_mask]) < 1e-3) / np.sum(off_diag_mask)
h_sparsity = np.sum(np.abs(Z_H) < 1e-3) / Z_H.size

print("\nOptimization complete:")
print(f" Iterations: {iteration + 1}/{max_iter}")
print(f" Relative reconstruction error: {rel_error:.4%}")

print("\n W (cytokine interactions):")
print(f" Off-diagonal sparsity: {w_sparsity:.2%}")
print(f" Off-diagonal non-zeros: {np.sum(np.abs(Z_W[off_diag_mask]) > 1e-3)}")
print(f" Mean |W_offdiag|: {np.abs(Z_W[off_diag_mask]).mean():.4f}")
print(f" Min value: {W.min():.4f}") # Check non-negativity
print(f" Max value: {W.max():.4f}")
print(f" Diagonal: all 1.0 (constrained)")

print("\n H (effect patterns):")
print(f" Sparsity: {h_sparsity:.2%}")
print(f" Non-zeros: {np.sum(np.abs(Z_H) > 1e-3)}/{Z_H.size}")
print(f" Mean |H|: {np.abs(Z_H).mean():.4f}")
print(f" Min value: {H.min():.4f}") # Can be negative
print(f" Max value: {H.max():.4f}")
print(f" Negative values: {np.sum(H < 0)} ({100 * np.sum(H < 0) / H.size:.1f}%)")

return Z_W, Z_H, history
93 changes: 93 additions & 0 deletions pf2rnaseq/figures/figureParseADMM.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
"""
Parse data: Plotting factors
"""

import numpy as np
import pandas as pd
import seaborn as sns
from anndata import read_h5ad
from matplotlib import pyplot as plt

from ..factorization import correct_conditions, deconvolution_cytokine_admm
from .common import getSetup, subplotLabel
from .commonFuncs.plotFactors import (
plot_condition_factors,
)


def samples_only(X) -> pd.DataFrame:
"""Obtain samples once only with corresponding observations"""
samples = X.obs
df_samples = samples.drop_duplicates(subset="condition_unique_idxs")
df_samples = df_samples.sort_values("condition_unique_idxs")
return df_samples


def makeFigure():
"""Get a list of the axis objects and create a figure."""
# Get list of axis objects
ax, f = getSetup((25, 15), (1, 3))

# Add subplot labels
subplotLabel(ax)

# Load data
X = read_h5ad("/home/nicoleb/ParsePf2_100_D11_filt.h5ad")
X.uns["Pf2_A"] = correct_conditions(X)
A = X.uns["Pf2_A"]

# Center A by cytokine medians
cytokine_medians = np.median(A, axis=1, keepdims=True)
A_centered = A - cytokine_medians
X.uns["Pf2_A"] = A_centered

W, H, _ = deconvolution_cytokine_admm(A_centered, alpha_h=0.05, alpha_w=0.05, rho=2)

# Get cytokine names in correct order
samples_df = samples_only(X)

# Create deconvolved version for plotting
X_deconv = X.copy()
X_deconv.uns["Pf2_A"] = H # Use primary effects only

plot_condition_factors(
X_deconv,
ax[0],
samples_df["cytokine"],
groupConditions=True,
cond="cytokine",
log_scale=False,
)
ax[0].set_title("Deconvolved matrix (H)", fontsize=12, fontweight="bold")

# Plot original median subtracted factor matrix for reference
plot_condition_factors(
X,
ax[1],
samples_df["cytokine"],
groupConditions=True,
cond="cytokine",
log_scale=False,
)
ax[1].set_title("Original Effects (A)", fontsize=12, fontweight="bold")

cytokine_names = samples_df["cytokine"].values

# Plot 2: W heatmap (primary effects)
sns.heatmap(
W,
ax=ax[2],
cmap="YlOrRd",
robust=False,
square=True,
cbar_kws={"label": "Signaling Strength"},
xticklabels=cytokine_names,
yticklabels=cytokine_names,
)
ax[2].set_title("Cytokine Signaling (W)", fontsize=12, fontweight="bold")
ax[2].set_xlabel("Inducing Cytokine →", fontsize=10)
ax[2].set_ylabel("← Induced Cytokine", fontsize=10)
plt.setp(ax[2].get_xticklabels(), rotation=90, ha="center", fontsize=6)
plt.setp(ax[2].get_yticklabels(), rotation=0, fontsize=6)

return f
Loading
Loading