@@ -349,28 +349,31 @@ def gradient(x):
349349
350350def deconvolution_cytokine_admm (
351351 A : np .ndarray ,
352- alpha : float = 0.1 ,
352+ alpha_h : float = 0.1 ,
353+ alpha_w : float = 0.01 ,
353354 rho : float = 1.0 ,
354355 max_iter : int = 5000 ,
355- tol : float = 1e-4 ,
356+ tol : float = 1e-4 , # Single tolerance for both primal and dual
356357 random_state : int = 1 ,
357358 adaptive_rho : bool = True ,
358359) -> tuple [np .ndarray , np .ndarray , dict ]:
359360 """
360- Decompose cytokine factor matrix using ADMM: A ≈ W @ H
361+ Decompose cytokine factor matrix using ADMM: A ≈ W @ H
361362
362363 Parameters
363364 ----------
364365 A : np.ndarray
365366 Input matrix (n_cytokines, n_components)
366- alpha : float
367- L1 regularization strength (applied to both W and H)
367+ alpha_h : float
368+ L1 regularization for H
369+ alpha_w : float
370+ L1 regularization for W (off-diagonal only)
368371 rho : float
369372 ADMM penalty parameter
370373 max_iter : int
371- Maximum ADMM iterations
374+ Maximum iterations
372375 tol : float
373- Convergence tolerance
376+ Convergence tolerance for both primal and dual residuals
374377 random_state : int
375378 Random seed
376379 adaptive_rho : bool
@@ -398,38 +401,40 @@ def deconvolution_cytokine_admm(
398401
399402 print ("Cytokine deconvolution with ADMM:" )
400403 print (f" A shape: { A .shape } " )
401- print (f" Alpha (L1 penalty) : { alpha } " )
402- print (f" Rho (ADMM penalty) : { rho } " )
403- print (f" Adaptive rho : { adaptive_rho } " )
404+ print (f" Alpha_W: { alpha_w } , Alpha_H : { alpha_h } " )
405+ print (f" Rho: { rho } " )
406+ print (f" Tolerance : { tol } " )
404407
405- # Create mask for off-diagonal elements
406408 off_diag_mask = ~ np .eye (n_cytokines , dtype = bool )
407409
408410 def soft_threshold (X , threshold ):
409411 return np .sign (X ) * np .maximum (np .abs (X ) - threshold , 0 )
410412
411413 def update_W (H , Z_W , U_W , rho ):
412- """Update W: solve (H@H^T + rho*I) W^T = (A@H^T + rho(Z_W - U_W))^T """
414+ """Update W: constrain diagonal to 1.0 """
413415 H_HT = H @ H .T
414416 A_HT = A @ H .T
415417 lhs = H_HT + rho * np .eye (n_cytokines )
416418 rhs = A_HT + rho * (Z_W - U_W )
417- return np .linalg .solve (lhs , rhs .T ).T
419+
420+ W_new = np .linalg .solve (lhs , rhs .T ).T
421+ np .fill_diagonal (W_new , 1.0 )
422+
423+ return W_new
418424
419425 def update_H (W , Z_H , U_H , rho ):
420- """Update H: solve (W^T@W + rho*I) H = W^T@A + rho(Z_H - U_H) """
426+ """Update H"""
421427 W_TW = W .T @ W
422428 W_TA = W .T @ A
423429 lhs = W_TW + rho * np .eye (n_cytokines )
424430 rhs = W_TA + rho * (Z_H - U_H )
425431 return np .linalg .solve (lhs , rhs )
426432
427433 def update_Z_W (W , U_W , alpha , rho ):
428- """Update Z_W: soft-threshold off-diagonal, preserve diagonal """
434+ """Update Z_W: soft-threshold off-diagonal only """
429435 X = W + U_W
430- Z_W_new = soft_threshold (X , alpha / rho )
431- # Restore diagonal (no L1 penalty on direct effects)
432- np .fill_diagonal (Z_W_new , np .diag (X ))
436+ Z_W_new = X .copy ()
437+ Z_W_new [off_diag_mask ] = soft_threshold (X [off_diag_mask ], alpha / rho )
433438 return Z_W_new
434439
435440 def update_Z_H (H , U_H , alpha , rho ):
@@ -448,27 +453,30 @@ def update_Z_H(H, U_H, alpha, rho):
448453 print ("\n Starting ADMM iterations..." )
449454
450455 for iteration in range (max_iter ):
451- # Store old Z values for dual residual computation
452456 Z_W_old = Z_W .copy ()
453457 Z_H_old = Z_H .copy ()
454458
455459 # ADMM updates
456460 W = update_W (H , Z_W , U_W , rho )
457461 H = update_H (W , Z_H , U_H , rho )
458- Z_W = update_Z_W (W , U_W , alpha , rho )
459- Z_H = update_Z_H (H , U_H , alpha , rho )
462+ Z_W = update_Z_W (W , U_W , alpha_w , rho )
463+ Z_H = update_Z_H (H , U_H , alpha_h , rho )
460464 U_W = U_W + (W - Z_W )
461465 U_H = U_H + (H - Z_H )
462466
463- # Compute residuals
467+ # ===== SIMPLIFIED CONVERGENCE CHECK =====
468+
469+ # Primal residual: ||W - Z_W||² + ||H - Z_H||²
464470 r_norm = np .sqrt (np .sum ((W - Z_W )** 2 ) + np .sum ((H - Z_H )** 2 ))
465- s_norm = rho * np .sqrt (np .sum ((Z_W - Z_W_old )** 2 ) +
466- np .sum ((Z_H - Z_H_old )** 2 ))
467471
468- # Compute objective (off-diagonal penalty for W, full penalty for H)
472+ # Dual residual: ||ρ(Z_W - Z_W_old)||² + ||ρ(Z_H - Z_H_old)||²
473+ s_norm = np .sqrt (np .sum ((rho * (Z_W - Z_W_old ))** 2 ) +
474+ np .sum ((rho * (Z_H - Z_H_old ))** 2 ))
475+
476+ # Compute objective
469477 recon_error = np .sum ((A - W @ H ) ** 2 )
470- l1_W = alpha * np .sum (np .abs (Z_W [off_diag_mask ]))
471- l1_H = alpha * np .sum (np .abs (Z_H ))
478+ l1_W = alpha_w * np .sum (np .abs (Z_W [off_diag_mask ]))
479+ l1_H = alpha_h * np .sum (np .abs (Z_H ))
472480 objective = recon_error + l1_W + l1_H
473481
474482 # Track sparsity
@@ -486,26 +494,26 @@ def update_Z_H(H, U_H, alpha, rho):
486494 # Print progress
487495 if iteration % 10 == 0 or iteration < 10 :
488496 print (f" Iter { iteration :4d} : Obj={ objective :.4e} , "
489- f"||r||={ r_norm :.4e} , ||s||={ s_norm :.4e} , "
490- f"W_sparse={ w_sparsity :.2%} , H_sparse={ h_sparsity :.2%} " )
497+ f"r={ r_norm :.3e} , s={ s_norm :.3e} , ρ={ rho :.2f} " )
491498
492499 # Adaptive rho update
493500 if adaptive_rho and iteration > 0 :
494501 if r_norm > 10 * s_norm :
495502 rho = rho * 2
496503 U_W = U_W / 2
497504 U_H = U_H / 2
505+ print (f" Increased ρ → { rho :.2f} " )
498506 elif s_norm > 10 * r_norm :
499507 rho = rho / 2
500508 U_W = U_W * 2
501509 U_H = U_H * 2
510+ print (f" Decreased ρ → { rho :.2f} " )
502511
503- # Check convergence
504- eps_primal = tol * np .sqrt (W .size + H .size )
505- eps_dual = tol * np .sqrt (U_W .size + U_H .size )
506-
507- if r_norm < eps_primal and s_norm < eps_dual :
508- print (f"\n Converged at iteration { iteration } " )
512+ # Simple convergence check
513+ if r_norm < tol and s_norm < tol :
514+ print (f"\n ✓ Converged at iteration { iteration } " )
515+ print (f" Primal residual: { r_norm :.4e} < { tol :.4e} " )
516+ print (f" Dual residual: { s_norm :.4e} < { tol :.4e} " )
509517 break
510518
511519 # Final statistics
@@ -522,7 +530,7 @@ def update_Z_H(H, U_H, alpha, rho):
522530 print (f" Off-diagonal sparsity: { w_sparsity :.2%} " )
523531 print (f" Off-diagonal non-zeros: { np .sum (np .abs (Z_W [off_diag_mask ]) > 1e-3 )} " )
524532 print (f" Mean |W_offdiag|: { np .abs (Z_W [off_diag_mask ]).mean ():.4f} " )
525- print (f" Diagonal mean: { np . abs ( np . diag ( Z_W )). mean ():.4f } " )
533+ print (f" Diagonal: all 1.0 (constrained) " )
526534 print (f"\n H (effect patterns):" )
527535 print (f" Sparsity: { h_sparsity :.2%} " )
528536 print (f" Non-zeros: { np .sum (np .abs (Z_H ) > 1e-3 )} /{ Z_H .size } " )
0 commit comments