|
1 | 1 | import argparse |
2 | 2 | import os |
| 3 | + |
3 | 4 | from glob import glob |
| 5 | +from subprocess import run |
4 | 6 |
|
5 | 7 | import pandas as pd |
6 | 8 | from util import evaluate_checkpoint_for_datasets, get_generalist_predictor |
7 | 9 |
|
8 | | -CHECKPOINT_ROOT = "/scratch-grete/projects/nim00007/sam/LM/generalist" |
9 | | -EXPERIMENT_ROOT = "/scratch-grete/projects/nim00007/sam/experiments/generalists/lm" |
| 10 | +CHECKPOINT_ROOT = "/scratch/projects/nim00007/sam/models/LM/generalist/v2" |
| 11 | +EXPERIMENT_ROOT = "/scratch/projects/nim00007/sam/experiments/training-evolution" |
10 | 12 | # We evaluate these three datasets for the training evolution. |
11 | 13 | # These are chosen based on observations from preliminary experiments. |
12 | 14 | # - covid-if: out-of-domain dataset that shows the expected improvement (over vanilla). |
13 | 15 | # - deepbacs: in domain dataset where we see the biggest gap to the specialist. |
14 | | -# - plantseg-root: out-of-domain dataset that doesn't show an improvement. |
15 | | -DATASETS = ("covid-if", "deepbacs", "plantseg-root") |
16 | | - |
| 16 | +# - lizard: out-of-domain that is furthest from the training data. |
| 17 | +EVAL_DATASETS = ("covid-if", "deepbacs", "lizard") |
17 | 18 |
|
18 | | -def evaluate_training_evolution(model_type): |
19 | | - checkpoints = sorted(glob( |
20 | | - os.path.join(CHECKPOINT_ROOT, model_type, "*.pt") |
21 | | - )) |
22 | | - assert len(checkpoints) > 0 |
23 | 19 |
|
24 | | - epochs, results = [], [] |
25 | | - for checkpoint in checkpoints: |
| 20 | +def evaluate_checkpoint_slurm(model_type, job_id, checkpoints): |
| 21 | + checkpoint = checkpoints[job_id] |
26 | 22 |
|
27 | | - predictor, state = get_generalist_predictor(checkpoint, model_type, return_state=True) |
28 | | - epoch = state["epoch"] + 1 |
| 23 | + predictor, state = get_generalist_predictor( |
| 24 | + checkpoint, model_type, is_custom_model=True, return_state=True |
| 25 | + ) |
| 26 | + epoch = state["epoch"] + 1 |
29 | 27 |
|
30 | | - if epoch in epochs: |
31 | | - continue |
| 28 | + print("Run evaluation for", model_type, "epoch", epoch) |
| 29 | + experiment_root = os.path.join(EXPERIMENT_ROOT, f"{model_type}-epoch-{epoch}") |
| 30 | + result = evaluate_checkpoint_for_datasets( |
| 31 | + None, None, experiment_root, EVAL_DATASETS, |
| 32 | + run_default_evaluation=True, run_amg=False, |
| 33 | + is_custom_model=True, predictor=predictor, |
| 34 | + ) |
32 | 35 |
|
33 | | - print("Run evaluation for", model_type, "epoch", epoch) |
34 | | - experiment_root = os.path.join(EXPERIMENT_ROOT, f"{model_type}-epoch-{epoch}") |
35 | | - result = evaluate_checkpoint_for_datasets( |
36 | | - None, None, experiment_root, DATASETS, |
37 | | - run_default_evaluation=True, run_amg=False, |
38 | | - predictor=predictor, |
39 | | - ) |
40 | | - result.insert(0, "epoch", [epoch] * result.shape[0]) |
41 | | - results.append(result) |
| 36 | + result.insert(0, "epoch", [epoch] * result.shape[0]) |
| 37 | + return result |
42 | 38 |
|
43 | | - epochs.append(epoch) |
44 | 39 |
|
| 40 | +def evaluate_training_evolution(model_type, checkpoints): |
| 41 | + results = [] |
| 42 | + for i in range(len(checkpoints)): |
| 43 | + result = evaluate_checkpoint_slurm(model_type, i, checkpoints) |
| 44 | + results.append(result) |
45 | 45 | results = pd.concat(results) |
46 | 46 | save_path = os.path.join(EXPERIMENT_ROOT, f"{model_type}.csv") |
47 | 47 | results.to_csv(save_path, index=False) |
48 | 48 |
|
49 | 49 |
|
| 50 | +def submit_array_job(model_type, checkpoints): |
| 51 | + n_checkpoints = len(checkpoints) |
| 52 | + cmd = ["sbatch", "-a", f"0-{n_checkpoints-1}", "evaluate_training_evolution.sbatch", model_type] |
| 53 | + run(cmd) |
| 54 | + |
| 55 | + |
50 | 56 | def main(): |
51 | 57 | parser = argparse.ArgumentParser() |
52 | 58 | parser.add_argument("model_type") |
| 59 | + parser.add_argument("-e", "--evaluate", action="store_true") |
53 | 60 | args = parser.parse_args() |
54 | | - evaluate_training_evolution(args.model_type) |
| 61 | + |
| 62 | + checkpoints = sorted(glob(os.path.join(CHECKPOINT_ROOT, args.model_type, "epoch-*.pt"))) |
| 63 | + assert len(checkpoints) > 0 |
| 64 | + |
| 65 | + if args.evaluate: |
| 66 | + evaluate_training_evolution(args.model_type, checkpoints) |
| 67 | + return |
| 68 | + |
| 69 | + job_id = os.environ.get("SLURM_ARRAY_TASK_ID", None) |
| 70 | + if job_id is None: # this is the main script that submits slurm jobs |
| 71 | + submit_array_job(args.model_type, checkpoints) |
| 72 | + else: # we're in a slurm job |
| 73 | + job_id = int(job_id) |
| 74 | + evaluate_checkpoint_slurm(args.model_type, job_id, checkpoints) |
55 | 75 |
|
56 | 76 |
|
57 | 77 | if __name__ == "__main__": |
|
0 commit comments