@@ -346,3 +346,186 @@ def gradient(x):
346346 print (f" Non-zeros: { np .sum (np .abs (H ) > 1e-3 )} /{ H .size } " )
347347
348348 return W , H
349+
350+ def deconvolution_cytokine_admm (
351+ A : np .ndarray ,
352+ alpha : float = 0.1 ,
353+ rho : float = 1.0 ,
354+ max_iter : int = 5000 ,
355+ tol : float = 1e-4 ,
356+ random_state : int = 1 ,
357+ adaptive_rho : bool = True ,
358+ ) -> tuple [np .ndarray , np .ndarray , dict ]:
359+ """
360+ Decompose cytokine factor matrix using ADMM: A ≈ W @ H
361+
362+ Parameters
363+ ----------
364+ A : np.ndarray
365+ Input matrix (n_cytokines, n_components)
366+ alpha : float
367+ L1 regularization strength (applied to both W and H)
368+ rho : float
369+ ADMM penalty parameter
370+ max_iter : int
371+ Maximum ADMM iterations
372+ tol : float
373+ Convergence tolerance
374+ random_state : int
375+ Random seed
376+ adaptive_rho : bool
377+ Whether to adaptively adjust rho
378+
379+ Returns
380+ -------
381+ Z_W : np.ndarray
382+ Cytokine interaction matrix (n_cytokines, n_cytokines)
383+ Z_H : np.ndarray
384+ Effect basis matrix (n_cytokines, n_components)
385+ history : dict
386+ Optimization history
387+ """
388+ n_cytokines , n_components = A .shape
389+ np .random .seed (random_state )
390+
391+ # Initialize
392+ W = np .eye (n_cytokines )
393+ H = A .copy ()
394+ Z_W = W .copy ()
395+ Z_H = H .copy ()
396+ U_W = np .zeros_like (W )
397+ U_H = np .zeros_like (H )
398+
399+ print ("Cytokine deconvolution with ADMM:" )
400+ print (f" A shape: { A .shape } " )
401+ print (f" Alpha (L1 penalty): { alpha } " )
402+ print (f" Rho (ADMM penalty): { rho } " )
403+ print (f" Adaptive rho: { adaptive_rho } " )
404+
405+ # Create mask for off-diagonal elements
406+ off_diag_mask = ~ np .eye (n_cytokines , dtype = bool )
407+
408+ def soft_threshold (X , threshold ):
409+ return np .sign (X ) * np .maximum (np .abs (X ) - threshold , 0 )
410+
411+ def update_W (H , Z_W , U_W , rho ):
412+ """Update W: solve (H@H^T + rho*I) W^T = (A@H^T + rho(Z_W - U_W))^T"""
413+ H_HT = H @ H .T
414+ A_HT = A @ H .T
415+ lhs = H_HT + rho * np .eye (n_cytokines )
416+ rhs = A_HT + rho * (Z_W - U_W )
417+ return np .linalg .solve (lhs , rhs .T ).T
418+
419+ def update_H (W , Z_H , U_H , rho ):
420+ """Update H: solve (W^T@W + rho*I) H = W^T@A + rho(Z_H - U_H)"""
421+ W_TW = W .T @ W
422+ W_TA = W .T @ A
423+ lhs = W_TW + rho * np .eye (n_cytokines )
424+ rhs = W_TA + rho * (Z_H - U_H )
425+ return np .linalg .solve (lhs , rhs )
426+
427+ def update_Z_W (W , U_W , alpha , rho ):
428+ """Update Z_W: soft-threshold off-diagonal, preserve diagonal"""
429+ X = W + U_W
430+ Z_W_new = soft_threshold (X , alpha / rho )
431+ # Restore diagonal (no L1 penalty on direct effects)
432+ np .fill_diagonal (Z_W_new , np .diag (X ))
433+ return Z_W_new
434+
435+ def update_Z_H (H , U_H , alpha , rho ):
436+ """Update Z_H: soft-threshold entire matrix"""
437+ return soft_threshold (H + U_H , alpha / rho )
438+
439+ history = {
440+ 'objective' : [],
441+ 'primal_residual' : [],
442+ 'dual_residual' : [],
443+ 'rho' : [],
444+ 'w_sparsity' : [],
445+ 'h_sparsity' : []
446+ }
447+
448+ print ("\n Starting ADMM iterations..." )
449+
450+ for iteration in range (max_iter ):
451+ # Store old Z values for dual residual computation
452+ Z_W_old = Z_W .copy ()
453+ Z_H_old = Z_H .copy ()
454+
455+ # ADMM updates
456+ W = update_W (H , Z_W , U_W , rho )
457+ H = update_H (W , Z_H , U_H , rho )
458+ Z_W = update_Z_W (W , U_W , alpha , rho )
459+ Z_H = update_Z_H (H , U_H , alpha , rho )
460+ U_W = U_W + (W - Z_W )
461+ U_H = U_H + (H - Z_H )
462+
463+ # Compute residuals
464+ r_norm = np .sqrt (np .sum ((W - Z_W )** 2 ) + np .sum ((H - Z_H )** 2 ))
465+ s_norm = rho * np .sqrt (np .sum ((Z_W - Z_W_old )** 2 ) +
466+ np .sum ((Z_H - Z_H_old )** 2 ))
467+
468+ # Compute objective (off-diagonal penalty for W, full penalty for H)
469+ recon_error = np .sum ((A - W @ H ) ** 2 )
470+ l1_W = alpha * np .sum (np .abs (Z_W [off_diag_mask ]))
471+ l1_H = alpha * np .sum (np .abs (Z_H ))
472+ objective = recon_error + l1_W + l1_H
473+
474+ # Track sparsity
475+ w_sparsity = np .sum (np .abs (Z_W [off_diag_mask ]) < 1e-3 ) / np .sum (off_diag_mask )
476+ h_sparsity = np .sum (np .abs (Z_H ) < 1e-3 ) / Z_H .size
477+
478+ # Store history
479+ history ['objective' ].append (objective )
480+ history ['primal_residual' ].append (r_norm )
481+ history ['dual_residual' ].append (s_norm )
482+ history ['rho' ].append (rho )
483+ history ['w_sparsity' ].append (w_sparsity )
484+ history ['h_sparsity' ].append (h_sparsity )
485+
486+ # Print progress
487+ if iteration % 10 == 0 or iteration < 10 :
488+ print (f" Iter { iteration :4d} : Obj={ objective :.4e} , "
489+ f"||r||={ r_norm :.4e} , ||s||={ s_norm :.4e} , "
490+ f"W_sparse={ w_sparsity :.2%} , H_sparse={ h_sparsity :.2%} " )
491+
492+ # Adaptive rho update
493+ if adaptive_rho and iteration > 0 :
494+ if r_norm > 10 * s_norm :
495+ rho = rho * 2
496+ U_W = U_W / 2
497+ U_H = U_H / 2
498+ elif s_norm > 10 * r_norm :
499+ rho = rho / 2
500+ U_W = U_W * 2
501+ U_H = U_H * 2
502+
503+ # Check convergence
504+ eps_primal = tol * np .sqrt (W .size + H .size )
505+ eps_dual = tol * np .sqrt (U_W .size + U_H .size )
506+
507+ if r_norm < eps_primal and s_norm < eps_dual :
508+ print (f"\n Converged at iteration { iteration } " )
509+ break
510+
511+ # Final statistics
512+ A_recon = W @ H
513+ rel_error = np .linalg .norm (A - A_recon , "fro" ) / np .linalg .norm (A , "fro" )
514+
515+ w_sparsity = np .sum (np .abs (Z_W [off_diag_mask ]) < 1e-3 ) / np .sum (off_diag_mask )
516+ h_sparsity = np .sum (np .abs (Z_H ) < 1e-3 ) / Z_H .size
517+
518+ print ("\n Optimization complete:" )
519+ print (f" Iterations: { iteration + 1 } /{ max_iter } " )
520+ print (f" Relative reconstruction error: { rel_error :.4%} " )
521+ print (f"\n W (cytokine interactions):" )
522+ print (f" Off-diagonal sparsity: { w_sparsity :.2%} " )
523+ print (f" Off-diagonal non-zeros: { np .sum (np .abs (Z_W [off_diag_mask ]) > 1e-3 )} " )
524+ print (f" Mean |W_offdiag|: { np .abs (Z_W [off_diag_mask ]).mean ():.4f} " )
525+ print (f" Diagonal mean: { np .abs (np .diag (Z_W )).mean ():.4f} " )
526+ print (f"\n H (effect patterns):" )
527+ print (f" Sparsity: { h_sparsity :.2%} " )
528+ print (f" Non-zeros: { np .sum (np .abs (Z_H ) > 1e-3 )} /{ Z_H .size } " )
529+ print (f" Mean |H|: { np .abs (Z_H ).mean ():.4f} " )
530+
531+ return Z_W , Z_H , history
0 commit comments