Skip to content

Commit bb7de9d

Browse files
authored
Merge pull request #28 from meyer-lab/Parse-data
Parse data
2 parents 04de245 + 1988607 commit bb7de9d

File tree

8 files changed

+384
-141
lines changed

8 files changed

+384
-141
lines changed

pf2rnaseq/factorization.py

Lines changed: 154 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import scipy.sparse as sps
77
from pacmap import PaCMAP
88
from parafac2.parafac2 import parafac2_nd, store_pf2
9+
from scipy.optimize import minimize
910
from scipy.stats import gmean
1011
from sklearn.decomposition import PCA
1112
from sklearn.linear_model import LinearRegression
@@ -17,14 +18,16 @@
1718
def correct_conditions(X: anndata.AnnData):
1819
"""Correct the conditions factors by overall read depth. Ensures that weighting is not affected by cell count difference"""
1920
sgIndex = X.obs["condition_unique_idxs"]
20-
#sgIndex = X.obs["condition_unique_idxs"].cat.codes
21+
# sgIndex = X.obs["condition_unique_idxs"].cat.codes
2122
counts = np.zeros((np.amax(sgIndex) + 1, 1))
2223
min_val = np.min(X.uns["Pf2_A"])
2324
if min_val < 0:
2425
# Add the absolute value of the minimum (plus a small epsilon) to make all values positive
2526
X.uns["Pf2_A"] = X.uns["Pf2_A"] + abs(min_val) + 1e-10
26-
print(f"Warning: Found negative values in Pf2_A (min: {min_val:.6f}). Added {abs(min_val) + 1e-10:.6f} to all values.")
27-
27+
print(
28+
f"Warning: Found negative values in Pf2_A (min: {min_val:.6f}). Added {abs(min_val) + 1e-10:.6f} to all values."
29+
)
30+
2831
cond_mean = gmean(X.uns["Pf2_A"], axis=1)
2932

3033
x_count = X.X.sum(axis=1)
@@ -50,13 +53,11 @@ def pf2(
5053
):
5154
cupy.cuda.Device(0).use()
5255
pf_out, R2X = parafac2_nd(
53-
5456
X,
5557
rank=rank,
5658
random_state=random_state,
5759
tol=tolerance,
5860
n_iter_max=500,
59-
6061
)
6162

6263
X = store_pf2(X, pf_out)
@@ -197,3 +198,151 @@ def fms_diff_ranks(
197198
)
198199

