-
Notifications
You must be signed in to change notification settings - Fork 0
Parse data #30
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Parse data #30
Changes from 3 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
|
|
@@ -288,7 +288,11 @@ def gradient(x): | |||
|
|
||||
| # ===== Gradient w.r.t. W ===== | ||||
| # 1. Reconstruction term: ∂/∂W [||A - WH||²] = 2(error @ H^T), L1 penalty: ∂/∂W [α||W||₁] = α * sign(W) | ||||
| grad_W = 2 * ((W @ H - A) @ H.T) + alpha * np.sign(W) - np.diag(alpha * np.sign(np.diag(W))) | ||||
| grad_W = ( | ||||
| 2 * ((W @ H - A) @ H.T) | ||||
| + alpha * np.sign(W) | ||||
| - np.diag(alpha * np.sign(np.diag(W))) | ||||
| ) | ||||
|
|
||||
| # ===== Gradient w.r.t. H ===== | ||||
| # 1. Reconstruction term: ∂/∂H [||A - WH||²] = 2(W^T @ error), L1 penalty: ∂/∂H [α||H||₁] = α * sign(H) | ||||
|
|
@@ -346,3 +350,225 @@ def gradient(x): | |||
| print(f" Non-zeros: {np.sum(np.abs(H) > 1e-3)}/{H.size}") | ||||
|
|
||||
| return W, H | ||||
|
|
||||
|
|
||||
| def deconvolution_cytokine_admm( | ||||
| A: np.ndarray, | ||||
| alpha_h: float = 0.1, | ||||
| alpha_w: float = 0.01, | ||||
| rho: float = 1.0, | ||||
| max_iter: int = 10000, | ||||
| tol: float = 1e-4, | ||||
| random_state: int = 1, | ||||
| adaptive_rho: bool = True, | ||||
| non_negative_w: bool = True, | ||||
nbedanova marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||
| ) -> tuple[np.ndarray, np.ndarray, dict]: | ||||
| """ | ||||
| Decompose cytokine factor matrix using ADMM: A ≈ W @ H | ||||
|
|
||||
| Parameters | ||||
| ---------- | ||||
| A : np.ndarray | ||||
| Input matrix (n_cytokines, n_components) | ||||
| alpha_h : float | ||||
| L1 regularization for H | ||||
| alpha_w : float | ||||
| L1 regularization for W (off-diagonal only) | ||||
| rho : float | ||||
| ADMM penalty parameter | ||||
| max_iter : int | ||||
| Maximum iterations | ||||
| tol : float | ||||
| Convergence tolerance for both primal and dual residuals | ||||
| random_state : int | ||||
| Random seed | ||||
| adaptive_rho : bool | ||||
| Whether to adaptively adjust rho | ||||
| non_negative_w : bool | ||||
| If True, enforce W ≥ 0 (cytokines only activate, not inhibit) | ||||
|
|
||||
| Returns | ||||
| ------- | ||||
| Z_W : np.ndarray | ||||
| Cytokine interaction matrix (n_cytokines, n_cytokines) | ||||
| Z_H : np.ndarray | ||||
| Effect basis matrix (n_cytokines, n_components) | ||||
| history : dict | ||||
| Optimization history | ||||
| """ | ||||
| n_cytokines, n_components = A.shape | ||||
| np.random.seed(random_state) | ||||
|
|
||||
| # Initialize | ||||
| W = np.eye(n_cytokines) | ||||
| H = A.copy() | ||||
| Z_W = W.copy() | ||||
| Z_H = H.copy() | ||||
| U_W = np.zeros_like(W) | ||||
| U_H = np.zeros_like(H) | ||||
|
|
||||
| print("Cytokine deconvolution with ADMM:") | ||||
| print(f" A shape: {A.shape}") | ||||
| print(f" Alpha_W: {alpha_w}, Alpha_H: {alpha_h}") | ||||
| print(f" Rho: {rho}") | ||||
| print(f" Tolerance: {tol}") | ||||
| print(f" Non-negative W: {non_negative_w}") | ||||
|
|
||||
| off_diag_mask = ~np.eye(n_cytokines, dtype=bool) | ||||
|
|
||||
| def soft_threshold(X, threshold): | ||||
| return np.sign(X) * np.maximum(np.abs(X) - threshold, 0) | ||||
|
|
||||
| def update_W(H, Z_W, U_W, rho): | ||||
| """Update W: constrain diagonal to 1.0, optional non-negativity""" | ||||
| H_HT = H @ H.T | ||||
| A_HT = A @ H.T | ||||
| lhs = H_HT + rho * np.eye(n_cytokines) | ||||
| rhs = A_HT + rho * (Z_W - U_W) | ||||
|
|
||||
| W_new = np.linalg.solve(lhs, rhs.T).T | ||||
|
|
||||
nbedanova marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||
| # Non-negativity constraint for W | ||||
| if non_negative_w: | ||||
| W_new = np.maximum(W_new, 0) | ||||
|
|
||||
|
||||
nbedanova marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
Outdated
Copilot
AI
Dec 10, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Trailing whitespace detected at the end of the line. Remove the trailing spaces for cleaner code formatting.
nbedanova marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
nbedanova marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
nbedanova marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
nbedanova marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,93 @@ | ||
| """ | ||
| Parse data: Plotting factors | ||
| """ | ||
|
|
||
| import numpy as np | ||
| import pandas as pd | ||
| import seaborn as sns | ||
| from anndata import read_h5ad | ||
| from matplotlib import pyplot as plt | ||
|
|
||
| from ..factorization import correct_conditions, deconvolution_cytokine_admm | ||
| from .common import getSetup, subplotLabel | ||
| from .commonFuncs.plotFactors import ( | ||
| plot_condition_factors, | ||
| ) | ||
|
|
||
|
|
||
| def samples_only(X) -> pd.DataFrame: | ||
| """Obtain samples once only with corresponding observations""" | ||
| samples = X.obs | ||
| df_samples = samples.drop_duplicates(subset="condition_unique_idxs") | ||
| df_samples = df_samples.sort_values("condition_unique_idxs") | ||
| return df_samples | ||
|
|
||
|
|
||
| def makeFigure(): | ||
| """Get a list of the axis objects and create a figure.""" | ||
| # Get list of axis objects | ||
| ax, f = getSetup((25, 15), (1, 3)) | ||
|
|
||
| # Add subplot labels | ||
| subplotLabel(ax) | ||
|
|
||
| # Load data | ||
| X = read_h5ad("/home/nicoleb/ParsePf2_100_D11_filt.h5ad") | ||
nbedanova marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| X.uns["Pf2_A"] = correct_conditions(X) | ||
| A = X.uns["Pf2_A"] | ||
|
|
||
| # Center A by cytokine medians | ||
| cytokine_medians = np.median(A, axis=1, keepdims=True) | ||
| A_centered = A - cytokine_medians | ||
| X.uns["Pf2_A"] = A_centered | ||
|
|
||
| W, H, _ = deconvolution_cytokine_admm(A_centered, alpha_h=0.05, alpha_w=0.05, rho=2) | ||
|
|
||
| # Get cytokine names in correct order | ||
| samples_df = samples_only(X) | ||
|
|
||
| # Create deconvolved version for plotting | ||
| X_deconv = X.copy() | ||
| X_deconv.uns["Pf2_A"] = H # Use primary effects only | ||
|
|
||
| plot_condition_factors( | ||
| X_deconv, | ||
| ax[0], | ||
| samples_df["cytokine"], | ||
| groupConditions=True, | ||
| cond="cytokine", | ||
| log_scale=False, | ||
| ) | ||
| ax[0].set_title("Deconvolved matrix (H)", fontsize=12, fontweight="bold") | ||
|
|
||
| #Plot original median subtracted factor matrix for reference | ||
nbedanova marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| plot_condition_factors( | ||
| X, | ||
| ax[1], | ||
| samples_df["cytokine"], | ||
| groupConditions=True, | ||
| cond="cytokine", | ||
| log_scale=False, | ||
| ) | ||
| ax[1].set_title("Original Effects (A)", fontsize=12, fontweight="bold") | ||
|
|
||
| cytokine_names = samples_df["cytokine"].values | ||
|
|
||
| # Plot 2: W heatmap (primary effects) | ||
| sns.heatmap( | ||
| W, | ||
| ax=ax[2], | ||
| cmap="YlOrRd", | ||
| robust=False, | ||
| square=True, | ||
| cbar_kws={"label": "Signaling Strength"}, | ||
| xticklabels=cytokine_names, | ||
| yticklabels=cytokine_names, | ||
| ) | ||
| ax[2].set_title("Cytokine Signaling (W)", fontsize=12, fontweight="bold") | ||
| ax[2].set_xlabel("Inducing Cytokine →", fontsize=10) | ||
| ax[2].set_ylabel("← Induced Cytokine", fontsize=10) | ||
| plt.setp(ax[2].get_xticklabels(), rotation=90, ha="center", fontsize=6) | ||
| plt.setp(ax[2].get_yticklabels(), rotation=0, fontsize=6) | ||
|
|
||
| return f | ||
Uh oh!
There was an error while loading. Please reload this page.