Skip to content

Commit 9005855

Browse files
committed
ADMM deconvolution
1 parent 1988607 commit 9005855

File tree

1 file changed

+183
-0
lines changed

1 file changed

+183
-0
lines changed

pf2rnaseq/factorization.py

Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,3 +346,186 @@ def gradient(x):
346346
print(f" Non-zeros: {np.sum(np.abs(H) > 1e-3)}/{H.size}")
347347

348348
return W, H
349+
350+
def deconvolution_cytokine_admm(
351+
A: np.ndarray,
352+
alpha: float = 0.1,
353+
rho: float = 1.0,
354+
max_iter: int = 5000,
355+
tol: float = 1e-4,
356+
random_state: int = 1,
357+
adaptive_rho: bool = True,
358+
) -> tuple[np.ndarray, np.ndarray, dict]:
359+
"""
360+
Decompose cytokine factor matrix using ADMM: A ≈ W @ H
361+
362+
Parameters
363+
----------
364+
A : np.ndarray
365+
Input matrix (n_cytokines, n_components)
366+
alpha : float
367+
L1 regularization strength (applied to both W and H)
368+
rho : float
369+
ADMM penalty parameter
370+
max_iter : int
371+
Maximum ADMM iterations
372+
tol : float
373+
Convergence tolerance
374+
random_state : int
375+
Random seed
376+
adaptive_rho : bool
377+
Whether to adaptively adjust rho
378+
379+
Returns
380+
-------
381+
Z_W : np.ndarray
382+
Cytokine interaction matrix (n_cytokines, n_cytokines)
383+
Z_H : np.ndarray
384+
Effect basis matrix (n_cytokines, n_components)
385+
history : dict
386+
Optimization history
387+
"""
388+
n_cytokines, n_components = A.shape
389+
np.random.seed(random_state)
390+
391+
# Initialize
392+
W = np.eye(n_cytokines)
393+
H = A.copy()
394+
Z_W = W.copy()
395+
Z_H = H.copy()
396+
U_W = np.zeros_like(W)
397+
U_H = np.zeros_like(H)
398+
399+
print("Cytokine deconvolution with ADMM:")
400+
print(f" A shape: {A.shape}")
401+
print(f" Alpha (L1 penalty): {alpha}")
402+
print(f" Rho (ADMM penalty): {rho}")
403+
print(f" Adaptive rho: {adaptive_rho}")
404+
405+
# Create mask for off-diagonal elements
406+
off_diag_mask = ~np.eye(n_cytokines, dtype=bool)
407+
408+
def soft_threshold(X, threshold):
409+
return np.sign(X) * np.maximum(np.abs(X) - threshold, 0)
410+
411+
def update_W(H, Z_W, U_W, rho):
412+
"""Update W: solve (H@H^T + rho*I) W^T = (A@H^T + rho(Z_W - U_W))^T"""
413+
H_HT = H @ H.T
414+
A_HT = A @ H.T
415+
lhs = H_HT + rho * np.eye(n_cytokines)
416+
rhs = A_HT + rho * (Z_W - U_W)
417+
return np.linalg.solve(lhs, rhs.T).T
418+
419+
def update_H(W, Z_H, U_H, rho):
420+
"""Update H: solve (W^T@W + rho*I) H = W^T@A + rho(Z_H - U_H)"""
421+
W_TW = W.T @ W
422+
W_TA = W.T @ A
423+
lhs = W_TW + rho * np.eye(n_cytokines)
424+
rhs = W_TA + rho * (Z_H - U_H)
425+
return np.linalg.solve(lhs, rhs)
426+
427+
def update_Z_W(W, U_W, alpha, rho):
428+
"""Update Z_W: soft-threshold off-diagonal, preserve diagonal"""
429+
X = W + U_W
430+
Z_W_new = soft_threshold(X, alpha / rho)
431+
# Restore diagonal (no L1 penalty on direct effects)
432+
np.fill_diagonal(Z_W_new, np.diag(X))
433+
return Z_W_new
434+
435+
def update_Z_H(H, U_H, alpha, rho):
436+
"""Update Z_H: soft-threshold entire matrix"""
437+
return soft_threshold(H + U_H, alpha / rho)
438+
439+
history = {
440+
'objective': [],
441+
'primal_residual': [],
442+
'dual_residual': [],
443+
'rho': [],
444+
'w_sparsity': [],
445+
'h_sparsity': []
446+
}
447+
448+
print("\nStarting ADMM iterations...")
449+
450+
for iteration in range(max_iter):
451+
# Store old Z values for dual residual computation
452+
Z_W_old = Z_W.copy()
453+
Z_H_old = Z_H.copy()
454+
455+
# ADMM updates
456+
W = update_W(H, Z_W, U_W, rho)
457+
H = update_H(W, Z_H, U_H, rho)
458+
Z_W = update_Z_W(W, U_W, alpha, rho)
459+
Z_H = update_Z_H(H, U_H, alpha, rho)
460+
U_W = U_W + (W - Z_W)
461+
U_H = U_H + (H - Z_H)
462+
463+
# Compute residuals
464+
r_norm = np.sqrt(np.sum((W - Z_W)**2) + np.sum((H - Z_H)**2))
465+
s_norm = rho * np.sqrt(np.sum((Z_W - Z_W_old)**2) +
466+
np.sum((Z_H - Z_H_old)**2))
467+
468+
# Compute objective (off-diagonal penalty for W, full penalty for H)
469+
recon_error = np.sum((A - W @ H) ** 2)
470+
l1_W = alpha * np.sum(np.abs(Z_W[off_diag_mask]))
471+
l1_H = alpha * np.sum(np.abs(Z_H))
472+
objective = recon_error + l1_W + l1_H
473+
474+
# Track sparsity
475+
w_sparsity = np.sum(np.abs(Z_W[off_diag_mask]) < 1e-3) / np.sum(off_diag_mask)
476+
h_sparsity = np.sum(np.abs(Z_H) < 1e-3) / Z_H.size
477+
478+
# Store history
479+
history['objective'].append(objective)
480+
history['primal_residual'].append(r_norm)
481+
history['dual_residual'].append(s_norm)
482+
history['rho'].append(rho)
483+
history['w_sparsity'].append(w_sparsity)
484+
history['h_sparsity'].append(h_sparsity)
485+
486+
# Print progress
487+
if iteration % 10 == 0 or iteration < 10:
488+
print(f" Iter {iteration:4d}: Obj={objective:.4e}, "
489+
f"||r||={r_norm:.4e}, ||s||={s_norm:.4e}, "
490+
f"W_sparse={w_sparsity:.2%}, H_sparse={h_sparsity:.2%}")
491+
492+
# Adaptive rho update
493+
if adaptive_rho and iteration > 0:
494+
if r_norm > 10 * s_norm:
495+
rho = rho * 2
496+
U_W = U_W / 2
497+
U_H = U_H / 2
498+
elif s_norm > 10 * r_norm:
499+
rho = rho / 2
500+
U_W = U_W * 2
501+
U_H = U_H * 2
502+
503+
# Check convergence
504+
eps_primal = tol * np.sqrt(W.size + H.size)
505+
eps_dual = tol * np.sqrt(U_W.size + U_H.size)
506+
507+
if r_norm < eps_primal and s_norm < eps_dual:
508+
print(f"\nConverged at iteration {iteration}")
509+
break
510+
511+
# Final statistics
512+
A_recon = W @ H
513+
rel_error = np.linalg.norm(A - A_recon, "fro") / np.linalg.norm(A, "fro")
514+
515+
w_sparsity = np.sum(np.abs(Z_W[off_diag_mask]) < 1e-3) / np.sum(off_diag_mask)
516+
h_sparsity = np.sum(np.abs(Z_H) < 1e-3) / Z_H.size
517+
518+
print("\nOptimization complete:")
519+
print(f" Iterations: {iteration + 1}/{max_iter}")
520+
print(f" Relative reconstruction error: {rel_error:.4%}")
521+
print(f"\n W (cytokine interactions):")
522+
print(f" Off-diagonal sparsity: {w_sparsity:.2%}")
523+
print(f" Off-diagonal non-zeros: {np.sum(np.abs(Z_W[off_diag_mask]) > 1e-3)}")
524+
print(f" Mean |W_offdiag|: {np.abs(Z_W[off_diag_mask]).mean():.4f}")
525+
print(f" Diagonal mean: {np.abs(np.diag(Z_W)).mean():.4f}")
526+
print(f"\n H (effect patterns):")
527+
print(f" Sparsity: {h_sparsity:.2%}")
528+
print(f" Non-zeros: {np.sum(np.abs(Z_H) > 1e-3)}/{Z_H.size}")
529+
print(f" Mean |H|: {np.abs(Z_H).mean():.4f}")
530+
531+
return Z_W, Z_H, history

0 commit comments

Comments
 (0)