66import scipy .sparse as sps
77from pacmap import PaCMAP
88from parafac2 .parafac2 import parafac2_nd , store_pf2
9+ from scipy .optimize import minimize
910from scipy .stats import gmean
1011from sklearn .decomposition import PCA
1112from sklearn .linear_model import LinearRegression
1718def correct_conditions (X : anndata .AnnData ):
1819 """Correct the conditions factors by overall read depth. Ensures that weighting is not affected by cell count difference"""
1920 sgIndex = X .obs ["condition_unique_idxs" ]
20- #sgIndex = X.obs["condition_unique_idxs"].cat.codes
21+ # sgIndex = X.obs["condition_unique_idxs"].cat.codes
2122 counts = np .zeros ((np .amax (sgIndex ) + 1 , 1 ))
2223 min_val = np .min (X .uns ["Pf2_A" ])
2324 if min_val < 0 :
2425 # Add the absolute value of the minimum (plus a small epsilon) to make all values positive
2526 X .uns ["Pf2_A" ] = X .uns ["Pf2_A" ] + abs (min_val ) + 1e-10
26- print (f"Warning: Found negative values in Pf2_A (min: { min_val :.6f} ). Added { abs (min_val ) + 1e-10 :.6f} to all values." )
27-
27+ print (
28+ f"Warning: Found negative values in Pf2_A (min: { min_val :.6f} ). Added { abs (min_val ) + 1e-10 :.6f} to all values."
29+ )
30+
2831 cond_mean = gmean (X .uns ["Pf2_A" ], axis = 1 )
2932
3033 x_count = X .X .sum (axis = 1 )
@@ -50,13 +53,11 @@ def pf2(
5053):
5154 cupy .cuda .Device (0 ).use ()
5255 pf_out , R2X = parafac2_nd (
53-
5456 X ,
5557 rank = rank ,
5658 random_state = random_state ,
5759 tol = tolerance ,
5860 n_iter_max = 500 ,
59-
6061 )
6162
6263 X = store_pf2 (X , pf_out )
@@ -197,3 +198,151 @@ def fms_diff_ranks(
197198 )
198199
199200 return df
201+
202+
203+ def deconvolution_cytokine (
204+ A : np .ndarray ,
205+ alpha : float = 0.1 ,
206+ max_iter : int = 5000 ,
207+ random_state : int = 1 ,
208+ ) -> tuple [np .ndarray , np .ndarray ]:
209+ """
210+ Decompose cytokine factor matrix: A ≈ W @ H
211+
212+ This decomposes observed cytokine effects into:
213+ 1. Direct primary effects (H)
214+ 2. Induced effects via other cytokines (W)
215+
216+ Parameters
217+ ----------
218+ A : np.ndarray
219+ Input matrix (n_cytokines, n_components)
220+ Example: (91 cytokines, 100 Parafac2 components)
221+ alpha : float
222+ Regularization strength
223+ max_iter : int
224+ Maximum optimization iterations
225+ random_state : int
226+ Random seed
227+
228+ Returns
229+ -------
230+ W : np.ndarray
231+ Cytokine interaction matrix (n_cytokines, n_cytokines)
232+ W[i, j] = total contribution of cytokine j to observed effect of i
233+ Diagonal W[i,i] = direct effect of cytokine i
234+ H : np.ndarray
235+ Effect basis matrix (n_cytokines, n_components)
236+ H[:, j] = cytokine effects for component j without indirect contributions
237+ """
238+ n_cytokines , n_components = A .shape
239+
240+ np .random .seed (random_state )
241+
242+ # W initialized as identity, H is original A
243+ W_init = np .eye (n_cytokines )
244+ H_init = A .copy ()
245+
246+ x0 = np .concatenate ([W_init .ravel (), H_init .ravel ()])
247+
248+ print ("Cytokine deconvolution:" )
249+ print (f" A shape: { A .shape } (cytokines × components)" )
250+ print (f" W shape: ({ n_cytokines } , { n_cytokines } ) (cytokine interactions)" )
251+ print (f" H shape: ({ n_cytokines } , { n_components } ) (effect basis)" )
252+
253+ w_size = n_cytokines * n_cytokines
254+ iteration_counter = [0 ]
255+ best_loss = [np .inf ]
256+
257+ def objective (x ):
258+ W = x [:w_size ].reshape (n_cytokines , n_cytokines )
259+ H = x [w_size :].reshape (n_cytokines , n_components )
260+
261+ # Reconstruction:A ≈ W @ H
262+
263+ reconstruction = W @ H
264+ mse = np .sum ((A - reconstruction ) ** 2 )
265+
266+ # Regularization: L1 penalty on both W and H
267+ # Exclude diagonal of W from L1 penalty
268+ l1_W = alpha * np .sum (np .abs (W )) - alpha * np .diag (np .abs (W )).sum ()
269+ l1_H = alpha * np .sum (np .abs (H ))
270+
271+ total_loss = mse + l1_W + l1_H
272+
273+ iteration_counter [0 ] += 1
274+ if total_loss < best_loss [0 ]:
275+ best_loss [0 ] = total_loss
276+
277+ if iteration_counter [0 ] % 10 == 0 :
278+ print (
279+ f" Iter { iteration_counter [0 ]} : Loss={ total_loss :.4f} "
280+ f"(MSE={ mse :.4f} , L1_W={ l1_W :.4f} , L1_H={ l1_H :.4f} )"
281+ )
282+
283+ return total_loss
284+
285+ def gradient (x ):
286+ W = x [:w_size ].reshape (n_cytokines , n_cytokines )
287+ H = x [w_size :].reshape (n_cytokines , n_components )
288+
289+ # ===== Gradient w.r.t. W =====
290+ # 1. Reconstruction term: ∂/∂W [||A - WH||²] = 2(error @ H^T), L1 penalty: ∂/∂W [α||W||₁] = α * sign(W)
291+ grad_W = 2 * ((W @ H - A ) @ H .T ) + alpha * np .sign (W ) - np .diag (alpha * np .sign (np .diag (W )))
292+
293+ # ===== Gradient w.r.t. H =====
294+ # 1. Reconstruction term: ∂/∂H [||A - WH||²] = 2(W^T @ error), L1 penalty: ∂/∂H [α||H||₁] = α * sign(H)
295+ grad_H = 2 * (W .T @ (W @ H - A )) + alpha * np .sign (H )
296+
297+ return np .concatenate ([grad_W .ravel (), grad_H .ravel ()])
298+
299+ print ("\n Starting optimization..." )
300+
301+ result = minimize (
302+ fun = objective ,
303+ x0 = x0 ,
304+ method = "L-BFGS-B" ,
305+ jac = gradient ,
306+ options = {"maxiter" : max_iter , "disp" : True },
307+ )
308+
309+ W = result .x [:w_size ].reshape (n_cytokines , n_cytokines )
310+ H = result .x [w_size :].reshape (n_cytokines , n_components )
311+
312+ # Evaluate
313+
314+ A_recon = W @ H
315+
316+ recon_error = np .linalg .norm (A - A_recon , "fro" )
317+ rel_error = recon_error / np .linalg .norm (A , "fro" )
318+
319+ # Statistics for W
320+ w_sparsity = np .sum (np .abs (W ) < 1e-3 ) / W .size
321+ w_mean = np .abs (W ).mean ()
322+ w_max = np .abs (W ).max ()
323+
324+ # Statistics for H
325+ h_sparsity = np .sum (np .abs (H ) < 1e-3 ) / H .size
326+ h_mean = np .abs (H ).mean ()
327+ h_max = np .abs (H ).max ()
328+
329+ print ("\n Optimization complete:" )
330+ print (f" Success: { result .success } " )
331+ print (f" Iterations: { result .nit } " )
332+ print (f" Relative reconstruction error: { rel_error :.4%} " )
333+
334+ print ("\n W (cytokine interactions):" )
335+ print (f" Shape: { W .shape } " )
336+ print (f" Sparsity: { w_sparsity :.2%} (near-zero elements)" )
337+ print (f" Mean |W|: { w_mean :.4f} " )
338+ print (f" Max |W|: { w_max :.4f} " )
339+ print (f" Non-zeros: { np .sum (np .abs (W ) > 1e-3 )} /{ W .size } " )
340+
341+ print ("\n H (effect patterns):" )
342+ print (f" Shape: { H .shape } " )
343+ print (f" Sparsity: { h_sparsity :.2%} (near-zero elements)" )
344+ print (f" Mean |H|: { h_mean :.4f} " )
345+ print (f" Max |H|: { h_max :.4f} " )
346+ print (f" Non-zeros: { np .sum (np .abs (H ) > 1e-3 )} /{ H .size } " )
347+
348+ return W , H
0 commit comments