44from pathlib import Path
55
66import numpy as np
7+ import pandas as pd
78import torch
89import wandb
910
4041 parser .add_argument ("--sweep_id" , type = str , default = None )
4142 parser .add_argument ("--summary_file_path" , default = "results/pipeline/best_test_acc.csv" , type = str )
4243 parser .add_argument ("--root_path" , default = str (Path (__file__ ).resolve ().parent ), type = str )
43-
44+ parser . add_argument ( "--get_result" , action = "store_true" , help = "save imputation result" )
4445 params = parser .parse_args ()
4546 print (vars (params ))
4647 file_root_path = Path (params .root_path , params .dataset ).resolve ()
4748 logger .info (f"\n files is saved in { file_root_path } " )
4849 pipeline_planer = PipelinePlaner .from_config_file (f"{ file_root_path } /{ params .tune_mode } _tuning_config.yaml" )
4950 os .environ ["WANDB_AGENT_MAX_INITIAL_FAILURES" ] = "2000"
50-
51+ logger . info ( params . tune_mode )
5152 def evaluate_pipeline (tune_mode = params .tune_mode , pipeline_planer = pipeline_planer ):
5253 wandb .init (settings = wandb .Settings (start_method = 'thread' ))
5354 set_seed (params .seed )
5455
5556 data = ImputationDataset (data_dir = params .data_dir , dataset = params .dataset ,
5657 train_size = params .train_size ).load_data ()
5758 # Prepare preprocessing pipeline and apply it to data
58- kwargs = {tune_mode : dict (wandb .config )}
59+ wandb_config = wandb .config
60+ if "run_kwargs" in pipeline_planer .config :
61+ if any (d == dict (wandb .config ["run_kwargs" ]) for d in pipeline_planer .config .run_kwargs ):
62+ wandb_config = wandb_config ["run_kwargs" ]
63+ else :
64+ wandb .log ({"skip" : 1 })
65+ wandb .finish ()
66+ return
67+ kwargs = {tune_mode : dict (wandb_config )}
5968 preprocessing_pipeline = pipeline_planer .generate (** kwargs )
6069 print (f"Pipeline config:\n { preprocessing_pipeline .to_yaml ()} " )
6170 preprocessing_pipeline (data )
@@ -77,10 +86,20 @@ def evaluate_pipeline(tune_mode=params.tune_mode, pipeline_planer=pipeline_plane
7786 pcc = model .score (X , imputed_data , mask , "PCC" )
7887 mre = model .score (X , imputed_data , mask , metric = "MRE" )
7988 wandb .log ({"RMSE" : score , "PCC" : pcc , "MRE" : mre })
80-
89+ if params .get_result :
90+ result = model .predict (X ,None )
91+ array = result .detach ().cpu ().numpy ()
92+ df = pd .DataFrame (
93+ data = array ,
94+ index = data .data .obs_names ,
95+ columns = data .data .var_names
96+ )
97+ df .to_csv (f"{ params .dataset } /result.csv" )
8198 entity , project , sweep_id = pipeline_planer .wandb_sweep_agent (
8299 evaluate_pipeline , sweep_id = params .sweep_id , count = params .count ) #Score can be recorded for each epoch
83100 save_summary_data (entity , project , sweep_id , summary_file_path = params .summary_file_path , root_path = file_root_path )
101+ if params .get_result :
102+ sys .exit (0 )
84103 if params .tune_mode == "pipeline" or params .tune_mode == "pipeline_params" :
85104 get_step3_yaml (result_load_path = f"{ params .summary_file_path } " , step2_pipeline_planer = pipeline_planer ,
86105 conf_load_path = f"{ Path (params .root_path ).resolve ().parent } /step3_default_params.yaml" ,
0 commit comments