199200
return df
201+
202+
203+
def deconvolution_cytokine(
204+
A: np.ndarray,
205+
alpha: float = 0.1,
206+
max_iter: int = 5000,
207+
random_state: int = 1,
208+
) -> tuple[np.ndarray, np.ndarray]:
209+
"""
210+
Decompose cytokine factor matrix: A ≈ W @ H
211+
212+
This decomposes observed cytokine effects into:
213+
1. Direct primary effects (H)
214+
2. Induced effects via other cytokines (W)
215+
216+
Parameters
217+
----------
218+
A : np.ndarray
219+
Input matrix (n_cytokines, n_components)
220+
Example: (91 cytokines, 100 Parafac2 components)
221+
alpha : float
222+
Regularization strength
223+
max_iter : int
224+
Maximum optimization iterations
225+
random_state : int
226+
Random seed
227+
228+
Returns
229+
-------
230+
W : np.ndarray
231+
Cytokine interaction matrix (n_cytokines, n_cytokines)
232+
W[i, j] = total contribution of cytokine j to observed effect of i
233+
Diagonal W[i,i] = direct effect of cytokine i
234+
H : np.ndarray
235+
Effect basis matrix (n_cytokines, n_components)
236+
H[:, j] = cytokine effects for component j without indirect contributions
237+
"""
238+
n_cytokines, n_components = A.shape
239+
240+
np.random.seed(random_state)
241+
242+
# W initialized as identity, H is original A
243+
W_init = np.eye(n_cytokines)
244+
H_init = A.copy()
245+
246+
x0 = np.concatenate([W_init.ravel(), H_init.ravel()])
247+
248+
print("Cytokine deconvolution:")
249+
print(f" A shape: {A.shape} (cytokines × components)")
250+
print(f" W shape: ({n_cytokines}, {n_cytokines}) (cytokine interactions)")
251+
print(f" H shape: ({n_cytokines}, {n_components}) (effect basis)")
252+
253+
w_size = n_cytokines * n_cytokines
254+
iteration_counter = [0]
255+
best_loss = [np.inf]
256+
257+
def objective(x):
258+
W = x[:w_size].reshape(n_cytokines, n_cytokines)
259+
H = x[w_size:].reshape(n_cytokines, n_components)
260+
261+
# Reconstruction:A ≈ W @ H
262+
263+
reconstruction = W @ H
264+
mse = np.sum((A - reconstruction) ** 2)
265+
266+
# Regularization: L1 penalty on both W and H
267+
# Exclude diagonal of W from L1 penalty
268+
l1_W = alpha * np.sum(np.abs(W)) - alpha * np.diag(np.abs(W)).sum()
269+
l1_H = alpha * np.sum(np.abs(H))
270+
271+
total_loss = mse + l1_W + l1_H
272+
273+
iteration_counter[0] += 1
274+
if total_loss < best_loss[0]:
275+
best_loss[0] = total_loss
276+
277+
if iteration_counter[0] % 10 == 0:
278+
print(
279+
f" Iter {iteration_counter[0]}: Loss={total_loss:.4f} "
280+
f"(MSE={mse:.4f}, L1_W={l1_W:.4f}, L1_H={l1_H:.4f})"
281+
)
282+
283+
return total_loss
284+
285+
def gradient(x):
286+
W = x[:w_size].reshape(n_cytokines, n_cytokines)
287+
H = x[w_size:].reshape(n_cytokines, n_components)
288+
289+
# ===== Gradient w.r.t. W =====
290+
# 1. Reconstruction term: ∂/∂W [||A - WH||²] = 2(error @ H^T), L1 penalty: ∂/∂W [α||W||₁] = α * sign(W)
291+
grad_W = 2 * ((W @ H - A) @ H.T) + alpha * np.sign(W) - np.diag(alpha * np.sign(np.diag(W)))
292+
293+
# ===== Gradient w.r.t. H =====
294+
# 1. Reconstruction term: ∂/∂H [||A - WH||²] = 2(W^T @ error), L1 penalty: ∂/∂H [α||H||₁] = α * sign(H)
295+
grad_H = 2 * (W.T @ (W @ H - A)) + alpha * np.sign(H)
296+
297+
return np.concatenate([grad_W.ravel(), grad_H.ravel()])
298+
299+
print("\nStarting optimization...")
300+
301+
result = minimize(
302+
fun=objective,
303+
x0=x0,
304+
method="L-BFGS-B",
305+
jac=gradient,
306+
options={"maxiter": max_iter, "disp": True},
307+
)
308+
309+
W = result.x[:w_size].reshape(n_cytokines, n_cytokines)
310+
H = result.x[w_size:].reshape(n_cytokines, n_components)
311+
312+
# Evaluate
313+
314+
A_recon = W @ H
315+
316+
recon_error = np.linalg.norm(A - A_recon, "fro")
317+
rel_error = recon_error / np.linalg.norm(A, "fro")
318+
319+
# Statistics for W
320+
w_sparsity = np.sum(np.abs(W) < 1e-3) / W.size
321+
w_mean = np.abs(W).mean()
322+
w_max = np.abs(W).max()
323+
324+
# Statistics for H
325+
h_sparsity = np.sum(np.abs(H) < 1e-3) / H.size
326+
h_mean = np.abs(H).mean()
327+
h_max = np.abs(H).max()
328+
329+
print("\nOptimization complete:")
330+
print(f" Success: {result.success}")
331+
print(f" Iterations: {result.nit}")
332+
print(f" Relative reconstruction error: {rel_error:.4%}")
333+
334+
print("\n W (cytokine interactions):")
335+
print(f" Shape: {W.shape}")
336+
print(f" Sparsity: {w_sparsity:.2%} (near-zero elements)")
337+
print(f" Mean |W|: {w_mean:.4f}")
338+
print(f" Max |W|: {w_max:.4f}")
339+
print(f" Non-zeros: {np.sum(np.abs(W) > 1e-3)}/{W.size}")
340+
341+
print("\n H (effect patterns):")
342+
print(f" Shape: {H.shape}")
343+
print(f" Sparsity: {h_sparsity:.2%} (near-zero elements)")
344+
print(f" Mean |H|: {h_mean:.4f}")
345+
print(f" Max |H|: {h_max:.4f}")
346+
print(f" Non-zeros: {np.sum(np.abs(H) > 1e-3)}/{H.size}")
347+
348+
return W, H

