Skip to content

Commit 355aee8

Browse files
committed
Update ADMM
1 parent 9005855 commit 355aee8

File tree

1 file changed

+44
-36
lines changed

1 file changed

+44
-36
lines changed

pf2rnaseq/factorization.py

Lines changed: 44 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -349,28 +349,31 @@ def gradient(x):
349349

350350
def deconvolution_cytokine_admm(
351351
A: np.ndarray,
352-
alpha: float = 0.1,
352+
alpha_h: float = 0.1,
353+
alpha_w: float = 0.01,
353354
rho: float = 1.0,
354355
max_iter: int = 5000,
355-
tol: float = 1e-4,
356+
tol: float = 1e-4, # Single tolerance for both primal and dual
356357
random_state: int = 1,
357358
adaptive_rho: bool = True,
358359
) -> tuple[np.ndarray, np.ndarray, dict]:
359360
"""
360-
Decompose cytokine factor matrix using ADMM: A ≈ W @ H
361+
Decompose cytokine factor matrix using ADMM: A ≈ W @ H
361362
362363
Parameters
363364
----------
364365
A : np.ndarray
365366
Input matrix (n_cytokines, n_components)
366-
alpha : float
367-
L1 regularization strength (applied to both W and H)
367+
alpha_h : float
368+
L1 regularization for H
369+
alpha_w : float
370+
L1 regularization for W (off-diagonal only)
368371
rho : float
369372
ADMM penalty parameter
370373
max_iter : int
371-
Maximum ADMM iterations
374+
Maximum iterations
372375
tol : float
373-
Convergence tolerance
376+
Convergence tolerance for both primal and dual residuals
374377
random_state : int
375378
Random seed
376379
adaptive_rho : bool
@@ -398,38 +401,40 @@ def deconvolution_cytokine_admm(
398401

399402
print("Cytokine deconvolution with ADMM:")
400403
print(f" A shape: {A.shape}")
401-
print(f" Alpha (L1 penalty): {alpha}")
402-
print(f" Rho (ADMM penalty): {rho}")
403-
print(f" Adaptive rho: {adaptive_rho}")
404+
print(f" Alpha_W: {alpha_w}, Alpha_H: {alpha_h}")
405+
print(f" Rho: {rho}")
406+
print(f" Tolerance: {tol}")
404407

405-
# Create mask for off-diagonal elements
406408
off_diag_mask = ~np.eye(n_cytokines, dtype=bool)
407409

408410
def soft_threshold(X, threshold):
409411
return np.sign(X) * np.maximum(np.abs(X) - threshold, 0)
410412

411413
def update_W(H, Z_W, U_W, rho):
412-
"""Update W: solve (H@H^T + rho*I) W^T = (A@H^T + rho(Z_W - U_W))^T"""
414+
"""Update W: constrain diagonal to 1.0"""
413415
H_HT = H @ H.T
414416
A_HT = A @ H.T
415417
lhs = H_HT + rho * np.eye(n_cytokines)
416418
rhs = A_HT + rho * (Z_W - U_W)
417-
return np.linalg.solve(lhs, rhs.T).T
419+
420+
W_new = np.linalg.solve(lhs, rhs.T).T
421+
np.fill_diagonal(W_new, 1.0)
422+
423+
return W_new
418424

419425
def update_H(W, Z_H, U_H, rho):
420-
"""Update H: solve (W^T@W + rho*I) H = W^T@A + rho(Z_H - U_H)"""
426+
"""Update H"""
421427
W_TW = W.T @ W
422428
W_TA = W.T @ A
423429
lhs = W_TW + rho * np.eye(n_cytokines)
424430
rhs = W_TA + rho * (Z_H - U_H)
425431
return np.linalg.solve(lhs, rhs)
426432

427433
def update_Z_W(W, U_W, alpha, rho):
428-
"""Update Z_W: soft-threshold off-diagonal, preserve diagonal"""
434+
"""Update Z_W: soft-threshold off-diagonal only"""
429435
X = W + U_W
430-
Z_W_new = soft_threshold(X, alpha / rho)
431-
# Restore diagonal (no L1 penalty on direct effects)
432-
np.fill_diagonal(Z_W_new, np.diag(X))
436+
Z_W_new = X.copy()
437+
Z_W_new[off_diag_mask] = soft_threshold(X[off_diag_mask], alpha / rho)
433438
return Z_W_new
434439

435440
def update_Z_H(H, U_H, alpha, rho):
@@ -448,27 +453,30 @@ def update_Z_H(H, U_H, alpha, rho):
448453
print("\nStarting ADMM iterations...")
449454

450455
for iteration in range(max_iter):
451-
# Store old Z values for dual residual computation
452456
Z_W_old = Z_W.copy()
453457
Z_H_old = Z_H.copy()
454458

455459
# ADMM updates
456460
W = update_W(H, Z_W, U_W, rho)
457461
H = update_H(W, Z_H, U_H, rho)
458-
Z_W = update_Z_W(W, U_W, alpha, rho)
459-
Z_H = update_Z_H(H, U_H, alpha, rho)
462+
Z_W = update_Z_W(W, U_W, alpha_w, rho)
463+
Z_H = update_Z_H(H, U_H, alpha_h, rho)
460464
U_W = U_W + (W - Z_W)
461465
U_H = U_H + (H - Z_H)
462466

463-
# Compute residuals
467+
# ===== SIMPLIFIED CONVERGENCE CHECK =====
468+
469+
# Primal residual: ||W - Z_W||² + ||H - Z_H||²
464470
r_norm = np.sqrt(np.sum((W - Z_W)**2) + np.sum((H - Z_H)**2))
465-
s_norm = rho * np.sqrt(np.sum((Z_W - Z_W_old)**2) +
466-
np.sum((Z_H - Z_H_old)**2))
467471

468-
# Compute objective (off-diagonal penalty for W, full penalty for H)
472+
# Dual residual: ||ρ(Z_W - Z_W_old)||² + ||ρ(Z_H - Z_H_old)||²
473+
s_norm = np.sqrt(np.sum((rho * (Z_W - Z_W_old))**2) +
474+
np.sum((rho * (Z_H - Z_H_old))**2))
475+
476+
# Compute objective
469477
recon_error = np.sum((A - W @ H) ** 2)
470-
l1_W = alpha * np.sum(np.abs(Z_W[off_diag_mask]))
471-
l1_H = alpha * np.sum(np.abs(Z_H))
478+
l1_W = alpha_w * np.sum(np.abs(Z_W[off_diag_mask]))
479+
l1_H = alpha_h * np.sum(np.abs(Z_H))
472480
objective = recon_error + l1_W + l1_H
473481

474482
# Track sparsity
@@ -486,26 +494,26 @@ def update_Z_H(H, U_H, alpha, rho):
486494
# Print progress
487495
if iteration % 10 == 0 or iteration < 10:
488496
print(f" Iter {iteration:4d}: Obj={objective:.4e}, "
489-
f"||r||={r_norm:.4e}, ||s||={s_norm:.4e}, "
490-
f"W_sparse={w_sparsity:.2%}, H_sparse={h_sparsity:.2%}")
497+
f"r={r_norm:.3e}, s={s_norm:.3e}, ρ={rho:.2f}")
491498

492499
# Adaptive rho update
493500
if adaptive_rho and iteration > 0:
494501
if r_norm > 10 * s_norm:
495502
rho = rho * 2
496503
U_W = U_W / 2
497504
U_H = U_H / 2
505+
print(f" Increased ρ → {rho:.2f}")
498506
elif s_norm > 10 * r_norm:
499507
rho = rho / 2
500508
U_W = U_W * 2
501509
U_H = U_H * 2
510+
print(f" Decreased ρ → {rho:.2f}")
502511

503-
# Check convergence
504-
eps_primal = tol * np.sqrt(W.size + H.size)
505-
eps_dual = tol * np.sqrt(U_W.size + U_H.size)
506-
507-
if r_norm < eps_primal and s_norm < eps_dual:
508-
print(f"\nConverged at iteration {iteration}")
512+
# Simple convergence check
513+
if r_norm < tol and s_norm < tol:
514+
print(f"\n✓ Converged at iteration {iteration}")
515+
print(f" Primal residual: {r_norm:.4e} < {tol:.4e}")
516+
print(f" Dual residual: {s_norm:.4e} < {tol:.4e}")
509517
break
510518

511519
# Final statistics
@@ -522,7 +530,7 @@ def update_Z_H(H, U_H, alpha, rho):
522530
print(f" Off-diagonal sparsity: {w_sparsity:.2%}")
523531
print(f" Off-diagonal non-zeros: {np.sum(np.abs(Z_W[off_diag_mask]) > 1e-3)}")
524532
print(f" Mean |W_offdiag|: {np.abs(Z_W[off_diag_mask]).mean():.4f}")
525-
print(f" Diagonal mean: {np.abs(np.diag(Z_W)).mean():.4f}")
533+
print(f" Diagonal: all 1.0 (constrained)")
526534
print(f"\n H (effect patterns):")
527535
print(f" Sparsity: {h_sparsity:.2%}")
528536
print(f" Non-zeros: {np.sum(np.abs(Z_H) > 1e-3)}/{Z_H.size}")

0 commit comments

Comments
 (0)