Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 154 additions & 5 deletions pf2rnaseq/factorization.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import scipy.sparse as sps
from pacmap import PaCMAP
from parafac2.parafac2 import parafac2_nd, store_pf2
from scipy.optimize import minimize
from scipy.stats import gmean
from sklearn.decomposition import PCA
from sklearn.linear_model import LinearRegression
Expand All @@ -17,14 +18,16 @@
def correct_conditions(X: anndata.AnnData):
"""Correct the conditions factors by overall read depth. Ensures that weighting is not affected by cell count difference"""
sgIndex = X.obs["condition_unique_idxs"]
#sgIndex = X.obs["condition_unique_idxs"].cat.codes
# sgIndex = X.obs["condition_unique_idxs"].cat.codes
counts = np.zeros((np.amax(sgIndex) + 1, 1))
min_val = np.min(X.uns["Pf2_A"])
if min_val < 0:
# Add the absolute value of the minimum (plus a small epsilon) to make all values positive
X.uns["Pf2_A"] = X.uns["Pf2_A"] + abs(min_val) + 1e-10
print(f"Warning: Found negative values in Pf2_A (min: {min_val:.6f}). Added {abs(min_val) + 1e-10:.6f} to all values.")

print(
f"Warning: Found negative values in Pf2_A (min: {min_val:.6f}). Added {abs(min_val) + 1e-10:.6f} to all values."
)

cond_mean = gmean(X.uns["Pf2_A"], axis=1)

x_count = X.X.sum(axis=1)
Expand All @@ -50,13 +53,11 @@ def pf2(
):
cupy.cuda.Device(0).use()
pf_out, R2X = parafac2_nd(

X,
rank=rank,
random_state=random_state,
tol=tolerance,
n_iter_max=500,

)

X = store_pf2(X, pf_out)
Expand Down Expand Up @@ -197,3 +198,151 @@ def fms_diff_ranks(
)

return df


def deconvolution_cytokine(
A: np.ndarray,
alpha: float = 0.1,
max_iter: int = 5000,
random_state: int = 1,
) -> tuple[np.ndarray, np.ndarray]:
"""
Decompose cytokine factor matrix: A ≈ W @ H

This decomposes observed cytokine effects into:
1. Direct primary effects (H)
2. Induced effects via other cytokines (W)

Parameters
----------
A : np.ndarray
Input matrix (n_cytokines, n_components)
Example: (91 cytokines, 100 Parafac2 components)
alpha : float
Regularization strength
max_iter : int
Maximum optimization iterations
random_state : int
Random seed

Returns
-------
W : np.ndarray
Cytokine interaction matrix (n_cytokines, n_cytokines)
W[i, j] = total contribution of cytokine j to observed effect of i
Diagonal W[i,i] = direct effect of cytokine i
H : np.ndarray
Effect basis matrix (n_cytokines, n_components)
H[:, j] = cytokine effects for component j without indirect contributions
"""
n_cytokines, n_components = A.shape

np.random.seed(random_state)

# W initialized as identity, H is original A
W_init = np.eye(n_cytokines)
H_init = A.copy()

x0 = np.concatenate([W_init.ravel(), H_init.ravel()])

print("Cytokine deconvolution:")
print(f" A shape: {A.shape} (cytokines × components)")
print(f" W shape: ({n_cytokines}, {n_cytokines}) (cytokine interactions)")
print(f" H shape: ({n_cytokines}, {n_components}) (effect basis)")

w_size = n_cytokines * n_cytokines
iteration_counter = [0]
best_loss = [np.inf]

def objective(x):
W = x[:w_size].reshape(n_cytokines, n_cytokines)
H = x[w_size:].reshape(n_cytokines, n_components)

# Reconstruction:A ≈ W @ H

reconstruction = W @ H
mse = np.sum((A - reconstruction) ** 2)

# Regularization: L1 penalty on both W and H
# Exclude diagonal of W from L1 penalty
l1_W = alpha * np.sum(np.abs(W)) - alpha * np.diag(np.abs(W)).sum()
l1_H = alpha * np.sum(np.abs(H))

total_loss = mse + l1_W + l1_H

iteration_counter[0] += 1
if total_loss < best_loss[0]:
best_loss[0] = total_loss

if iteration_counter[0] % 10 == 0:
print(
f" Iter {iteration_counter[0]}: Loss={total_loss:.4f} "
f"(MSE={mse:.4f}, L1_W={l1_W:.4f}, L1_H={l1_H:.4f})"
)

return total_loss

