66import pandas as pd
77import torch
88import anndata as ad
9+ import scanpy as sc
910import tqdm
1011from scipy .sparse import csr_matrix
1112from scipy .stats import ttest_rel , spearmanr , pearsonr , wilcoxon
@@ -378,40 +379,63 @@ def main(par):
378379 j = gene_dict [target ]
379380 A [i , j ] = float (weight )
380381
381- # Only consider the genes that are actually present in the inferred GRN,
382- # and keep only the most-connected genes (for speed).
383- gene_mask = np .logical_or (np .any (A , axis = 1 ), np .any (A , axis = 0 ))
382+ # Compute HVGs from full evaluation data (for HVG-based evaluation)
383+ print ("\n ======== Computing HVGs from full evaluation data ========" )
384+ n_top_hvg = par ['n_top_genes' ]
385+ sc .pp .highly_variable_genes (adata , n_top_genes = n_top_hvg , flavor = 'seurat' , layer = layer )
386+ hvg_mask_full = adata .var ['highly_variable' ].values
387+ hvg_genes = gene_names [hvg_mask_full ]
388+ print (f"Total HVGs identified: { hvg_mask_full .sum ()} " )
389+
390+ # For GRN-based evaluation: keep only most-connected genes in the GRN
391+ print ("\n ======== Filtering genes for GRN-based evaluation ========" )
392+ gene_mask_grn = np .logical_or (np .any (A , axis = 1 ), np .any (A , axis = 0 ))
384393 in_degrees = np .sum (A != 0 , axis = 0 )
385394 out_degrees = np .sum (A != 0 , axis = 1 )
386- # n_genes = par['n_top_genes']
387- n_genes = 3000
388- idx = np .argsort (np .maximum (out_degrees , in_degrees ))[:- n_genes ]
389- gene_mask [idx ] = False
390- X = X [:, gene_mask ]
391- X = X .toarray () if isinstance (X , csr_matrix ) else X
392- A = A [gene_mask , :][:, gene_mask ]
393- gene_names = gene_names [gene_mask ]
395+ n_genes_grn = par ['n_top_genes' ]
396+ idx = np .argsort (np .maximum (out_degrees , in_degrees ))[:- n_genes_grn ]
397+ gene_mask_grn [idx ] = False
398+
399+ X_grn = X [:, gene_mask_grn ]
400+ X_grn = X_grn .toarray () if isinstance (X_grn , csr_matrix ) else X_grn
401+ A_grn = A [gene_mask_grn , :][:, gene_mask_grn ]
402+ gene_names_grn = gene_names [gene_mask_grn ]
403+ print (f"Genes for GRN-based evaluation: { len (gene_names_grn )} " )
394404
395405 # Remove self-regulations
396406 np .fill_diagonal (A , 0 )
407+ np .fill_diagonal (A_grn , 0 )
397408
398409 # Check whether the inferred GRN contains signed predictions
399410 if False :
400411 use_signs = np .any (A < 0 )
401412 else :
402413 use_signs = False
403414
404- # Center and scale dataset
405- scaler = StandardScaler ()
406- scaler .fit (X [are_controls , :]) # Use controls only to infer statistics (to avoid data leakage)
407- X = scaler .transform (X )
408-
409- # Get negative controls
410- X_controls = X [are_controls , :]
411- delta_X = compute_perturbations (X , are_controls , match_groups , loose_match_groups )
415+ # Center and scale dataset for GRN-based evaluation
416+ scaler_grn = StandardScaler ()
417+ X_grn_controls = X_grn [are_controls , :]
418+ scaler_grn .fit (X_grn_controls )
419+ X_grn_scaled = scaler_grn .transform (X_grn )
420+ X_grn_controls_scaled = X_grn_scaled [are_controls , :]
421+ delta_X_grn = compute_perturbations (X_grn_scaled , are_controls , match_groups , loose_match_groups )
422+ delta_X_grn = delta_X_grn [~ are_controls , :]
423+
424+ # Center and scale dataset for HVG-based evaluation (use all HVG genes, even if not in GRN)
425+ X_hvg = X [:, hvg_mask_full ]
426+ X_hvg = X_hvg .toarray () if isinstance (X_hvg , csr_matrix ) else X_hvg
427+ A_hvg = A [hvg_mask_full , :][:, hvg_mask_full ]
428+ gene_names_hvg = gene_names [hvg_mask_full ]
429+ scaler_hvg = StandardScaler ()
430+ X_hvg_controls = X_hvg [are_controls , :]
431+ scaler_hvg .fit (X_hvg_controls )
432+ X_hvg_scaled = scaler_hvg .transform (X_hvg )
433+ X_hvg_controls_scaled = X_hvg_scaled [are_controls , :]
434+ delta_X_hvg = compute_perturbations (X_hvg_scaled , are_controls , match_groups , loose_match_groups )
435+ delta_X_hvg = delta_X_hvg [~ are_controls , :]
436+ print (f"Genes for HVG-based evaluation: { len (gene_names_hvg )} " )
412437
413438 # Remove negative controls from downstream analysis
414- delta_X = delta_X [~ are_controls , :]
415439 cv_groups = cv_groups [~ are_controls ]
416440 match_groups = match_groups [~ are_controls ]
417441 loose_match_groups = loose_match_groups [~ are_controls ]
@@ -420,94 +444,76 @@ def main(par):
420444 # Make sure that no compound ends up in both sets.
421445 try :
422446 splitter = GroupShuffleSplit (test_size = 0.5 , n_splits = 2 , random_state = seed ) # Use consistent seed
423- train_idx , _ = next (splitter .split (delta_X , groups = cv_groups ))
447+ train_idx , _ = next (splitter .split (delta_X_grn , groups = cv_groups ))
424448 except ValueError :
425449 print ("Group k-fold failed. Using k-fold CV instead." )
426450 splitter = KFold (n_splits = 2 , random_state = seed , shuffle = True ) # Use consistent seed
427- train_idx , _ = next (splitter .split (delta_X ))
428- is_train = np .zeros (len (delta_X ), dtype = bool )
451+ train_idx , _ = next (splitter .split (delta_X_grn ))
452+ is_train = np .zeros (len (delta_X_grn ), dtype = bool )
429453 is_train [train_idx ] = True
430454
431- # Create a split between genes: reporter genes and evaluation genes.
432- # All TFs and IEGs should be included in the reporter gene set.
433- n_genes = A .shape [1 ]
434- reg_mask = np .asarray (A != 0 ).any (axis = 1 )
435- ieg_mask = np .asarray ([gene_name in IEG for gene_name in gene_names ], dtype = bool )
436- is_reporter = np .logical_or (reg_mask , ieg_mask )
437- print (f"Proportion of reporter genes: { np .mean (is_reporter )} " )
438- print (f"Use regulatory modes/signs: { use_signs } " )
455+ # ========== GRN-based evaluation ==========
456+ print ("\n ======== Evaluate inferred GRN (GRN-based: most connected genes) ========" )
457+ n_genes_grn = A_grn .shape [1 ]
458+ reg_mask_grn = np .asarray (A_grn != 0 ).any (axis = 1 )
459+ ieg_mask_grn = np .asarray ([gene_name in IEG for gene_name in gene_names_grn ], dtype = bool )
460+ is_reporter_grn = np .logical_or (reg_mask_grn , ieg_mask_grn )
461+ print (f"Proportion of reporter genes (GRN): { np .mean (is_reporter_grn )} " )
439462
440- # Create baseline model
441- try :
442- A_baseline = create_grn_baseline (A )
443- except :
444- print ("Failed to create baseline GRN. Using zero baseline." )
445- raise ValueError ("Failed to create baseline GRN." )
446-
447- # Evaluate inferred GRN
448- print ("\n ======== Evaluate inferred GRN ========" )
449- scores = evaluate_grn (X_controls , delta_X , is_train , is_reporter , A , signed = use_signs )
463+
464+ scores_grn = evaluate_grn (X_grn_controls_scaled , delta_X_grn , is_train , is_reporter_grn , A_grn , signed = use_signs )
465+ valid_scores_grn = scores_grn [~ np .isnan (scores_grn )]
466+
467+ if len (valid_scores_grn ) == 0 :
468+ print ("WARNING: No valid genes to evaluate for GRN-based!" )
469+ sem_grn_score = 0.0
470+ else :
471+ sem_grn_score = float (np .mean (valid_scores_grn ))
472+ print (f"SEM GRN score (mean R²): { sem_grn_score :.4f} " )
473+ print (f"Valid genes evaluated: { len (valid_scores_grn )} /{ len (scores_grn )} " )
474+ print (f"SEM GRN score (min): { np .min (valid_scores_grn ):.4f} " )
475+ print (f"SEM GRN score (max): { np .max (valid_scores_grn ):.4f} " )
476+
477+ # ========== HVG-based evaluation ==========
478+ print ("\n ======== Evaluate inferred GRN (HVG-based: highly variable genes) ========" )
479+ n_genes_hvg = A_hvg .shape [1 ]
480+ reg_mask_hvg = np .asarray (A_hvg != 0 ).any (axis = 1 )
481+ ieg_mask_hvg = np .asarray ([gene_name in IEG for gene_name in gene_names_hvg ], dtype = bool )
482+ is_reporter_hvg = np .logical_or (reg_mask_hvg , ieg_mask_hvg )
483+ print (f"Proportion of reporter genes (HVG): { np .mean (is_reporter_hvg )} " )
484+
485+ scores_hvg = evaluate_grn (X_hvg_controls_scaled , delta_X_hvg , is_train , is_reporter_hvg , A_hvg , signed = use_signs )
486+
487+ # For HVGs: genes with no GRN connections get score of 0 (penalize missing connections)
488+ has_parent_hvg = (np .asarray (A_hvg != 0 ).any (axis = 0 ))
489+ eval_mask_hvg = ~ is_reporter_hvg
490+ scores_hvg_penalized = scores_hvg .copy ()
491+ for j in range (len (scores_hvg_penalized )):
492+ if eval_mask_hvg [j ]:
493+ if not has_parent_hvg [j ]: # Gene has no connections in GRN
494+ scores_hvg_penalized [j ] = 0.0 # Penalize by setting score to 0
495+ elif np .isnan (scores_hvg_penalized [j ]):
496+ scores_hvg_penalized [j ] = 0.0 # Also set NaN to 0
450497
451- # Keep only valid scores (non-NaN)
452- valid_scores = scores [~ np .isnan (scores )]
498+ valid_scores_hvg = scores_hvg_penalized [eval_mask_hvg ]
453499
454- if len (valid_scores ) == 0 :
455- # No valid genes to evaluate
456- print ("WARNING: No valid genes to evaluate!" )
457- results = {'sem' : [0.0 ]}
500+ if len (valid_scores_hvg ) == 0 :
501+ print ("WARNING: No valid HVG genes to evaluate!" )
502+ sem_hvg_score = 0.0
458503 else :
459- # Final score is mean of valid R² scores
460- final_score = float (np .mean (valid_scores ))
461-
462- print (f"\n Method: { method_id } " )
463- print (f"SEM score (mean R²): { final_score :.4f} " )
464- print (f"Valid genes evaluated: { len (valid_scores )} /{ len (scores )} " )
465- print (f"SEM score (min): { np .min (valid_scores ):.4f} " )
466- print (f"SEM score (max): { np .max (valid_scores ):.4f} " )
467-
468- results = {'sem' : [float (final_score )]}
504+ sem_hvg_score = float (np .mean (valid_scores_hvg ))
505+ n_missing = np .sum (~ has_parent_hvg [eval_mask_hvg ])
506+ print (f"SEM HVG score (mean R²): { sem_hvg_score :.4f} " )
507+ print (f"HVG genes evaluated: { len (valid_scores_hvg )} " )
508+ print (f"HVG genes missing in GRN (penalized with 0): { n_missing } " )
509+ print (f"SEM HVG score (min): { np .min (valid_scores_hvg ):.4f} " )
510+ print (f"SEM HVG score (max): { np .max (valid_scores_hvg ):.4f} " )
469511
470- # Evaluate baseline GRN
471- if False :
472- print ("\n ======== Evaluate shuffled GRN ========" )
473- scores_baseline = evaluate_grn (X_controls , delta_X , is_train , is_reporter , A_baseline , signed = use_signs )
474-
475- # Keep only the genes for which both GRNs got a score
476- mask = ~ np .logical_or (np .isnan (scores ), np .isnan (scores_baseline ))
477- scores = scores [mask ]
478- scores_baseline = scores_baseline [mask ]
479-
480- rr_all = {}
481- # Perform rank test between actual scores and baseline
482- rr_all ['spearman' ] = float (np .mean (scores ))
483- rr_all ['spearman_shuffled' ] = float (np .mean (scores_baseline ))
484- if len (scores ) == 0 :
485- raise ValueError ("No valid scores to compare between inferred GRN and baseline GRN." )
486- elif np .all (scores - scores_baseline == 0 ):
487- # Identical performance (suspicious - likely an error)
488- raise ValueError ("Identical performance between inferred GRN and baseline GRN - likely an error." )
489- else :
490- res = wilcoxon (scores - scores_baseline , zero_method = 'wilcox' , alternative = 'greater' )
491- rr_all ['Wilcoxon pvalue' ] = float (res .pvalue )
492-
493- print (rr_all )
494-
495- eps = 1e-300 # very small number to avoid log(0)
496- pval_clipped = max (res .pvalue , eps )
497-
498- # Set to 0 if not significant (p >= 0.05)
499- if res .pvalue >= 0.05 :
500- score = 0.0
501- print (f"p-value: { res .pvalue :.6f} (not significant, p >= 0.05)" )
502- print (f"SEM score set to 0" )
503- else :
504- # Compute final score
505- score = - np .log10 (pval_clipped )
506- print (f"p-value: { res .pvalue :.6f} (significant)" )
507-
508- print (f"Final score: { score } " )
509- results ['sem_precision' ] = [float (np .log2 (np .mean (scores ) / (np .mean (scores_baseline ) + 1e-6 )))]
510- results ['sem_n' ] = [float (score )]
512+ results = {
513+ 'sem_grn' : [float (sem_grn_score )],
514+ 'sem_hvg' : [float (sem_hvg_score )],
515+ 'sem' : [float ((sem_grn_score + sem_hvg_score ) / 2 )]
516+ }
511517
512518 df_results = pd .DataFrame (results )
513519 return df_results
0 commit comments