From 900585502a1bca0be60efb91a9d46c14aacb25de Mon Sep 17 00:00:00 2001 From: Nicole Bedanova Date: Thu, 4 Dec 2025 11:58:47 -0800 Subject: [PATCH 1/5] ADMM deconvolution --- pf2rnaseq/factorization.py | 183 +++++++++++++++++++++++++++++++++++++ 1 file changed, 183 insertions(+) diff --git a/pf2rnaseq/factorization.py b/pf2rnaseq/factorization.py index f26b8da..b635c81 100644 --- a/pf2rnaseq/factorization.py +++ b/pf2rnaseq/factorization.py @@ -346,3 +346,186 @@ def gradient(x): print(f" Non-zeros: {np.sum(np.abs(H) > 1e-3)}/{H.size}") return W, H + +def deconvolution_cytokine_admm( + A: np.ndarray, + alpha: float = 0.1, + rho: float = 1.0, + max_iter: int = 5000, + tol: float = 1e-4, + random_state: int = 1, + adaptive_rho: bool = True, +) -> tuple[np.ndarray, np.ndarray, dict]: + """ + Decompose cytokine factor matrix using ADMM: A ≈ W @ H + + Parameters + ---------- + A : np.ndarray + Input matrix (n_cytokines, n_components) + alpha : float + L1 regularization strength (applied to both W and H) + rho : float + ADMM penalty parameter + max_iter : int + Maximum ADMM iterations + tol : float + Convergence tolerance + random_state : int + Random seed + adaptive_rho : bool + Whether to adaptively adjust rho + + Returns + ------- + Z_W : np.ndarray + Cytokine interaction matrix (n_cytokines, n_cytokines) + Z_H : np.ndarray + Effect basis matrix (n_cytokines, n_components) + history : dict + Optimization history + """ + n_cytokines, n_components = A.shape + np.random.seed(random_state) + + # Initialize + W = np.eye(n_cytokines) + H = A.copy() + Z_W = W.copy() + Z_H = H.copy() + U_W = np.zeros_like(W) + U_H = np.zeros_like(H) + + print("Cytokine deconvolution with ADMM:") + print(f" A shape: {A.shape}") + print(f" Alpha (L1 penalty): {alpha}") + print(f" Rho (ADMM penalty): {rho}") + print(f" Adaptive rho: {adaptive_rho}") + + # Create mask for off-diagonal elements + off_diag_mask = ~np.eye(n_cytokines, dtype=bool) + + def soft_threshold(X, threshold): + return np.sign(X) * np.maximum(np.abs(X) - threshold, 0) + + def update_W(H, Z_W, U_W, rho): + """Update W: solve (H@H^T + rho*I) W^T = (A@H^T + rho(Z_W - U_W))^T""" + H_HT = H @ H.T + A_HT = A @ H.T + lhs = H_HT + rho * np.eye(n_cytokines) + rhs = A_HT + rho * (Z_W - U_W) + return np.linalg.solve(lhs, rhs.T).T + + def update_H(W, Z_H, U_H, rho): + """Update H: solve (W^T@W + rho*I) H = W^T@A + rho(Z_H - U_H)""" + W_TW = W.T @ W + W_TA = W.T @ A + lhs = W_TW + rho * np.eye(n_cytokines) + rhs = W_TA + rho * (Z_H - U_H) + return np.linalg.solve(lhs, rhs) + + def update_Z_W(W, U_W, alpha, rho): + """Update Z_W: soft-threshold off-diagonal, preserve diagonal""" + X = W + U_W + Z_W_new = soft_threshold(X, alpha / rho) + # Restore diagonal (no L1 penalty on direct effects) + np.fill_diagonal(Z_W_new, np.diag(X)) + return Z_W_new + + def update_Z_H(H, U_H, alpha, rho): + """Update Z_H: soft-threshold entire matrix""" + return soft_threshold(H + U_H, alpha / rho) + + history = { + 'objective': [], + 'primal_residual': [], + 'dual_residual': [], + 'rho': [], + 'w_sparsity': [], + 'h_sparsity': [] + } + + print("\nStarting ADMM iterations...") + + for iteration in range(max_iter): + # Store old Z values for dual residual computation + Z_W_old = Z_W.copy() + Z_H_old = Z_H.copy() + + # ADMM updates + W = update_W(H, Z_W, U_W, rho) + H = update_H(W, Z_H, U_H, rho) + Z_W = update_Z_W(W, U_W, alpha, rho) + Z_H = update_Z_H(H, U_H, alpha, rho) + U_W = U_W + (W - Z_W) + U_H = U_H + (H - Z_H) + + # Compute residuals + r_norm = np.sqrt(np.sum((W - Z_W)**2) + np.sum((H - Z_H)**2)) + s_norm = rho * np.sqrt(np.sum((Z_W - Z_W_old)**2) + + np.sum((Z_H - Z_H_old)**2)) + + # Compute objective (off-diagonal penalty for W, full penalty for H) + recon_error = np.sum((A - W @ H) ** 2) + l1_W = alpha * np.sum(np.abs(Z_W[off_diag_mask])) + l1_H = alpha * np.sum(np.abs(Z_H)) + objective = recon_error + l1_W + l1_H + + # Track sparsity + w_sparsity = np.sum(np.abs(Z_W[off_diag_mask]) < 1e-3) / np.sum(off_diag_mask) + h_sparsity = np.sum(np.abs(Z_H) < 1e-3) / Z_H.size + + # Store history + history['objective'].append(objective) + history['primal_residual'].append(r_norm) + history['dual_residual'].append(s_norm) + history['rho'].append(rho) + history['w_sparsity'].append(w_sparsity) + history['h_sparsity'].append(h_sparsity) + + # Print progress + if iteration % 10 == 0 or iteration < 10: + print(f" Iter {iteration:4d}: Obj={objective:.4e}, " + f"||r||={r_norm:.4e}, ||s||={s_norm:.4e}, " + f"W_sparse={w_sparsity:.2%}, H_sparse={h_sparsity:.2%}") + + # Adaptive rho update + if adaptive_rho and iteration > 0: + if r_norm > 10 * s_norm: + rho = rho * 2 + U_W = U_W / 2 + U_H = U_H / 2 + elif s_norm > 10 * r_norm: + rho = rho / 2 + U_W = U_W * 2 + U_H = U_H * 2 + + # Check convergence + eps_primal = tol * np.sqrt(W.size + H.size) + eps_dual = tol * np.sqrt(U_W.size + U_H.size) + + if r_norm < eps_primal and s_norm < eps_dual: + print(f"\nConverged at iteration {iteration}") + break + + # Final statistics + A_recon = W @ H + rel_error = np.linalg.norm(A - A_recon, "fro") / np.linalg.norm(A, "fro") + + w_sparsity = np.sum(np.abs(Z_W[off_diag_mask]) < 1e-3) / np.sum(off_diag_mask) + h_sparsity = np.sum(np.abs(Z_H) < 1e-3) / Z_H.size + + print("\nOptimization complete:") + print(f" Iterations: {iteration + 1}/{max_iter}") + print(f" Relative reconstruction error: {rel_error:.4%}") + print(f"\n W (cytokine interactions):") + print(f" Off-diagonal sparsity: {w_sparsity:.2%}") + print(f" Off-diagonal non-zeros: {np.sum(np.abs(Z_W[off_diag_mask]) > 1e-3)}") + print(f" Mean |W_offdiag|: {np.abs(Z_W[off_diag_mask]).mean():.4f}") + print(f" Diagonal mean: {np.abs(np.diag(Z_W)).mean():.4f}") + print(f"\n H (effect patterns):") + print(f" Sparsity: {h_sparsity:.2%}") + print(f" Non-zeros: {np.sum(np.abs(Z_H) > 1e-3)}/{Z_H.size}") + print(f" Mean |H|: {np.abs(Z_H).mean():.4f}") + + return Z_W, Z_H, history \ No newline at end of file From 355aee8091bbac8a89b2f8435ee4b7a6c816753a Mon Sep 17 00:00:00 2001 From: Nicole Bedanova Date: Tue, 9 Dec 2025 12:20:07 -0800 Subject: [PATCH 2/5] Update ADMM --- pf2rnaseq/factorization.py | 80 +++++++++++++++++++++----------------- 1 file changed, 44 insertions(+), 36 deletions(-) diff --git a/pf2rnaseq/factorization.py b/pf2rnaseq/factorization.py index b635c81..3c22fd0 100644 --- a/pf2rnaseq/factorization.py +++ b/pf2rnaseq/factorization.py @@ -349,28 +349,31 @@ def gradient(x): def deconvolution_cytokine_admm( A: np.ndarray, - alpha: float = 0.1, + alpha_h: float = 0.1, + alpha_w: float = 0.01, rho: float = 1.0, max_iter: int = 5000, - tol: float = 1e-4, + tol: float = 1e-4, # Single tolerance for both primal and dual random_state: int = 1, adaptive_rho: bool = True, ) -> tuple[np.ndarray, np.ndarray, dict]: """ - Decompose cytokine factor matrix using ADMM: A ≈ W @ H + Decompose cytokine factor matrix using ADMM: A ≈ W @ H Parameters ---------- A : np.ndarray Input matrix (n_cytokines, n_components) - alpha : float - L1 regularization strength (applied to both W and H) + alpha_h : float + L1 regularization for H + alpha_w : float + L1 regularization for W (off-diagonal only) rho : float ADMM penalty parameter max_iter : int - Maximum ADMM iterations + Maximum iterations tol : float - Convergence tolerance + Convergence tolerance for both primal and dual residuals random_state : int Random seed adaptive_rho : bool @@ -398,26 +401,29 @@ def deconvolution_cytokine_admm( print("Cytokine deconvolution with ADMM:") print(f" A shape: {A.shape}") - print(f" Alpha (L1 penalty): {alpha}") - print(f" Rho (ADMM penalty): {rho}") - print(f" Adaptive rho: {adaptive_rho}") + print(f" Alpha_W: {alpha_w}, Alpha_H: {alpha_h}") + print(f" Rho: {rho}") + print(f" Tolerance: {tol}") - # Create mask for off-diagonal elements off_diag_mask = ~np.eye(n_cytokines, dtype=bool) def soft_threshold(X, threshold): return np.sign(X) * np.maximum(np.abs(X) - threshold, 0) def update_W(H, Z_W, U_W, rho): - """Update W: solve (H@H^T + rho*I) W^T = (A@H^T + rho(Z_W - U_W))^T""" + """Update W: constrain diagonal to 1.0""" H_HT = H @ H.T A_HT = A @ H.T lhs = H_HT + rho * np.eye(n_cytokines) rhs = A_HT + rho * (Z_W - U_W) - return np.linalg.solve(lhs, rhs.T).T + + W_new = np.linalg.solve(lhs, rhs.T).T + np.fill_diagonal(W_new, 1.0) + + return W_new def update_H(W, Z_H, U_H, rho): - """Update H: solve (W^T@W + rho*I) H = W^T@A + rho(Z_H - U_H)""" + """Update H""" W_TW = W.T @ W W_TA = W.T @ A lhs = W_TW + rho * np.eye(n_cytokines) @@ -425,11 +431,10 @@ def update_H(W, Z_H, U_H, rho): return np.linalg.solve(lhs, rhs) def update_Z_W(W, U_W, alpha, rho): - """Update Z_W: soft-threshold off-diagonal, preserve diagonal""" + """Update Z_W: soft-threshold off-diagonal only""" X = W + U_W - Z_W_new = soft_threshold(X, alpha / rho) - # Restore diagonal (no L1 penalty on direct effects) - np.fill_diagonal(Z_W_new, np.diag(X)) + Z_W_new = X.copy() + Z_W_new[off_diag_mask] = soft_threshold(X[off_diag_mask], alpha / rho) return Z_W_new def update_Z_H(H, U_H, alpha, rho): @@ -448,27 +453,30 @@ def update_Z_H(H, U_H, alpha, rho): print("\nStarting ADMM iterations...") for iteration in range(max_iter): - # Store old Z values for dual residual computation Z_W_old = Z_W.copy() Z_H_old = Z_H.copy() # ADMM updates W = update_W(H, Z_W, U_W, rho) H = update_H(W, Z_H, U_H, rho) - Z_W = update_Z_W(W, U_W, alpha, rho) - Z_H = update_Z_H(H, U_H, alpha, rho) + Z_W = update_Z_W(W, U_W, alpha_w, rho) + Z_H = update_Z_H(H, U_H, alpha_h, rho) U_W = U_W + (W - Z_W) U_H = U_H + (H - Z_H) - # Compute residuals + # ===== SIMPLIFIED CONVERGENCE CHECK ===== + + # Primal residual: ||W - Z_W||² + ||H - Z_H||² r_norm = np.sqrt(np.sum((W - Z_W)**2) + np.sum((H - Z_H)**2)) - s_norm = rho * np.sqrt(np.sum((Z_W - Z_W_old)**2) + - np.sum((Z_H - Z_H_old)**2)) - # Compute objective (off-diagonal penalty for W, full penalty for H) + # Dual residual: ||ρ(Z_W - Z_W_old)||² + ||ρ(Z_H - Z_H_old)||² + s_norm = np.sqrt(np.sum((rho * (Z_W - Z_W_old))**2) + + np.sum((rho * (Z_H - Z_H_old))**2)) + + # Compute objective recon_error = np.sum((A - W @ H) ** 2) - l1_W = alpha * np.sum(np.abs(Z_W[off_diag_mask])) - l1_H = alpha * np.sum(np.abs(Z_H)) + l1_W = alpha_w * np.sum(np.abs(Z_W[off_diag_mask])) + l1_H = alpha_h * np.sum(np.abs(Z_H)) objective = recon_error + l1_W + l1_H # Track sparsity @@ -486,8 +494,7 @@ def update_Z_H(H, U_H, alpha, rho): # Print progress if iteration % 10 == 0 or iteration < 10: print(f" Iter {iteration:4d}: Obj={objective:.4e}, " - f"||r||={r_norm:.4e}, ||s||={s_norm:.4e}, " - f"W_sparse={w_sparsity:.2%}, H_sparse={h_sparsity:.2%}") + f"r={r_norm:.3e}, s={s_norm:.3e}, ρ={rho:.2f}") # Adaptive rho update if adaptive_rho and iteration > 0: @@ -495,17 +502,18 @@ def update_Z_H(H, U_H, alpha, rho): rho = rho * 2 U_W = U_W / 2 U_H = U_H / 2 + print(f" Increased ρ → {rho:.2f}") elif s_norm > 10 * r_norm: rho = rho / 2 U_W = U_W * 2 U_H = U_H * 2 + print(f" Decreased ρ → {rho:.2f}") - # Check convergence - eps_primal = tol * np.sqrt(W.size + H.size) - eps_dual = tol * np.sqrt(U_W.size + U_H.size) - - if r_norm < eps_primal and s_norm < eps_dual: - print(f"\nConverged at iteration {iteration}") + # Simple convergence check + if r_norm < tol and s_norm < tol: + print(f"\n✓ Converged at iteration {iteration}") + print(f" Primal residual: {r_norm:.4e} < {tol:.4e}") + print(f" Dual residual: {s_norm:.4e} < {tol:.4e}") break # Final statistics @@ -522,7 +530,7 @@ def update_Z_H(H, U_H, alpha, rho): print(f" Off-diagonal sparsity: {w_sparsity:.2%}") print(f" Off-diagonal non-zeros: {np.sum(np.abs(Z_W[off_diag_mask]) > 1e-3)}") print(f" Mean |W_offdiag|: {np.abs(Z_W[off_diag_mask]).mean():.4f}") - print(f" Diagonal mean: {np.abs(np.diag(Z_W)).mean():.4f}") + print(f" Diagonal: all 1.0 (constrained)") print(f"\n H (effect patterns):") print(f" Sparsity: {h_sparsity:.2%}") print(f" Non-zeros: {np.sum(np.abs(Z_H) > 1e-3)}/{Z_H.size}") From 5ddb3a64edc90ad665a17b2d388f34b1e73d6a89 Mon Sep 17 00:00:00 2001 From: Nicole Bedanova Date: Wed, 10 Dec 2025 14:35:57 -0800 Subject: [PATCH 3/5] ADMM application and non-negativity for W --- pf2rnaseq/factorization.py | 147 +++++++++++++++++---------- pf2rnaseq/figures/figureParseADMM.py | 93 +++++++++++++++++ 2 files changed, 184 insertions(+), 56 deletions(-) create mode 100644 pf2rnaseq/figures/figureParseADMM.py diff --git a/pf2rnaseq/factorization.py b/pf2rnaseq/factorization.py index 3c22fd0..bf60680 100644 --- a/pf2rnaseq/factorization.py +++ b/pf2rnaseq/factorization.py @@ -288,7 +288,11 @@ def gradient(x): # ===== Gradient w.r.t. W ===== # 1. Reconstruction term: ∂/∂W [||A - WH||²] = 2(error @ H^T), L1 penalty: ∂/∂W [α||W||₁] = α * sign(W) - grad_W = 2 * ((W @ H - A) @ H.T) + alpha * np.sign(W) - np.diag(alpha * np.sign(np.diag(W))) + grad_W = ( + 2 * ((W @ H - A) @ H.T) + + alpha * np.sign(W) + - np.diag(alpha * np.sign(np.diag(W))) + ) # ===== Gradient w.r.t. H ===== # 1. Reconstruction term: ∂/∂H [||A - WH||²] = 2(W^T @ error), L1 penalty: ∂/∂H [α||H||₁] = α * sign(H) @@ -347,19 +351,21 @@ def gradient(x): return W, H + def deconvolution_cytokine_admm( A: np.ndarray, alpha_h: float = 0.1, alpha_w: float = 0.01, rho: float = 1.0, - max_iter: int = 5000, - tol: float = 1e-4, # Single tolerance for both primal and dual + max_iter: int = 10000, + tol: float = 1e-4, random_state: int = 1, adaptive_rho: bool = True, + non_negative_w: bool = True, ) -> tuple[np.ndarray, np.ndarray, dict]: """ Decompose cytokine factor matrix using ADMM: A ≈ W @ H - + Parameters ---------- A : np.ndarray @@ -378,7 +384,9 @@ def deconvolution_cytokine_admm( Random seed adaptive_rho : bool Whether to adaptively adjust rho - + non_negative_w : bool + If True, enforce W ≥ 0 (cytokines only activate, not inhibit) + Returns ------- Z_W : np.ndarray @@ -390,7 +398,7 @@ def deconvolution_cytokine_admm( """ n_cytokines, n_components = A.shape np.random.seed(random_state) - + # Initialize W = np.eye(n_cytokines) H = A.copy() @@ -398,64 +406,83 @@ def deconvolution_cytokine_admm( Z_H = H.copy() U_W = np.zeros_like(W) U_H = np.zeros_like(H) - + print("Cytokine deconvolution with ADMM:") print(f" A shape: {A.shape}") print(f" Alpha_W: {alpha_w}, Alpha_H: {alpha_h}") print(f" Rho: {rho}") print(f" Tolerance: {tol}") - + print(f" Non-negative W: {non_negative_w}") + off_diag_mask = ~np.eye(n_cytokines, dtype=bool) - + def soft_threshold(X, threshold): return np.sign(X) * np.maximum(np.abs(X) - threshold, 0) - + def update_W(H, Z_W, U_W, rho): - """Update W: constrain diagonal to 1.0""" + """Update W: constrain diagonal to 1.0, optional non-negativity""" H_HT = H @ H.T A_HT = A @ H.T lhs = H_HT + rho * np.eye(n_cytokines) rhs = A_HT + rho * (Z_W - U_W) - + W_new = np.linalg.solve(lhs, rhs.T).T - np.fill_diagonal(W_new, 1.0) + # Non-negativity constraint for W + if non_negative_w: + W_new = np.maximum(W_new, 0) + + # Diagonal constraint + np.fill_diagonal(W_new, 1.0) + return W_new - + def update_H(W, Z_H, U_H, rho): - """Update H""" + """Update H: NO non-negativity constraint""" W_TW = W.T @ W W_TA = W.T @ A lhs = W_TW + rho * np.eye(n_cytokines) rhs = W_TA + rho * (Z_H - U_H) + return np.linalg.solve(lhs, rhs) - + def update_Z_W(W, U_W, alpha, rho): - """Update Z_W: soft-threshold off-diagonal only""" + """Update Z_W: soft-threshold off-diagonal, optional non-negativity""" X = W + U_W Z_W_new = X.copy() + + # Soft-threshold off-diagonal Z_W_new[off_diag_mask] = soft_threshold(X[off_diag_mask], alpha / rho) + + # Non-negativity constraint for W + if non_negative_w: + Z_W_new = np.maximum(Z_W_new, 0) + + # Diagonal constraint + np.fill_diagonal(Z_W_new, 1.0) + return Z_W_new - + def update_Z_H(H, U_H, alpha, rho): - """Update Z_H: soft-threshold entire matrix""" + """Update Z_H: soft-threshold, NO non-negativity""" + # H can be negative return soft_threshold(H + U_H, alpha / rho) - + history = { - 'objective': [], - 'primal_residual': [], - 'dual_residual': [], - 'rho': [], - 'w_sparsity': [], - 'h_sparsity': [] + "objective": [], + "primal_residual": [], + "dual_residual": [], + "rho": [], + "w_sparsity": [], + "h_sparsity": [], } - + print("\nStarting ADMM iterations...") - + for iteration in range(max_iter): Z_W_old = Z_W.copy() Z_H_old = Z_H.copy() - + # ADMM updates W = update_W(H, Z_W, U_W, rho) H = update_H(W, Z_H, U_H, rho) @@ -463,39 +490,40 @@ def update_Z_H(H, U_H, alpha, rho): Z_H = update_Z_H(H, U_H, alpha_h, rho) U_W = U_W + (W - Z_W) U_H = U_H + (H - Z_H) - - # ===== SIMPLIFIED CONVERGENCE CHECK ===== - + # Primal residual: ||W - Z_W||² + ||H - Z_H||² - r_norm = np.sqrt(np.sum((W - Z_W)**2) + np.sum((H - Z_H)**2)) - + r_norm = np.sqrt(np.sum((W - Z_W) ** 2) + np.sum((H - Z_H) ** 2)) + # Dual residual: ||ρ(Z_W - Z_W_old)||² + ||ρ(Z_H - Z_H_old)||² - s_norm = np.sqrt(np.sum((rho * (Z_W - Z_W_old))**2) + - np.sum((rho * (Z_H - Z_H_old))**2)) - + s_norm = np.sqrt( + np.sum((rho * (Z_W - Z_W_old)) ** 2) + np.sum((rho * (Z_H - Z_H_old)) ** 2) + ) + # Compute objective recon_error = np.sum((A - W @ H) ** 2) l1_W = alpha_w * np.sum(np.abs(Z_W[off_diag_mask])) l1_H = alpha_h * np.sum(np.abs(Z_H)) objective = recon_error + l1_W + l1_H - + # Track sparsity w_sparsity = np.sum(np.abs(Z_W[off_diag_mask]) < 1e-3) / np.sum(off_diag_mask) h_sparsity = np.sum(np.abs(Z_H) < 1e-3) / Z_H.size - + # Store history - history['objective'].append(objective) - history['primal_residual'].append(r_norm) - history['dual_residual'].append(s_norm) - history['rho'].append(rho) - history['w_sparsity'].append(w_sparsity) - history['h_sparsity'].append(h_sparsity) - + history["objective"].append(objective) + history["primal_residual"].append(r_norm) + history["dual_residual"].append(s_norm) + history["rho"].append(rho) + history["w_sparsity"].append(w_sparsity) + history["h_sparsity"].append(h_sparsity) + # Print progress if iteration % 10 == 0 or iteration < 10: - print(f" Iter {iteration:4d}: Obj={objective:.4e}, " - f"r={r_norm:.3e}, s={s_norm:.3e}, ρ={rho:.2f}") - + print( + f" Iter {iteration:4d}: Obj={objective:.4e}, " + f"r={r_norm:.3e}, s={s_norm:.3e}, ρ={rho:.2f}" + ) + # Adaptive rho update if adaptive_rho and iteration > 0: if r_norm > 10 * s_norm: @@ -508,32 +536,39 @@ def update_Z_H(H, U_H, alpha, rho): U_W = U_W * 2 U_H = U_H * 2 print(f" Decreased ρ → {rho:.2f}") - + # Simple convergence check if r_norm < tol and s_norm < tol: print(f"\n✓ Converged at iteration {iteration}") print(f" Primal residual: {r_norm:.4e} < {tol:.4e}") print(f" Dual residual: {s_norm:.4e} < {tol:.4e}") break - + # Final statistics A_recon = W @ H rel_error = np.linalg.norm(A - A_recon, "fro") / np.linalg.norm(A, "fro") - + w_sparsity = np.sum(np.abs(Z_W[off_diag_mask]) < 1e-3) / np.sum(off_diag_mask) h_sparsity = np.sum(np.abs(Z_H) < 1e-3) / Z_H.size - + print("\nOptimization complete:") print(f" Iterations: {iteration + 1}/{max_iter}") print(f" Relative reconstruction error: {rel_error:.4%}") - print(f"\n W (cytokine interactions):") + + print("\n W (cytokine interactions):") print(f" Off-diagonal sparsity: {w_sparsity:.2%}") print(f" Off-diagonal non-zeros: {np.sum(np.abs(Z_W[off_diag_mask]) > 1e-3)}") print(f" Mean |W_offdiag|: {np.abs(Z_W[off_diag_mask]).mean():.4f}") + print(f" Min value: {W.min():.4f}") # Check non-negativity + print(f" Max value: {W.max():.4f}") print(f" Diagonal: all 1.0 (constrained)") - print(f"\n H (effect patterns):") + + print("\n H (effect patterns):") print(f" Sparsity: {h_sparsity:.2%}") print(f" Non-zeros: {np.sum(np.abs(Z_H) > 1e-3)}/{Z_H.size}") print(f" Mean |H|: {np.abs(Z_H).mean():.4f}") - + print(f" Min value: {H.min():.4f}") # Can be negative + print(f" Max value: {H.max():.4f}") + print(f" Negative values: {np.sum(H < 0)} ({100*np.sum(H < 0)/H.size:.1f}%)") + return Z_W, Z_H, history \ No newline at end of file diff --git a/pf2rnaseq/figures/figureParseADMM.py b/pf2rnaseq/figures/figureParseADMM.py new file mode 100644 index 0000000..958da4b --- /dev/null +++ b/pf2rnaseq/figures/figureParseADMM.py @@ -0,0 +1,93 @@ +""" +Parse data: Plotting factors +""" + +import numpy as np +import pandas as pd +import seaborn as sns +from anndata import read_h5ad +from matplotlib import pyplot as plt + +from ..factorization import correct_conditions, deconvolution_cytokine_admm +from .common import getSetup, subplotLabel +from .commonFuncs.plotFactors import ( + plot_condition_factors, +) + + +def samples_only(X) -> pd.DataFrame: + """Obtain samples once only with corresponding observations""" + samples = X.obs + df_samples = samples.drop_duplicates(subset="condition_unique_idxs") + df_samples = df_samples.sort_values("condition_unique_idxs") + return df_samples + + +def makeFigure(): + """Get a list of the axis objects and create a figure.""" + # Get list of axis objects + ax, f = getSetup((25, 15), (1, 3)) + + # Add subplot labels + subplotLabel(ax) + + # Load data + X = read_h5ad("/home/nicoleb/ParsePf2_100_D11_filt.h5ad") + X.uns["Pf2_A"] = correct_conditions(X) + A = X.uns["Pf2_A"] + + # Center A by cytokine medians + cytokine_medians = np.median(A, axis=1, keepdims=True) + A_centered = A - cytokine_medians + X.uns["Pf2_A"] = A_centered + + W, H, _ = deconvolution_cytokine_admm(A_centered, alpha_h=0.05, alpha_w=0.05, rho=2) + + # Get cytokine names in correct order + samples_df = samples_only(X) + + # Create deconvolved version for plotting + X_deconv = X.copy() + X_deconv.uns["Pf2_A"] = H # Use primary effects only + + plot_condition_factors( + X_deconv, + ax[0], + samples_df["cytokine"], + groupConditions=True, + cond="cytokine", + log_scale=False, + ) + ax[0].set_title("Deconvolved matrix (H)", fontsize=12, fontweight="bold") + + #Plot original median subtracted factor matrix for reference + plot_condition_factors( + X, + ax[1], + samples_df["cytokine"], + groupConditions=True, + cond="cytokine", + log_scale=False, + ) + ax[1].set_title("Original Effects (A)", fontsize=12, fontweight="bold") + + cytokine_names = samples_df["cytokine"].values + + # Plot 2: W heatmap (primary effects) + sns.heatmap( + W, + ax=ax[2], + cmap="YlOrRd", + robust=False, + square=True, + cbar_kws={"label": "Signaling Strength"}, + xticklabels=cytokine_names, + yticklabels=cytokine_names, + ) + ax[2].set_title("Cytokine Signaling (W)", fontsize=12, fontweight="bold") + ax[2].set_xlabel("Inducing Cytokine →", fontsize=10) + ax[2].set_ylabel("← Induced Cytokine", fontsize=10) + plt.setp(ax[2].get_xticklabels(), rotation=90, ha="center", fontsize=6) + plt.setp(ax[2].get_yticklabels(), rotation=0, fontsize=6) + + return f From 83f200348e96729a22f75d51ee5ba218398e13b5 Mon Sep 17 00:00:00 2001 From: Aaron Meyer Date: Wed, 10 Dec 2025 20:45:13 -0800 Subject: [PATCH 4/5] ruff format --- pf2rnaseq/factorization.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/pf2rnaseq/factorization.py b/pf2rnaseq/factorization.py index bf60680..8016699 100644 --- a/pf2rnaseq/factorization.py +++ b/pf2rnaseq/factorization.py @@ -358,10 +358,10 @@ def deconvolution_cytokine_admm( alpha_w: float = 0.01, rho: float = 1.0, max_iter: int = 10000, - tol: float = 1e-4, + tol: float = 1e-4, random_state: int = 1, adaptive_rho: bool = True, - non_negative_w: bool = True, + non_negative_w: bool = True, ) -> tuple[np.ndarray, np.ndarray, dict]: """ Decompose cytokine factor matrix using ADMM: A ≈ W @ H @@ -427,11 +427,11 @@ def update_W(H, Z_W, U_W, rho): rhs = A_HT + rho * (Z_W - U_W) W_new = np.linalg.solve(lhs, rhs.T).T - + # Non-negativity constraint for W if non_negative_w: W_new = np.maximum(W_new, 0) - + # Diagonal constraint np.fill_diagonal(W_new, 1.0) @@ -443,24 +443,24 @@ def update_H(W, Z_H, U_H, rho): W_TA = W.T @ A lhs = W_TW + rho * np.eye(n_cytokines) rhs = W_TA + rho * (Z_H - U_H) - + return np.linalg.solve(lhs, rhs) def update_Z_W(W, U_W, alpha, rho): """Update Z_W: soft-threshold off-diagonal, optional non-negativity""" X = W + U_W Z_W_new = X.copy() - + # Soft-threshold off-diagonal Z_W_new[off_diag_mask] = soft_threshold(X[off_diag_mask], alpha / rho) - + # Non-negativity constraint for W if non_negative_w: Z_W_new = np.maximum(Z_W_new, 0) - + # Diagonal constraint np.fill_diagonal(Z_W_new, 1.0) - + return Z_W_new def update_Z_H(H, U_H, alpha, rho): @@ -554,7 +554,7 @@ def update_Z_H(H, U_H, alpha, rho): print("\nOptimization complete:") print(f" Iterations: {iteration + 1}/{max_iter}") print(f" Relative reconstruction error: {rel_error:.4%}") - + print("\n W (cytokine interactions):") print(f" Off-diagonal sparsity: {w_sparsity:.2%}") print(f" Off-diagonal non-zeros: {np.sum(np.abs(Z_W[off_diag_mask]) > 1e-3)}") @@ -562,13 +562,13 @@ def update_Z_H(H, U_H, alpha, rho): print(f" Min value: {W.min():.4f}") # Check non-negativity print(f" Max value: {W.max():.4f}") print(f" Diagonal: all 1.0 (constrained)") - + print("\n H (effect patterns):") print(f" Sparsity: {h_sparsity:.2%}") print(f" Non-zeros: {np.sum(np.abs(Z_H) > 1e-3)}/{Z_H.size}") print(f" Mean |H|: {np.abs(Z_H).mean():.4f}") print(f" Min value: {H.min():.4f}") # Can be negative print(f" Max value: {H.max():.4f}") - print(f" Negative values: {np.sum(H < 0)} ({100*np.sum(H < 0)/H.size:.1f}%)") + print(f" Negative values: {np.sum(H < 0)} ({100 * np.sum(H < 0) / H.size:.1f}%)") - return Z_W, Z_H, history \ No newline at end of file + return Z_W, Z_H, history From aba71d4566c8152f5fe37e117ca143abd2bdac0b Mon Sep 17 00:00:00 2001 From: Aaron Meyer Date: Wed, 10 Dec 2025 21:00:44 -0800 Subject: [PATCH 5/5] Add two simple tests for ADMM approach --- pf2rnaseq/factorization.py | 5 +- pf2rnaseq/figures/figureParseADMM.py | 2 +- .../tests/test_cytokine_deconvolution.py | 149 ++++++++++++++++++ 3 files changed, 153 insertions(+), 3 deletions(-) create mode 100644 pf2rnaseq/tests/test_cytokine_deconvolution.py diff --git a/pf2rnaseq/factorization.py b/pf2rnaseq/factorization.py index 8016699..ec60b1f 100644 --- a/pf2rnaseq/factorization.py +++ b/pf2rnaseq/factorization.py @@ -400,8 +400,9 @@ def deconvolution_cytokine_admm( np.random.seed(random_state) # Initialize - W = np.eye(n_cytokines) - H = A.copy() + # W initialized as identity, H is original A + W = np.random.rand(n_cytokines, n_cytokines) * 0.1 + np.eye(n_cytokines) + H = np.random.rand(n_cytokines, n_components) * np.mean(np.abs(A)) Z_W = W.copy() Z_H = H.copy() U_W = np.zeros_like(W) diff --git a/pf2rnaseq/figures/figureParseADMM.py b/pf2rnaseq/figures/figureParseADMM.py index 958da4b..1e1d941 100644 --- a/pf2rnaseq/figures/figureParseADMM.py +++ b/pf2rnaseq/figures/figureParseADMM.py @@ -60,7 +60,7 @@ def makeFigure(): ) ax[0].set_title("Deconvolved matrix (H)", fontsize=12, fontweight="bold") - #Plot original median subtracted factor matrix for reference + # Plot original median subtracted factor matrix for reference plot_condition_factors( X, ax[1], diff --git a/pf2rnaseq/tests/test_cytokine_deconvolution.py b/pf2rnaseq/tests/test_cytokine_deconvolution.py new file mode 100644 index 0000000..9a4f1a5 --- /dev/null +++ b/pf2rnaseq/tests/test_cytokine_deconvolution.py @@ -0,0 +1,149 @@ +""" +Test the cytokine deconvolution method. +""" + +import numpy as np +import pytest + +from ..factorization import deconvolution_cytokine_admm + + +def test_deconvolution_cytokine_admm_sparse(): + """ + Test deconvolution_cytokine_admm with sparse ground truth matrices. + + This test generates sparse W (cytokine interaction) and H (effect basis) matrices, + computes A = W @ H, and verifies that the deconvolution recovers the structure. + """ + np.random.seed(42) + + # Dimensions + n_cytokines = 8 + n_components = 12 + + # Generate sparse ground truth W (cytokine interaction matrix) + # W should have 1s on diagonal and sparse off-diagonal elements + W_true = np.eye(n_cytokines) + + # Add sparse off-diagonal interactions (only 20% of off-diagonal elements) + off_diag_mask = ~np.eye(n_cytokines, dtype=bool) + n_off_diag = np.sum(off_diag_mask) + n_nonzero_w = int(0.2 * n_off_diag) + + # Randomly select positions for non-zero off-diagonal elements + off_diag_positions = np.where(off_diag_mask) + nonzero_indices = np.random.choice(n_off_diag, n_nonzero_w, replace=False) + + for idx in nonzero_indices: + i, j = off_diag_positions[0][idx], off_diag_positions[1][idx] + # Use small positive values for cytokine interactions + W_true[i, j] = np.random.uniform(0.1, 0.5) + + # Generate sparse ground truth H (effect basis matrix) + # H should have about 30% non-zero elements + H_true = np.zeros((n_cytokines, n_components)) + n_nonzero_h = int(0.3 * n_cytokines * n_components) + + for _ in range(n_nonzero_h): + i = np.random.randint(0, n_cytokines) + j = np.random.randint(0, n_components) + # H can have both positive and negative values + H_true[i, j] = np.random.uniform(-2.0, 2.0) + + # Compute the observed matrix A + A = W_true @ H_true + + # Add small noise + noise_level = 0.01 + A_noisy = A + noise_level * np.random.randn(n_cytokines, n_components) + + # Run deconvolution + W_recovered, H_recovered, history = deconvolution_cytokine_admm( + A_noisy, + alpha_h=0.1, + alpha_w=0.05, + rho=1.0, + max_iter=1000, + tol=1e-6, + random_state=42, + adaptive_rho=True, + non_negative_w=True, + ) + + # Verify shapes + assert W_recovered.shape == (n_cytokines, n_cytokines) + assert H_recovered.shape == (n_cytokines, n_components) + + # Verify diagonal of W is constrained to 1 + np.testing.assert_allclose(np.diag(W_recovered), np.ones(n_cytokines), atol=1e-10) + + # Verify non-negativity of W + assert np.all(W_recovered >= -1e-10), "W should be non-negative" + + # Verify reconstruction quality + A_reconstructed = W_recovered @ H_recovered + reconstruction_error = np.linalg.norm( + A_noisy - A_reconstructed, "fro" + ) / np.linalg.norm(A_noisy, "fro") + assert reconstruction_error < 0.1, ( + f"Reconstruction error too high: {reconstruction_error}" + ) + + # Verify sparsity of W (off-diagonal should be sparse) + w_sparsity = np.sum(np.abs(W_recovered[off_diag_mask]) < 1e-3) / np.sum( + off_diag_mask + ) + assert w_sparsity > 0.5, f"W should be sparse, but sparsity is only {w_sparsity}" + + # Verify sparsity of H + h_sparsity = np.sum(np.abs(H_recovered) < 1e-3) / H_recovered.size + assert h_sparsity > 0.3, f"H should be sparse, but sparsity is only {h_sparsity}" + + # Verify history contains expected keys + assert "objective" in history + assert "primal_residual" in history + assert "dual_residual" in history + assert "rho" in history + assert "w_sparsity" in history + assert "h_sparsity" in history + + # Verify objective decreases (generally) + assert len(history["objective"]) > 0 + # Check that final objective is lower than initial (with some tolerance for fluctuations) + initial_obj = history["objective"][0] + final_obj = history["objective"][-1] + assert final_obj < initial_obj * 1.1, "Objective should generally decrease" + + print("\nTest passed!") + print(f"Reconstruction error: {reconstruction_error:.4f}") + print(f"W off-diagonal sparsity: {w_sparsity:.2%}") + print(f"H sparsity: {h_sparsity:.2%}") + print(f"Converged in {len(history['objective'])} iterations") + + +def test_deconvolution_cytokine_admm_small(): + """ + Test with a small problem to ensure basic functionality. + """ + np.random.seed(999) + + n_cytokines = 3 + n_components = 5 + + # Simple test matrix + A = np.random.randn(n_cytokines, n_components) + + # Run with default parameters + W, H, history = deconvolution_cytokine_admm( + A, max_iter=100, tol=1e-6, random_state=999 + ) + + # Basic checks + assert W.shape == (n_cytokines, n_cytokines) + assert H.shape == (n_cytokines, n_components) + assert len(history["objective"]) > 0 + + # Verify diagonal constraint + np.testing.assert_allclose(np.diag(W), np.ones(n_cytokines), atol=1e-10) + + print("\nSmall problem test passed!")