def gradient(x):
W = x[:w_size].reshape(n_cytokines, n_cytokines)
H = x[w_size:].reshape(n_cytokines, n_components)

# ===== Gradient w.r.t. W =====
# 1. Reconstruction term: ∂/∂W [||A - WH||²] = 2(error @ H^T), L1 penalty: ∂/∂W [α||W||₁] = α * sign(W)
grad_W = 2 * ((W @ H - A) @ H.T) + alpha * np.sign(W) - np.diag(alpha * np.sign(np.diag(W)))

# ===== Gradient w.r.t. H =====
# 1. Reconstruction term: ∂/∂H [||A - WH||²] = 2(W^T @ error), L1 penalty: ∂/∂H [α||H||₁] = α * sign(H)
grad_H = 2 * (W.T @ (W @ H - A)) + alpha * np.sign(H)

return np.concatenate([grad_W.ravel(), grad_H.ravel()])

print("\nStarting optimization...")

result = minimize(
fun=objective,
x0=x0,
method="L-BFGS-B",
jac=gradient,
options={"maxiter": max_iter, "disp": True},
)

W = result.x[:w_size].reshape(n_cytokines, n_cytokines)
H = result.x[w_size:].reshape(n_cytokines, n_components)

# Evaluate

A_recon = W @ H

recon_error = np.linalg.norm(A - A_recon, "fro")
rel_error = recon_error / np.linalg.norm(A, "fro")

# Statistics for W
w_sparsity = np.sum(np.abs(W) < 1e-3) / W.size
w_mean = np.abs(W).mean()
w_max = np.abs(W).max()

# Statistics for H
h_sparsity = np.sum(np.abs(H) < 1e-3) / H.size
h_mean = np.abs(H).mean()
h_max = np.abs(H).max()

print("\nOptimization complete:")
print(f" Success: {result.success}")
print(f" Iterations: {result.nit}")
print(f" Relative reconstruction error: {rel_error:.4%}")

print("\n W (cytokine interactions):")
print(f" Shape: {W.shape}")
print(f" Sparsity: {w_sparsity:.2%} (near-zero elements)")
print(f" Mean |W|: {w_mean:.4f}")
print(f" Max |W|: {w_max:.4f}")
print(f" Non-zeros: {np.sum(np.abs(W) > 1e-3)}/{W.size}")

print("\n H (effect patterns):")
print(f" Shape: {H.shape}")
print(f" Sparsity: {h_sparsity:.2%} (near-zero elements)")
print(f" Mean |H|: {h_mean:.4f}")
print(f" Max |H|: {h_max:.4f}")
print(f" Non-zeros: {np.sum(np.abs(H) > 1e-3)}/{H.size}")

return W, H
7 changes: 3 additions & 4 deletions pf2rnaseq/figures/commonFuncs/plotFactors.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ def plot_condition_factors(
X = np.log10(X)

X -= np.median(X, axis=0)
X /= np.std(X, axis=0)
X /= np.std(X, axis=0) + 1e-3
ind = reorder_table(X + 1e-3)

ind = reorder_table(X)
X = X[ind]
yt = yt.iloc[ind]

Expand Down Expand Up @@ -67,7 +67,7 @@ def plot_condition_factors(
)
)
# add a little legend
ax.legend(handles=legend_elements, bbox_to_anchor=(0, 1.3))
# ax.legend(handles=legend_elements, bbox_to_anchor=(0, 1.3))

xticks = np.arange(1, X.shape[1] + 1)

