22Hyperparameter sweep for Pf2 using Weights & Biases
33Optimizing rank and regularization parameter
44"""
5- import os
5+
66import numpy as np
7- import pandas as pd
8- import anndata
97import wandb
10- from tensorly .cp_tensor import CPTensor
11- from tlviz .factor_tools import factor_match_score as fms
12- import matplotlib .pyplot as plt
13- import seaborn as sns
14-
158from factorization import pf2
169from imports import import_cytokine
17-
10+ from tensorly .cp_tensor import CPTensor
11+ from tlviz .factor_tools import factor_match_score as fms
1812
1913ranks = np .arange (1 , 31 )
2014# Define the sweep configuration
2115sweep_config = {
22- ' method' : ' grid' , # grid search for thorough exploration
23- ' metric' : {
24- ' name' : ' fms' , # optimize for factor match score
25- ' goal' : ' maximize' # we want to maximize factor stability
16+ " method" : " grid" , # grid search for thorough exploration
17+ " metric" : {
18+ " name" : " fms" , # optimize for factor match score
19+ " goal" : " maximize" , # we want to maximize factor stability
2620 },
27- ' parameters' : {
28- ' rank' : {
29- ' values' : ranks # Different component numbers to test
21+ " parameters" : {
22+ " rank" : {
23+ " values" : ranks # Different component numbers to test
3024 },
31- 'regParam' : {
32- 'values' : [0.0 , 1e-6 , 1e-5 , 5e-5 , 1e-4 ] # Different L1 regularization strengths
33- }
34- }
25+ "regParam" : {
26+ "values" : [
27+ 0.0 ,
28+ 1e-6 ,
29+ 1e-5 ,
30+ 5e-5 ,
31+ 1e-4 ,
32+ ] # Different L1 regularization strengths
33+ },
34+ },
3535}
3636
37+
3738def resample (data ):
3839 """Bootstrapping dataset"""
3940 indices = np .random .randint (0 , data .shape [0 ], size = (data .shape [0 ],))
4041 return data [indices ].copy ()
4142
43+
4244def calculateFMS (A , B ):
4345 """Calculates FMS between 2 factorizations"""
4446 A_factors = [A .uns ["Pf2_A" ], A .uns ["Pf2_B" ], A .varm ["Pf2_C" ]]
@@ -49,78 +51,82 @@ def calculateFMS(A, B):
4951
5052 return fms (A_CP , B_CP , consider_weights = False , skip_mode = 1 )
5153
54+
5255def calculate_sparsity (matrix , threshold = 1e-6 ):
5356 """Calculate sparsity (proportion of near-zero elements)"""
5457 total_elements = matrix .size
5558 near_zero_elements = np .sum (np .abs (matrix ) < threshold )
5659 return near_zero_elements / total_elements
5760
61+
5862def train ():
5963 """Main training function for wandb sweep"""
6064 # Initialize a new wandb run
6165 with wandb .init () as run :
6266 # Get parameters from wandb
6367 config = wandb .config
64-
68+
6569 # Load data (do this once per run to save time)
6670 X = import_cytokine ()
6771 print (f"Running with rank={ config .rank } , regParam={ config .regParam } " )
68-
72+
6973 # Set number of bootstrap samples
7074 n_bootstrap = 3
71-
75+
7276 # Run base factorization with current parameters
73- base_model , r2x = pf2 (X ,
74- rank = config .rank ,
75- random_state = 42 ,
76- doEmbedding = False ,
77- regParam = config .regParam ,
78- r2x = True )
79-
80-
77+ base_model , r2x = pf2 (
78+ X ,
79+ rank = config .rank ,
80+ random_state = 42 ,
81+ doEmbedding = False ,
82+ regParam = config .regParam ,
83+ r2x = True ,
84+ )
85+
8186 sparsity_C = calculate_sparsity (base_model .varm ["Pf2_C" ])
82-
83-
87+
8488 # Log R2X and sparsity metrics
85- wandb .log ({
86- "r2x" : r2x ,
87-
88- "sparsity_C" : sparsity_C
89-
90- })
91-
89+ wandb .log ({"r2x" : r2x , "sparsity_C" : sparsity_C })
90+
9291 # Calculate FMS across bootstrap samples
9392 fms_scores = []
9493 for i in range (n_bootstrap ):
9594 # Create bootstrap sample
9695 bootstrap_data = resample (X )
97-
96+
9897 # Run factorization on bootstrap sample
99- bootstrap_model = pf2 (bootstrap_data ,
100- rank = config .rank ,
101- random_state = i ,
102- doEmbedding = False ,
103- regParam = config .regParam )
104-
98+ bootstrap_model = pf2 (
99+ bootstrap_data ,
100+ rank = config .rank ,
101+ random_state = i ,
102+ doEmbedding = False ,
103+ regParam = config .regParam ,
104+ )
105+
105106 # Calculate FMS between base model and bootstrap model
106107 fms_score = calculateFMS (base_model , bootstrap_model )
107108 fms_scores .append (fms_score )
108-
109+
109110 # Log individual bootstrap FMS
110111 wandb .log ({f"fms_bootstrap_{ i } " : fms_score })
111-
112+
112113 # Calculate and log average FMS
113114 avg_fms = np .mean (fms_scores )
114115 wandb .log ({"fms" : avg_fms })
115-
116- print (f"Completed run: rank={ config .rank } , regParam={ config .regParam } , R2X={ r2x :.4f} , FMS={ avg_fms :.4f} " )
116+
117+ print (
118+ f"Completed run: rank={ config .rank } , regParam={ config .regParam } , R2X={ r2x :.4f} , FMS={ avg_fms :.4f} "
119+ )
120+
117121
118122if __name__ == "__main__" :
119123 # Initialize wandb
120124 wandb .login ()
121-
125+
122126 # Create the sweep
123127 sweep_id = wandb .sweep (sweep_config , project = "Pf2_parameter_optimization2" )
124-
128+
125129 # Run the sweep
126- wandb .agent (sweep_id , function = train , count = None ) # Set count if you want to limit runs
130+ wandb .agent (
131+ sweep_id , function = train , count = None
132+ ) # Set count if you want to limit runs
0 commit comments