Skip to content

Commit 1a1edf0

Browse files
nvmvlejwilber
andauthored
Add ESM2 Finetuning Benchmark Configuration (#964)
### Description This PR adds comprehensive benchmark configurations for ESM2 finetuning to support performance testing and validation. The changes introduce two new benchmark configurations (partial-conv and perf) along with enhanced finetuning capabilities including checkpointing control, TensorBoard logging, and TFLOPS measurement callbacks. Key enhancements include: - Added ESM2 finetuning YAML configurations for partial-conv and performance benchmarks - Implemented checkpointing control with `--disable-checkpointing` option for faster benchmark runs - Added TensorBoard logging support for training metrics visualization - Introduced TFLOPS callback option to measure and log computational performance - Enhanced training control parameters including max_steps, early stopping, and batch size configurations ### Type of changes - [ ] Bug fix (non-breaking change which fixes an issue) - [x] New feature (non-breaking change which adds functionality) - [ ] Refactor - [ ] Documentation update - [ ] Other (please describe): ### CI Pipeline Configuration Configure CI behavior by applying the relevant labels: - [SKIP_CI](https://github.com/NVIDIA/bionemo-framework/blob/main/docs/docs/user-guide/contributing/contributing.md#skip_ci) - Skip all continuous integration tests - [INCLUDE_NOTEBOOKS_TESTS](https://github.com/NVIDIA/bionemo-framework/blob/main/docs/docs/user-guide/contributing/contributing.md#include_notebooks_tests) - Execute notebook validation tests in pytest - [INCLUDE_SLOW_TESTS](https://github.com/NVIDIA/bionemo-framework/blob/main/docs/docs/user-guide/contributing/contributing.md#include_slow_tests) - Execute tests labelled as slow in pytest for extensive testing > [!NOTE] > By default, the notebooks validation tests are skipped unless explicitly enabled. #### Authorizing CI Runs We use [copy-pr-bot](https://docs.gha-runners.nvidia.com/apps/copy-pr-bot/#automation) to manage authorization of CI runs on NVIDIA's compute resources. * If a pull request is opened by a trusted user and contains only trusted changes, the pull request's code will automatically be copied to a pull-request/ prefixed branch in the source repository (e.g. pull-request/123) * If a pull request is opened by an untrusted user or contains untrusted changes, an NVIDIA org member must leave an `/ok to test` comment on the pull request to trigger CI. This will need to be done for each new commit. ### Pre-submit Checklist - [x] I have tested these changes locally - [ ] I have updated the documentation accordingly - [x] I have added/updated tests as needed - [ ] All existing tests pass successfully Signed-off-by: My Le <mvle@nvidia.com> --------- Co-authored-by: Jared Wilber <jwilber@nvidia.com>
1 parent 612ea21 commit 1a1edf0

File tree

4 files changed

+654
-154
lines changed

4 files changed

+654
-154
lines changed
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
scope: partial-conv
2+
time_limit: 14400
3+
key_segments:
4+
# Modify keys to be renamed (str) or excluded (False) from run identifier. By default, all args under script_args are included.
5+
train_data_path: False
6+
valid_data_path: False
7+
data_base_path: False
8+
num_workers: False
9+
limit_val_batches: False
10+
limit_test_batches: False
11+
val_check_interval: False
12+
dataset_class: False
13+
task_type: False
14+
config_class: False
15+
experiment_name: False
16+
workspace: False
17+
restore_from_checkpoint_path: False
18+
script_args:
19+
# All arguments referenced in the script string must be specified here.
20+
# Arguments not referenced in the script string must have the 'arg' field specified.
21+
# See jet/core/configs.py for the specification of the configuration class
22+
workspace: /workspace/bionemo2
23+
data_base_path: /data/FLIP
24+
restore_from_checkpoint_path: /data/esm2_650M_nemo2
25+
nodes: [1]
26+
gpus: 8
27+
model: esm2
28+
variant: finetune
29+
config_name: 650M
30+
precision: [bf16-mixed]
31+
num_workers: 8
32+
limit_val_batches: 100 # original 1000, 100 is enough for validation and produce good enough curves
33+
limit_test_batches: 100
34+
task: seq_classification
35+
train_data_path: scl/train/x000.csv
36+
valid_data_path: scl/val/x000.csv
37+
task_type: classification
38+
config_class: ESM2FineTuneSeqConfig
39+
dataset_class: InMemorySingleValueDataset
40+
max_steps: 30000
41+
stop_steps: 3000
42+
experiment_name: seq-level-classification
43+
batch_size: 64
44+
val_check_interval: 100
45+
script: |-
46+
WANDB_API_KEY=$BIONEMO_WANDB_API_KEY ${variant}_${model} \
47+
--train-data-path=${data_base_path}/${train_data_path} \
48+
--valid-data-path=${data_base_path}/${valid_data_path} \
49+
--restore-from-checkpoint-path=${restore_from_checkpoint_path} \
50+
--task-type=${task_type} \
51+
--config-class=${config_class} \
52+
--dataset-class=${dataset_class} \
53+
--num-steps=${max_steps} \
54+
--experiment-name=${experiment_name}_${batch_size}bs_${nodes}node_${gpus}gpu_${max_steps}s_${precision}prec \
55+
--lr=0.0005 \
56+
--result-dir=${tensorboard_dir} \
57+
--micro-batch-size=${batch_size} \
58+
--limit-val-batches=${limit_val_batches} \
59+
--limit-test-batches=${limit_test_batches} \
60+
--precision=${precision} \
61+
--label-column=scl_label \
62+
--num-gpus=${gpus} \
63+
--num-nodes=${nodes} \
64+
--accumulate-grad-batches=2 \
65+
--val-check-interval=${val_check_interval} \
66+
--num-dataset-workers=${num_workers} \
67+
--wandb-project=${wandb_project_name} \
68+
--wandb-group=${model}_${variant}_${config_name}_${task}_${target} \
69+
--create-tensorboard-logger \
70+
--encoder-frozen \
71+
--mlp-ft-dropout=0.25 \
72+
--mlp-hidden-size=256 \
73+
--mlp-target-size=10 \
74+
--disable-checkpointing \
75+
--early-stop-on-step=${stop_steps} \
76+
--create-tflops-callback;
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
scope: perf
2+
time_limit: 3600
3+
key_segments:
4+
# Modify keys to be renamed (str) or excluded (False) from run identifier. By default, all args under script_args are included.
5+
train_data_path: False
6+
valid_data_path: False
7+
data_base_path: False
8+
limit_val_batches: False
9+
limit_test_batches: False
10+
val_check_interval: False
11+
dataset_class: False
12+
task_type: False
13+
config_class: False
14+
num_workers: False
15+
experiment_name: False
16+
workspace: False
17+
restore_from_checkpoint_path: False
18+
script_args:
19+
# All arguments referenced in the script string must be specified here.
20+
# Arguments not referenced in the script string must have the 'arg' field specified.
21+
# See jet/core/configs.py for the specification of the configuration class
22+
workspace: /workspace/bionemo2
23+
data_base_path: /data/FLIP
24+
restore_from_checkpoint_path: /data/esm2_650M_nemo2
25+
gpus: 8
26+
model: esm2
27+
variant: finetune
28+
config_name: 650M
29+
precision: [bf16-mixed]
30+
num_workers: 8
31+
limit_val_batches: 1
32+
limit_test_batches: 1
33+
task: seq_classification
34+
train_data_path: scl/train/x000.csv
35+
valid_data_path: scl/val/x000.csv
36+
task_type: classification
37+
config_class: ESM2FineTuneSeqConfig
38+
dataset_class: InMemorySingleValueDataset
39+
max_steps: 30000
40+
stop_steps: 300
41+
experiment_name: seq-level-classification
42+
val_check_interval: 100
43+
products:
44+
- nodes: 1
45+
batch_size: 16
46+
pp: 1
47+
tp: 1
48+
- nodes: 1
49+
batch_size: 64
50+
pp: 1
51+
tp: 1
52+
- nodes: 2
53+
batch_size: 16
54+
pp: 1
55+
tp: 1
56+
- nodes: 2
57+
batch_size: 64
58+
pp: 1
59+
tp: 1
60+
script: |-
61+
WANDB_API_KEY=$BIONEMO_WANDB_API_KEY ${variant}_${model} \
62+
--train-data-path=${data_base_path}/${train_data_path} \
63+
--valid-data-path=${data_base_path}/${valid_data_path} \
64+
--restore-from-checkpoint-path=${restore_from_checkpoint_path} \
65+
--task-type=${task_type} \
66+
--config-class=${config_class} \
67+
--dataset-class=${dataset_class} \
68+
--num-steps=${max_steps} \
69+
--experiment-name=${experiment_name}_${batch_size}bs_${nodes}node_${gpus}gpu_${max_steps}s_${precision}prec_tp${tp}_pp_${pp} \
70+
--lr=0.0005 \
71+
--result-dir=${tensorboard_dir} \
72+
--micro-batch-size=${batch_size} \
73+
--limit-val-batches=${limit_val_batches} \
74+
--limit-test-batches=${limit_test_batches} \
75+
--precision=${precision} \
76+
--label-column=scl_label \
77+
--num-gpus=${gpus} \
78+
--num-nodes=${nodes} \
79+
--accumulate-grad-batches=1 \
80+
--val-check-interval=${val_check_interval} \
81+
--num-dataset-workers=${num_workers} \
82+
--wandb-project=${wandb_project_name} \
83+
--wandb-group=${model}_${variant}_${config_name}_${task}_${target} \
84+
--create-tensorboard-logger \
85+
--encoder-frozen \
86+
--mlp-ft-dropout=0.25 \
87+
--mlp-hidden-size=256 \
88+
--mlp-target-size=10 \
89+
--disable-checkpointing \
90+
--pipeline-model-parallel-size=${pp} \
91+
--tensor-model-parallel-size=${tp} \
92+
--early-stop-on-step=${stop_steps} \
93+
--create-tflops-callback;

sub-packages/bionemo-esm2/src/bionemo/esm2/scripts/finetune_esm2.py

Lines changed: 68 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@
2626
from nemo.collections import llm
2727
from nemo.lightning import resume
2828
from nemo.lightning.pytorch import callbacks as nl_callbacks
29+
from nemo.lightning.pytorch.callbacks.flops_callback import FLOPsMeasurementCallback
2930
from nemo.lightning.pytorch.optim import MegatronOptimizerModule
31+
from nemo.utils.exp_manager import TimingCallback
3032

3133
from bionemo.core.utils.dtypes import PrecisionTypes, get_autocast_dtype
3234
from bionemo.esm2.data.tokenizer import get_tokenizer
@@ -127,7 +129,7 @@ def get_parser():
127129

128130
# Checkpoint parameters
129131
parser.add_argument("--create-tensorboard-logger", action="store_true", help="Create tensorboard logger")
130-
parser.add_argument("--restore-from-checkpoint-path", type=Path, default=None, help="Restore from checkpoint")
132+
parser.add_argument("--restore-from-checkpoint-path", type=Path, required=True, help="Restore from checkpoint")
131133
parser.add_argument("--save-last-checkpoint", action="store_true", default=True, help="Save last checkpoint")
132134
parser.add_argument(
133135
"--metric-to-monitor-for-checkpoints", type=str, default="val_loss", help="Metric to monitor for checkpoints"
@@ -166,6 +168,28 @@ def get_parser():
166168
parser.add_argument("--lora-checkpoint-path", type=Path, default=None, help="LoRA checkpoint path")
167169
parser.add_argument("--lora-finetune", action="store_true", help="Use LoRA fine-tuning")
168170

171+
parser.add_argument(
172+
"--disable-checkpointing",
173+
action="store_false",
174+
default=True,
175+
dest="create_checkpoint_callback",
176+
help="Disable creating a ModelCheckpoint callback.",
177+
)
178+
179+
parser.add_argument(
180+
"--early-stop-on-step",
181+
type=int,
182+
default=None,
183+
help="Stop training on this step, if set. This may be useful for testing or debugging purposes.",
184+
)
185+
186+
parser.add_argument(
187+
"--create-tflops-callback",
188+
action="store_true",
189+
default=False,
190+
help="Enable tflops calculation callback. Default is False.",
191+
)
192+
169193
return parser
170194

171195

@@ -233,7 +257,10 @@ def train_model(
233257
labels_mask_column: Optional[str] = None,
234258
lora_checkpoint_path: Optional[Path] = None,
235259
lora_finetune: bool = False,
236-
) -> Tuple[Path, Callback | None, nl.Trainer]:
260+
create_checkpoint_callback: bool = True,
261+
early_stop_on_step: Optional[int] = None,
262+
create_tflops_callback: bool = False,
263+
) -> Tuple[Optional[Path], Callback | None, nl.Trainer]:
237264
config_class = SUPPORTED_CONFIGS[config_class]
238265
dataset_class = SUPPORTED_DATASETS[dataset_class]
239266

@@ -298,6 +325,7 @@ def train_model(
298325
RichModelSummary(max_depth=4),
299326
LearningRateMonitor(),
300327
nl_callbacks.PreemptionCallback(),
328+
TimingCallback(),
301329
]
302330
if metric_tracker is not None:
303331
callbacks.append(metric_tracker)
@@ -411,24 +439,44 @@ def train_model(
411439
initialize_tensorboard_logger=create_tensorboard_logger,
412440
wandb_config=wandb_config,
413441
)
414-
# Configure our custom Checkpointer
415-
checkpoint_path = str(Path(nemo_logger.save_dir) / "checkpoints")
416-
checkpoint_callback = nl_callbacks.ModelCheckpoint(
417-
dirpath=checkpoint_path,
418-
save_last=save_last_checkpoint,
419-
monitor=metric_to_monitor_for_checkpoints, # "val_loss",
420-
save_top_k=save_top_k,
421-
every_n_train_steps=val_check_interval,
422-
always_save_context=True, # Enables the .nemo file-like checkpointing where all IOMixins are under SerDe
423-
filename="checkpoint-{step}-{consumed_samples}", # Including step and consumed_samples in the checkpoint filename prevents duplicate filenames and bugs related to this.
424-
save_weights_only=False,
425-
save_optim_on_train_end=True,
426-
)
427-
callbacks.append(checkpoint_callback)
442+
443+
if create_checkpoint_callback:
444+
# Configure our custom Checkpointer
445+
checkpoint_path = str(Path(nemo_logger.save_dir) / "checkpoints")
446+
checkpoint_callback = nl_callbacks.ModelCheckpoint(
447+
dirpath=checkpoint_path,
448+
save_last=save_last_checkpoint,
449+
monitor=metric_to_monitor_for_checkpoints, # "val_loss",
450+
save_top_k=save_top_k,
451+
every_n_train_steps=val_check_interval,
452+
always_save_context=True, # Enables the .nemo file-like checkpointing where all IOMixins are under SerDe
453+
filename="checkpoint-{step}-{consumed_samples}", # Including step and consumed_samples in the checkpoint filename prevents duplicate filenames and bugs related to this.
454+
save_weights_only=False,
455+
save_optim_on_train_end=True,
456+
)
457+
callbacks.append(checkpoint_callback)
458+
auto_resume = resume.AutoResume(
459+
resume_from_directory=checkpoint_path,
460+
resume_if_exists=resume_if_exists, # Looks for the -last checkpoint to continue training.
461+
resume_ignore_no_checkpoint=True, # When false this will throw an error with no existing checkpoint.
462+
resume_past_end=False,
463+
)
464+
else:
465+
auto_resume = None
466+
467+
if create_tflops_callback:
468+
# Add callback that logs the tera-FLOPS per second per GPU during training.
469+
data_module.global_batch_size = global_batch_size
470+
flop_meas_callback = FLOPsMeasurementCallback(
471+
config,
472+
data_module,
473+
"bert",
474+
)
475+
callbacks.append(flop_meas_callback)
428476

429477
trainer = nl.Trainer(
430478
devices=num_gpus,
431-
max_steps=num_steps,
479+
max_steps=num_steps if early_stop_on_step is None else early_stop_on_step,
432480
max_epochs=max_epochs,
433481
accelerator="gpu",
434482
strategy=strategy,
@@ -445,21 +493,17 @@ def train_model(
445493
grad_reduce_in_fp32=grad_reduce_in_fp32,
446494
autocast_enabled=False,
447495
),
448-
enable_checkpointing=True,
496+
enable_checkpointing=create_checkpoint_callback,
449497
)
450498
llm.train(
451499
model=module,
452500
data=data_module,
453501
trainer=trainer,
454502
log=nemo_logger,
455-
resume=resume.AutoResume(
456-
resume_from_directory=checkpoint_path,
457-
resume_if_exists=resume_if_exists, # Looks for the -last checkpoint to continue training.
458-
resume_ignore_no_checkpoint=True, # When false this will throw an error with no existing checkpoint.
459-
),
503+
resume=auto_resume,
460504
)
461505

462-
ckpt_path = Path(checkpoint_callback.last_model_path.replace(".ckpt", ""))
506+
ckpt_path = Path(checkpoint_callback.last_model_path.replace(".ckpt", "")) if create_checkpoint_callback else None
463507
return ckpt_path, metric_tracker, trainer
464508

465509

0 commit comments

Comments
 (0)