Expand Down Expand Up @@ -584,7 +584,6 @@ def plot_comp_weights(

# Add legend for color coding (only if lowest are included)
if include_lowest:

legend_elements = [
Patch(facecolor="darkred", label=f"Top {top_n} Highest"),
Patch(facecolor="darkblue", label=f"Top {top_n} Lowest"),
Expand Down
1 change: 0 additions & 1 deletion pf2rnaseq/figures/figureHeiserCompPac.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
Weighted projections per component in PaCMAP and boxplot of cell types
"""


import numpy as np

from ..factorization import correct_conditions, pf2
Expand Down
86 changes: 86 additions & 0 deletions pf2rnaseq/figures/figureParseFactorsDeconv.py
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@aarmey I have just been building this figure to test out the regularization. I have a decomposition at rank 100 saved as "/home/nicoleb/ParsePf2_100_D11_filt.h5ad" then calling deconvolution_cytokine where the alpha(regularization strength) can be adjusted. The deconvoluted matrix, original matrix, and convolution matrix is plotted out. It should also print out the MSE at every 100 iterations and final sparsity.

Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
"""
Parse data: Plotting factors
"""

import numpy as np
import pandas as pd
import seaborn as sns
from anndata import read_h5ad
from matplotlib import pyplot as plt

from ..factorization import correct_conditions, deconvolution_cytokine
from .common import getSetup, subplotLabel
from .commonFuncs.plotFactors import (
plot_condition_factors,
)


def samples_only(X) -> pd.DataFrame:
"""Obtain samples once only with corresponding observations"""
samples = X.obs
df_samples = samples.drop_duplicates(subset="condition_unique_idxs")
df_samples = df_samples.sort_values("condition_unique_idxs")
return df_samples


def makeFigure():
"""Get a list of the axis objects and create a figure."""
# Get list of axis objects
ax, f = getSetup((22, 15), (1, 3))

# Add subplot labels
subplotLabel(ax)

# Load data
X = read_h5ad("/home/nicoleb/ParsePf2_100_D11_filt.h5ad")
X.uns["Pf2_A"] = correct_conditions(X)

W, H = deconvolution_cytokine(X.uns["Pf2_A"], alpha=1e-1, max_iter=5000)

# Get cytokine names in correct order
samples_df = samples_only(X)

# Create deconvolved version for plotting
X_deconv = X.copy()
X_deconv.uns["Pf2_A"] = H # Use primary effects only

plot_condition_factors(
X_deconv,
ax[0],
samples_df["cytokine"],
groupConditions=True,
cond="cytokine",
log_scale=False,
)
ax[0].set_title("Deconvolved matrix (H)", fontsize=12, fontweight="bold")

plot_condition_factors(
X,
ax[1],
samples_df["cytokine"],
groupConditions=True,
cond="cytokine",
log_scale=False,
)
ax[1].set_title("Original Effects (A)", fontsize=12, fontweight="bold")

cytokine_names = samples_df["cytokine"].values

# Plot 2: W heatmap (primary effects)
sns.heatmap(
W,
ax=ax[2],
cmap="YlOrRd",
robust=True,
square=True,
cbar_kws={"label": "Signaling Strength"},
xticklabels=cytokine_names,
yticklabels=cytokine_names,
)
ax[2].set_title("Cytokine Signaling (W)", fontsize=12, fontweight="bold")
ax[2].set_xlabel("Inducing Cytokine →", fontsize=10)
ax[2].set_ylabel("← Induced Cytokine", fontsize=10)
plt.setp(ax[2].get_xticklabels(), rotation=90, ha="center", fontsize=6)
plt.setp(ax[2].get_yticklabels(), rotation=0, fontsize=6)

return f
5 changes: 2 additions & 3 deletions pf2rnaseq/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,8 @@ def import_Parse(geneThreshold=0.1, doublet=False) -> anndata.AnnData:
X = anndata.read_h5ad("/home/nicoleb/Pf2-scRNAseq-1/pf2rnaseq/Parse_Donor11.h5ad")
if doublet:
doubletDF = pd.read_csv(
path_here / "pf2rnaseq/Data/DN11Doublets.csv.gz",
index_col=0
)
path_here / "pf2rnaseq/Data/DN11Doublets.csv.gz", index_col=0
)
X.obs = X.obs.join(doubletDF.reindex(X.obs.index))
singlet_mask = X.obs["doublet"] == 0
X = X[singlet_mask, :].copy()
Expand Down
1 change: 0 additions & 1 deletion pf2rnaseq/top_bot_genes_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
exports csv of top 30 and bottom 30 genes per component
"""


import numpy as np
import pandas as pd
from anndata import read_h5ad
Expand Down
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@ dependencies = [
"anndata>=0.10.3",
"datashader>=0.18",
"gseapy>=1.1",
"scanpy @ git+https://github.com/scverse/scanpy.git@c2a7a4b7ec3203121a8d75aa05fbeb602ceecbd4",
"scanpy>=1.10",
"pacmap>=0.8",
"leidenalg>=0.10.1",
"tqdm>=4.66.1",
"tlviz>=0.1.1",
"statsmodels>=0.14.1",
"statsmodels>=0.14.4",
"dask[dataframe]>=2025",
"ipykernel>=6.29.5",
"parafac2 @ git+https://github.com/meyer-lab/parafac2.git",
Expand All @@ -36,6 +36,7 @@ dev = [
"pytest>=8.0",
"pytest-cov>=6.0",
"pyright>=1.1",
"ruff>=0.14.4",
]

[project.scripts]
Expand Down
Loading
Loading