Skip to content

Commit 8e5f642

Browse files
committed
Merge remote-tracking branch 'origin/celltype_annotation_automl' into celltype_annotation_automl
2 parents d7f0475 + 55b8329 commit 8e5f642

File tree

2 files changed

+8
-14
lines changed

2 files changed

+8
-14
lines changed

examples/tuning/imputation_graphsci/main.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
parser.add_argument("--sweep_id", type=str, default=None)
4949
parser.add_argument("--summary_file_path", default="results/pipeline/best_test_acc.csv", type=str)
5050
parser.add_argument("--root_path", default=str(Path(__file__).resolve().parent), type=str)
51-
parser.add_argument("--get_result", action="store_true",help="save imputation result")
51+
parser.add_argument("--get_result", action="store_true", help="save imputation result")
5252
params = parser.parse_args()
5353
print(vars(params))
5454
file_root_path = Path(params.root_path, params.dataset).resolve()
@@ -126,15 +126,12 @@ def evaluate_pipeline(tune_mode=params.tune_mode, pipeline_planer=pipeline_plane
126126
gc.collect()
127127
torch.cuda.empty_cache()
128128
if params.get_result:
129-
result=model.predict(X,X_raw,g,None)
129+
result = model.predict(X, X_raw, g, None)
130130
array = result.detach().cpu().numpy()
131131
# Create DataFrame
132-
df = pd.DataFrame(
133-
data=array,
134-
index=data.data.obs_names,
135-
columns=data.data.var_names
136-
)
132+
df = pd.DataFrame(data=array, index=data.data.obs_names, columns=data.data.var_names)
137133
df.to_csv(f"{params.dataset}/result.csv")
134+
138135
entity, project, sweep_id = pipeline_planer.wandb_sweep_agent(
139136
evaluate_pipeline, sweep_id=params.sweep_id, count=params.count) #Score can be recorded for each epoch
140137
save_summary_data(entity, project, sweep_id, summary_file_path=params.summary_file_path, root_path=file_root_path)

examples/tuning/imputation_scgnn2/main.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@
176176
parser.add_argument("--sweep_id", type=str, default=None)
177177
parser.add_argument("--summary_file_path", default="results/pipeline/best_test_acc.csv", type=str)
178178
parser.add_argument("--root_path", default=str(Path(__file__).resolve().parent), type=str)
179-
parser.add_argument("--get_result", action="store_true",help="save imputation result")
179+
parser.add_argument("--get_result", action="store_true", help="save imputation result")
180180
args = parser.parse_args()
181181
logger.info(pformat(vars(args)))
182182
file_root_path = Path(args.root_path, args.dataset).resolve()
@@ -247,15 +247,12 @@ def evaluate_pipeline(tune_mode=args.tune_mode, pipeline_planer=pipeline_planer)
247247
"test_MRE": test_mre
248248
})
249249
if args.get_result:
250-
result=model.predict()
250+
result = model.predict()
251251
array = result.detach().cpu().numpy()
252252
# Create DataFrame
253-
df = pd.DataFrame(
254-
data=array,
255-
index=data.data.obs_names,
256-
columns=data.data.var_names
257-
)
253+
df = pd.DataFrame(data=array, index=data.data.obs_names, columns=data.data.var_names)
258254
df.to_csv(f"{args.dataset}/result.csv")
255+
259256
entity, project, sweep_id = pipeline_planer.wandb_sweep_agent(
260257
evaluate_pipeline, sweep_id=args.sweep_id, count=args.count) #Score can be recorded for each epoch
261258
save_summary_data(entity, project, sweep_id, summary_file_path=args.summary_file_path, root_path=file_root_path)

0 commit comments

Comments
 (0)