Skip to content

Commit f81b400

Browse files
Implement training evolution eval
1 parent 8f7dec0 commit f81b400

File tree

2 files changed

+48
-28
lines changed

2 files changed

+48
-28
lines changed

finetuning/generalists/evaluate_training_evolution.py

Lines changed: 47 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,57 +1,77 @@
11
import argparse
22
import os
3+
34
from glob import glob
5+
from subprocess import run
46

57
import pandas as pd
68
from util import evaluate_checkpoint_for_datasets, get_generalist_predictor
79

8-
CHECKPOINT_ROOT = "/scratch-grete/projects/nim00007/sam/LM/generalist"
9-
EXPERIMENT_ROOT = "/scratch-grete/projects/nim00007/sam/experiments/generalists/lm"
10+
CHECKPOINT_ROOT = "/scratch/projects/nim00007/sam/models/LM/generalist/v2"
11+
EXPERIMENT_ROOT = "/scratch/projects/nim00007/sam/experiments/training-evolution"
1012
# We evaluate these three datasets for the training evolution.
1113
# These are chosen based on observations from preliminary experiments.
1214
# - covid-if: out-of-domain dataset that shows the expected improvement (over vanilla).
1315
# - deepbacs: in domain dataset where we see the biggest gap to the specialist.
14-
# - plantseg-root: out-of-domain dataset that doesn't show an improvement.
15-
DATASETS = ("covid-if", "deepbacs", "plantseg-root")
16-
16+
# - lizard: out-of-domain that is furthest from the training data.
17+
EVAL_DATASETS = ("covid-if", "deepbacs", "lizard")
1718

18-
def evaluate_training_evolution(model_type):
19-
checkpoints = sorted(glob(
20-
os.path.join(CHECKPOINT_ROOT, model_type, "*.pt")
21-
))
22-
assert len(checkpoints) > 0
2319

24-
epochs, results = [], []
25-
for checkpoint in checkpoints:
20+
def evaluate_checkpoint_slurm(model_type, job_id, checkpoints):
21+
checkpoint = checkpoints[job_id]
2622

27-
predictor, state = get_generalist_predictor(checkpoint, model_type, return_state=True)
28-
epoch = state["epoch"] + 1
23+
predictor, state = get_generalist_predictor(
24+
checkpoint, model_type, is_custom_model=True, return_state=True
25+
)
26+
epoch = state["epoch"] + 1
2927

30-
if epoch in epochs:
31-
continue
28+
print("Run evaluation for", model_type, "epoch", epoch)
29+
experiment_root = os.path.join(EXPERIMENT_ROOT, f"{model_type}-epoch-{epoch}")
30+
result = evaluate_checkpoint_for_datasets(
31+
None, None, experiment_root, EVAL_DATASETS,
32+
run_default_evaluation=True, run_amg=False,
33+
is_custom_model=True, predictor=predictor,
34+
)
3235

33-
print("Run evaluation for", model_type, "epoch", epoch)
34-
experiment_root = os.path.join(EXPERIMENT_ROOT, f"{model_type}-epoch-{epoch}")
35-
result = evaluate_checkpoint_for_datasets(
36-
None, None, experiment_root, DATASETS,
37-
run_default_evaluation=True, run_amg=False,
38-
predictor=predictor,
39-
)
40-
result.insert(0, "epoch", [epoch] * result.shape[0])
41-
results.append(result)
36+
result.insert(0, "epoch", [epoch] * result.shape[0])
37+
return result
4238

43-
epochs.append(epoch)
4439

40+
def evaluate_training_evolution(model_type, checkpoints):
41+
results = []
42+
for i in range(len(checkpoints)):
43+
result = evaluate_checkpoint_slurm(model_type, i, checkpoints)
44+
results.append(result)
4545
results = pd.concat(results)
4646
save_path = os.path.join(EXPERIMENT_ROOT, f"{model_type}.csv")
4747
results.to_csv(save_path, index=False)
4848

4949

50+
def submit_array_job(model_type, checkpoints):
51+
n_checkpoints = len(checkpoints)
52+
cmd = ["sbatch", "-a", f"0-{n_checkpoints-1}", "evaluate_training_evolution.sbatch", model_type]
53+
run(cmd)
54+
55+
5056
def main():
5157
parser = argparse.ArgumentParser()
5258
parser.add_argument("model_type")
59+
parser.add_argument("-e", "--evaluate", action="store_true")
5360
args = parser.parse_args()
54-
evaluate_training_evolution(args.model_type)
61+
62+
checkpoints = sorted(glob(os.path.join(CHECKPOINT_ROOT, args.model_type, "epoch-*.pt")))
63+
assert len(checkpoints) > 0
64+
65+
if args.evaluate:
66+
evaluate_training_evolution(args.model_type, checkpoints)
67+
return
68+
69+
job_id = os.environ.get("SLURM_ARRAY_TASK_ID", None)
70+
if job_id is None: # this is the main script that submits slurm jobs
71+
submit_array_job(args.model_type, checkpoints)
72+
else: # we're in a slurm job
73+
job_id = int(job_id)
74+
evaluate_checkpoint_slurm(args.model_type, job_id, checkpoints)
5575

5676

5777
if __name__ == "__main__":

finetuning/generalists/evaluate_training_evolution.sbatch

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#! /bin/bash
22
#SBATCH -c 4
33
#SBATCH --mem 96G
4-
#SBATCH -t 2880
4+
#SBATCH -t 240
55
#SBATCH -p grete:shared
66
#SBATCH -G A100:1
77
#SBATCH -A nim00007

0 commit comments

Comments
 (0)