pf2rnaseq/figures/commonFuncs/plotFactors.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,9 @@ def plot_condition_factors(
2828
X = np.log10(X)
2929

3030
X -= np.median(X, axis=0)
31-
X /= np.std(X, axis=0)
31+
X /= np.std(X, axis=0) + 1e-3
32+
ind = reorder_table(X + 1e-3)
3233

33-
ind = reorder_table(X)
3434
X = X[ind]
3535
yt = yt.iloc[ind]
3636

@@ -67,7 +67,7 @@ def plot_condition_factors(
6767
)
6868
)
6969
# add a little legend
70-
ax.legend(handles=legend_elements, bbox_to_anchor=(0, 1.3))
70+
# ax.legend(handles=legend_elements, bbox_to_anchor=(0, 1.3))
7171

7272
xticks = np.arange(1, X.shape[1] + 1)
7373

@@ -584,7 +584,6 @@ def plot_comp_weights(
584584

585585
# Add legend for color coding (only if lowest are included)
586586
if include_lowest:
587-
588587
legend_elements = [
589588
Patch(facecolor="darkred", label=f"Top {top_n} Highest"),
590589
Patch(facecolor="darkblue", label=f"Top {top_n} Lowest"),

pf2rnaseq/figures/figureHeiserCompPac.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
Weighted projections per component in PaCMAP and boxplot of cell types
33
"""
44

5-
65
import numpy as np
76

87
from ..factorization import correct_conditions, pf2
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
"""
2+
Parse data: Plotting factors
3+
"""
4+
5+
import numpy as np
6+
import pandas as pd
7+
import seaborn as sns
8+
from anndata import read_h5ad
9+
from matplotlib import pyplot as plt
10+
11+
from ..factorization import correct_conditions, deconvolution_cytokine
12+
from .common import getSetup, subplotLabel
13+
from .commonFuncs.plotFactors import (
14+
plot_condition_factors,
15+
)
16+
17+
18+
def samples_only(X) -> pd.DataFrame:
19+
"""Obtain samples once only with corresponding observations"""
20+
samples = X.obs
21+
df_samples = samples.drop_duplicates(subset="condition_unique_idxs")
22+
df_samples = df_samples.sort_values("condition_unique_idxs")
23+
return df_samples
24+
25+
26+
def makeFigure():
27+
"""Get a list of the axis objects and create a figure."""
28+
# Get list of axis objects
29+
ax, f = getSetup((22, 15), (1, 3))
30+
31+
# Add subplot labels
32+
subplotLabel(ax)
33+
34+
# Load data
35+
X = read_h5ad("/home/nicoleb/ParsePf2_100_D11_filt.h5ad")
36+
X.uns["Pf2_A"] = correct_conditions(X)
37+
38+
W, H = deconvolution_cytokine(X.uns["Pf2_A"], alpha=1e-1, max_iter=5000)
39+
40+
# Get cytokine names in correct order
41+
samples_df = samples_only(X)
42+
43+
# Create deconvolved version for plotting
44+
X_deconv = X.copy()
45+
X_deconv.uns["Pf2_A"] = H # Use primary effects only
46+
47+
plot_condition_factors(
48+
X_deconv,
49+
ax[0],
50+
samples_df["cytokine"],
51+
groupConditions=True,
52+
cond="cytokine",
53+
log_scale=False,
54+
)
55+
ax[0].set_title("Deconvolved matrix (H)", fontsize=12, fontweight="bold")
56+
57+
plot_condition_factors(
58+
X,
59+
ax[1],
60+
samples_df["cytokine"],
61+
groupConditions=True,
62+
cond="cytokine",
63+
log_scale=False,
64+
)
65+
ax[1].set_title("Original Effects (A)", fontsize=12, fontweight="bold")
66+
67+
cytokine_names = samples_df["cytokine"].values
68+
69+
# Plot 2: W heatmap (primary effects)
70+
sns.heatmap(
71+
W,
72+
ax=ax[2],
73+
cmap="YlOrRd",
74+
robust=True,
75+
square=True,
76+
cbar_kws={"label": "Signaling Strength"},
77+
xticklabels=cytokine_names,
78+
yticklabels=cytokine_names,
79+
)
80+
ax[2].set_title("Cytokine Signaling (W)", fontsize=12, fontweight="bold")
81+
ax[2].set_xlabel("Inducing Cytokine →", fontsize=10)
82+
ax[2].set_ylabel("← Induced Cytokine", fontsize=10)
83+
plt.setp(ax[2].get_xticklabels(), rotation=90, ha="center", fontsize=6)
84+
plt.setp(ax[2].get_yticklabels(), rotation=0, fontsize=6)
85+
86+
return f

pf2rnaseq/imports.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,8 @@ def import_Parse(geneThreshold=0.1, doublet=False) -> anndata.AnnData:
9292
X = anndata.read_h5ad("/home/nicoleb/Pf2-scRNAseq-1/pf2rnaseq/Parse_Donor11.h5ad")
9393
if doublet:
9494
doubletDF = pd.read_csv(
95-
path_here / "pf2rnaseq/Data/DN11Doublets.csv.gz",
96-
index_col=0
97-
)
95+
path_here / "pf2rnaseq/Data/DN11Doublets.csv.gz", index_col=0
96+
)
9897
X.obs = X.obs.join(doubletDF.reindex(X.obs.index))
9998
singlet_mask = X.obs["doublet"] == 0
10099
X = X[singlet_mask, :].copy()

pf2rnaseq/top_bot_genes_export.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
exports csv of top 30 and bottom 30 genes per component
33
"""
44

5-
65
import numpy as np
76
import pandas as pd
87
from anndata import read_h5ad

pyproject.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,12 @@ dependencies = [
1818
"anndata>=0.10.3",
1919
"datashader>=0.18",
2020
"gseapy>=1.1",
21-
"scanpy @ git+https://github.com/scverse/scanpy.git@c2a7a4b7ec3203121a8d75aa05fbeb602ceecbd4",
21+
"scanpy>=1.10",
2222
"pacmap>=0.8",
2323
"leidenalg>=0.10.1",
2424
"tqdm>=4.66.1",
2525
"tlviz>=0.1.1",
26-
"statsmodels>=0.14.1",
26+
"statsmodels>=0.14.4",
2727
"dask[dataframe]>=2025",
2828
"ipykernel>=6.29.5",
2929
"parafac2 @ git+https://github.com/meyer-lab/parafac2.git",
@@ -36,6 +36,7 @@ dev = [
3636
"pytest>=8.0",
3737
"pytest-cov>=6.0",
3838
"pyright>=1.1",
39+
"ruff>=0.14.4",
3940
]
4041

4142
[project.scripts]

0 commit comments

Comments
 (0)