@@ -288,7 +288,11 @@ def gradient(x):
288288
289289 # ===== Gradient w.r.t. W =====
290290 # 1. Reconstruction term: ∂/∂W [||A - WH||²] = 2(error @ H^T), L1 penalty: ∂/∂W [α||W||₁] = α * sign(W)
291- grad_W = 2 * ((W @ H - A ) @ H .T ) + alpha * np .sign (W ) - np .diag (alpha * np .sign (np .diag (W )))
291+ grad_W = (
292+ 2 * ((W @ H - A ) @ H .T )
293+ + alpha * np .sign (W )
294+ - np .diag (alpha * np .sign (np .diag (W )))
295+ )
292296
293297 # ===== Gradient w.r.t. H =====
294298 # 1. Reconstruction term: ∂/∂H [||A - WH||²] = 2(W^T @ error), L1 penalty: ∂/∂H [α||H||₁] = α * sign(H)
@@ -347,19 +351,21 @@ def gradient(x):
347351
348352 return W , H
349353
354+
350355def deconvolution_cytokine_admm (
351356 A : np .ndarray ,
352357 alpha_h : float = 0.1 ,
353358 alpha_w : float = 0.01 ,
354359 rho : float = 1.0 ,
355- max_iter : int = 5000 ,
356- tol : float = 1e-4 , # Single tolerance for both primal and dual
360+ max_iter : int = 10000 ,
361+ tol : float = 1e-4 ,
357362 random_state : int = 1 ,
358363 adaptive_rho : bool = True ,
364+ non_negative_w : bool = True ,
359365) -> tuple [np .ndarray , np .ndarray , dict ]:
360366 """
361367 Decompose cytokine factor matrix using ADMM: A ≈ W @ H
362-
368+
363369 Parameters
364370 ----------
365371 A : np.ndarray
@@ -378,7 +384,9 @@ def deconvolution_cytokine_admm(
378384 Random seed
379385 adaptive_rho : bool
380386 Whether to adaptively adjust rho
381-
387+ non_negative_w : bool
388+ If True, enforce W ≥ 0 (cytokines only activate, not inhibit)
389+
382390 Returns
383391 -------
384392 Z_W : np.ndarray
@@ -390,112 +398,132 @@ def deconvolution_cytokine_admm(
390398 """
391399 n_cytokines , n_components = A .shape
392400 np .random .seed (random_state )
393-
401+
394402 # Initialize
395403 W = np .eye (n_cytokines )
396404 H = A .copy ()
397405 Z_W = W .copy ()
398406 Z_H = H .copy ()
399407 U_W = np .zeros_like (W )
400408 U_H = np .zeros_like (H )
401-
409+
402410 print ("Cytokine deconvolution with ADMM:" )
403411 print (f" A shape: { A .shape } " )
404412 print (f" Alpha_W: { alpha_w } , Alpha_H: { alpha_h } " )
405413 print (f" Rho: { rho } " )
406414 print (f" Tolerance: { tol } " )
407-
415+ print (f" Non-negative W: { non_negative_w } " )
416+
408417 off_diag_mask = ~ np .eye (n_cytokines , dtype = bool )
409-
418+
410419 def soft_threshold (X , threshold ):
411420 return np .sign (X ) * np .maximum (np .abs (X ) - threshold , 0 )
412-
421+
413422 def update_W (H , Z_W , U_W , rho ):
414- """Update W: constrain diagonal to 1.0"""
423+ """Update W: constrain diagonal to 1.0, optional non-negativity """
415424 H_HT = H @ H .T
416425 A_HT = A @ H .T
417426 lhs = H_HT + rho * np .eye (n_cytokines )
418427 rhs = A_HT + rho * (Z_W - U_W )
419-
428+
420429 W_new = np .linalg .solve (lhs , rhs .T ).T
421- np .fill_diagonal (W_new , 1.0 )
422430
431+ # Non-negativity constraint for W
432+ if non_negative_w :
433+ W_new = np .maximum (W_new , 0 )
434+
435+ # Diagonal constraint
436+ np .fill_diagonal (W_new , 1.0 )
437+
423438 return W_new
424-
439+
425440 def update_H (W , Z_H , U_H , rho ):
426- """Update H"""
441+ """Update H: NO non-negativity constraint """
427442 W_TW = W .T @ W
428443 W_TA = W .T @ A
429444 lhs = W_TW + rho * np .eye (n_cytokines )
430445 rhs = W_TA + rho * (Z_H - U_H )
446+
431447 return np .linalg .solve (lhs , rhs )
432-
448+
433449 def update_Z_W (W , U_W , alpha , rho ):
434- """Update Z_W: soft-threshold off-diagonal only """
450+ """Update Z_W: soft-threshold off-diagonal, optional non-negativity """
435451 X = W + U_W
436452 Z_W_new = X .copy ()
453+
454+ # Soft-threshold off-diagonal
437455 Z_W_new [off_diag_mask ] = soft_threshold (X [off_diag_mask ], alpha / rho )
456+
457+ # Non-negativity constraint for W
458+ if non_negative_w :
459+ Z_W_new = np .maximum (Z_W_new , 0 )
460+
461+ # Diagonal constraint
462+ np .fill_diagonal (Z_W_new , 1.0 )
463+
438464 return Z_W_new
439-
465+
440466 def update_Z_H (H , U_H , alpha , rho ):
441- """Update Z_H: soft-threshold entire matrix"""
467+ """Update Z_H: soft-threshold, NO non-negativity"""
468+ # H can be negative
442469 return soft_threshold (H + U_H , alpha / rho )
443-
470+
444471 history = {
445- ' objective' : [],
446- ' primal_residual' : [],
447- ' dual_residual' : [],
448- ' rho' : [],
449- ' w_sparsity' : [],
450- ' h_sparsity' : []
472+ " objective" : [],
473+ " primal_residual" : [],
474+ " dual_residual" : [],
475+ " rho" : [],
476+ " w_sparsity" : [],
477+ " h_sparsity" : [],
451478 }
452-
479+
453480 print ("\n Starting ADMM iterations..." )
454-
481+
455482 for iteration in range (max_iter ):
456483 Z_W_old = Z_W .copy ()
457484 Z_H_old = Z_H .copy ()
458-
485+
459486 # ADMM updates
460487 W = update_W (H , Z_W , U_W , rho )
461488 H = update_H (W , Z_H , U_H , rho )
462489 Z_W = update_Z_W (W , U_W , alpha_w , rho )
463490 Z_H = update_Z_H (H , U_H , alpha_h , rho )
464491 U_W = U_W + (W - Z_W )
465492 U_H = U_H + (H - Z_H )
466-
467- # ===== SIMPLIFIED CONVERGENCE CHECK =====
468-
493+
469494 # Primal residual: ||W - Z_W||² + ||H - Z_H||²
470- r_norm = np .sqrt (np .sum ((W - Z_W )** 2 ) + np .sum ((H - Z_H )** 2 ))
471-
495+ r_norm = np .sqrt (np .sum ((W - Z_W ) ** 2 ) + np .sum ((H - Z_H ) ** 2 ))
496+
472497 # Dual residual: ||ρ(Z_W - Z_W_old)||² + ||ρ(Z_H - Z_H_old)||²
473- s_norm = np .sqrt (np .sum ((rho * (Z_W - Z_W_old ))** 2 ) +
474- np .sum ((rho * (Z_H - Z_H_old ))** 2 ))
475-
498+ s_norm = np .sqrt (
499+ np .sum ((rho * (Z_W - Z_W_old )) ** 2 ) + np .sum ((rho * (Z_H - Z_H_old )) ** 2 )
500+ )
501+
476502 # Compute objective
477503 recon_error = np .sum ((A - W @ H ) ** 2 )
478504 l1_W = alpha_w * np .sum (np .abs (Z_W [off_diag_mask ]))
479505 l1_H = alpha_h * np .sum (np .abs (Z_H ))
480506 objective = recon_error + l1_W + l1_H
481-
507+
482508 # Track sparsity
483509 w_sparsity = np .sum (np .abs (Z_W [off_diag_mask ]) < 1e-3 ) / np .sum (off_diag_mask )
484510 h_sparsity = np .sum (np .abs (Z_H ) < 1e-3 ) / Z_H .size
485-
511+
486512 # Store history
487- history [' objective' ].append (objective )
488- history [' primal_residual' ].append (r_norm )
489- history [' dual_residual' ].append (s_norm )
490- history [' rho' ].append (rho )
491- history [' w_sparsity' ].append (w_sparsity )
492- history [' h_sparsity' ].append (h_sparsity )
493-
513+ history [" objective" ].append (objective )
514+ history [" primal_residual" ].append (r_norm )
515+ history [" dual_residual" ].append (s_norm )
516+ history [" rho" ].append (rho )
517+ history [" w_sparsity" ].append (w_sparsity )
518+ history [" h_sparsity" ].append (h_sparsity )
519+
494520 # Print progress
495521 if iteration % 10 == 0 or iteration < 10 :
496- print (f" Iter { iteration :4d} : Obj={ objective :.4e} , "
497- f"r={ r_norm :.3e} , s={ s_norm :.3e} , ρ={ rho :.2f} " )
498-
522+ print (
523+ f" Iter { iteration :4d} : Obj={ objective :.4e} , "
524+ f"r={ r_norm :.3e} , s={ s_norm :.3e} , ρ={ rho :.2f} "
525+ )
526+
499527 # Adaptive rho update
500528 if adaptive_rho and iteration > 0 :
501529 if r_norm > 10 * s_norm :
@@ -508,32 +536,39 @@ def update_Z_H(H, U_H, alpha, rho):
508536 U_W = U_W * 2
509537 U_H = U_H * 2
510538 print (f" Decreased ρ → { rho :.2f} " )
511-
539+
512540 # Simple convergence check
513541 if r_norm < tol and s_norm < tol :
514542 print (f"\n ✓ Converged at iteration { iteration } " )
515543 print (f" Primal residual: { r_norm :.4e} < { tol :.4e} " )
516544 print (f" Dual residual: { s_norm :.4e} < { tol :.4e} " )
517545 break
518-
546+
519547 # Final statistics
520548 A_recon = W @ H
521549 rel_error = np .linalg .norm (A - A_recon , "fro" ) / np .linalg .norm (A , "fro" )
522-
550+
523551 w_sparsity = np .sum (np .abs (Z_W [off_diag_mask ]) < 1e-3 ) / np .sum (off_diag_mask )
524552 h_sparsity = np .sum (np .abs (Z_H ) < 1e-3 ) / Z_H .size
525-
553+
526554 print ("\n Optimization complete:" )
527555 print (f" Iterations: { iteration + 1 } /{ max_iter } " )
528556 print (f" Relative reconstruction error: { rel_error :.4%} " )
529- print (f"\n W (cytokine interactions):" )
557+
558+ print ("\n W (cytokine interactions):" )
530559 print (f" Off-diagonal sparsity: { w_sparsity :.2%} " )
531560 print (f" Off-diagonal non-zeros: { np .sum (np .abs (Z_W [off_diag_mask ]) > 1e-3 )} " )
532561 print (f" Mean |W_offdiag|: { np .abs (Z_W [off_diag_mask ]).mean ():.4f} " )
562+ print (f" Min value: { W .min ():.4f} " ) # Check non-negativity
563+ print (f" Max value: { W .max ():.4f} " )
533564 print (f" Diagonal: all 1.0 (constrained)" )
534- print (f"\n H (effect patterns):" )
565+
566+ print ("\n H (effect patterns):" )
535567 print (f" Sparsity: { h_sparsity :.2%} " )
536568 print (f" Non-zeros: { np .sum (np .abs (Z_H ) > 1e-3 )} /{ Z_H .size } " )
537569 print (f" Mean |H|: { np .abs (Z_H ).mean ():.4f} " )
538-
570+ print (f" Min value: { H .min ():.4f} " ) # Can be negative
571+ print (f" Max value: { H .max ():.4f} " )
572+ print (f" Negative values: { np .sum (H < 0 )} ({ 100 * np .sum (H < 0 )/ H .size :.1f} %)" )
573+
539574 return Z_W , Z_H , history
0 commit comments