Skip to content

Commit c6f60df

Browse files
committed
update analysis
1 parent d91a511 commit c6f60df

File tree

1 file changed

+23
-4
lines changed
  • examples/tuning/imputation_deepimpute

1 file changed

+23
-4
lines changed

examples/tuning/imputation_deepimpute/main.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from pathlib import Path
55

66
import numpy as np
7+
import pandas as pd
78
import torch
89
import wandb
910

@@ -40,22 +41,30 @@
4041
parser.add_argument("--sweep_id", type=str, default=None)
4142
parser.add_argument("--summary_file_path", default="results/pipeline/best_test_acc.csv", type=str)
4243
parser.add_argument("--root_path", default=str(Path(__file__).resolve().parent), type=str)
43-
44+
parser.add_argument("--get_result", action="store_true",help="save imputation result")
4445
params = parser.parse_args()
4546
print(vars(params))
4647
file_root_path = Path(params.root_path, params.dataset).resolve()
4748
logger.info(f"\n files is saved in {file_root_path}")
4849
pipeline_planer = PipelinePlaner.from_config_file(f"{file_root_path}/{params.tune_mode}_tuning_config.yaml")
4950
os.environ["WANDB_AGENT_MAX_INITIAL_FAILURES"] = "2000"
50-
51+
logger.info(params.tune_mode)
5152
def evaluate_pipeline(tune_mode=params.tune_mode, pipeline_planer=pipeline_planer):
5253
wandb.init(settings=wandb.Settings(start_method='thread'))
5354
set_seed(params.seed)
5455

5556
data = ImputationDataset(data_dir=params.data_dir, dataset=params.dataset,
5657
train_size=params.train_size).load_data()
5758
# Prepare preprocessing pipeline and apply it to data
58-
kwargs = {tune_mode: dict(wandb.config)}
59+
wandb_config = wandb.config
60+
if "run_kwargs" in pipeline_planer.config:
61+
if any(d == dict(wandb.config["run_kwargs"]) for d in pipeline_planer.config.run_kwargs):
62+
wandb_config = wandb_config["run_kwargs"]
63+
else:
64+
wandb.log({"skip": 1})
65+
wandb.finish()
66+
return
67+
kwargs = {tune_mode: dict(wandb_config)}
5968
preprocessing_pipeline = pipeline_planer.generate(**kwargs)
6069
print(f"Pipeline config:\n{preprocessing_pipeline.to_yaml()}")
6170
preprocessing_pipeline(data)
@@ -77,10 +86,20 @@ def evaluate_pipeline(tune_mode=params.tune_mode, pipeline_planer=pipeline_plane
7786
pcc = model.score(X, imputed_data, mask, "PCC")
7887
mre = model.score(X, imputed_data, mask, metric="MRE")
7988
wandb.log({"RMSE": score, "PCC": pcc, "MRE": mre})
80-
89+
if params.get_result:
90+
result=model.predict(X,None)
91+
array = result.detach().cpu().numpy()
92+
df = pd.DataFrame(
93+
data=array,
94+
index=data.data.obs_names,
95+
columns=data.data.var_names
96+
)
97+
df.to_csv(f"{params.dataset}/result.csv")
8198
entity, project, sweep_id = pipeline_planer.wandb_sweep_agent(
8299
evaluate_pipeline, sweep_id=params.sweep_id, count=params.count) #Score can be recorded for each epoch
83100
save_summary_data(entity, project, sweep_id, summary_file_path=params.summary_file_path, root_path=file_root_path)
101+
if params.get_result:
102+
sys.exit(0)
84103
if params.tune_mode == "pipeline" or params.tune_mode == "pipeline_params":
85104
get_step3_yaml(result_load_path=f"{params.summary_file_path}", step2_pipeline_planer=pipeline_planer,
86105
conf_load_path=f"{Path(params.root_path).resolve().parent}/step3_default_params.yaml",

0 commit comments

Comments
 (0)