|
48 | 48 | parser.add_argument("--sweep_id", type=str, default=None) |
49 | 49 | parser.add_argument("--summary_file_path", default="results/pipeline/best_test_acc.csv", type=str) |
50 | 50 | parser.add_argument("--root_path", default=str(Path(__file__).resolve().parent), type=str) |
51 | | - parser.add_argument("--get_result", action="store_true",help="save imputation result") |
| 51 | + parser.add_argument("--get_result", action="store_true", help="save imputation result") |
52 | 52 | params = parser.parse_args() |
53 | 53 | print(vars(params)) |
54 | 54 | file_root_path = Path(params.root_path, params.dataset).resolve() |
@@ -126,15 +126,12 @@ def evaluate_pipeline(tune_mode=params.tune_mode, pipeline_planer=pipeline_plane |
126 | 126 | gc.collect() |
127 | 127 | torch.cuda.empty_cache() |
128 | 128 | if params.get_result: |
129 | | - result=model.predict(X,X_raw,g,None) |
| 129 | + result = model.predict(X, X_raw, g, None) |
130 | 130 | array = result.detach().cpu().numpy() |
131 | 131 | # Create DataFrame |
132 | | - df = pd.DataFrame( |
133 | | - data=array, |
134 | | - index=data.data.obs_names, |
135 | | - columns=data.data.var_names |
136 | | - ) |
| 132 | + df = pd.DataFrame(data=array, index=data.data.obs_names, columns=data.data.var_names) |
137 | 133 | df.to_csv(f"{params.dataset}/result.csv") |
| 134 | + |
138 | 135 | entity, project, sweep_id = pipeline_planer.wandb_sweep_agent( |
139 | 136 | evaluate_pipeline, sweep_id=params.sweep_id, count=params.count) #Score can be recorded for each epoch |
140 | 137 | save_summary_data(entity, project, sweep_id, summary_file_path=params.summary_file_path, root_path=file_root_path) |
|
0 commit comments