Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
5543a7c
changes and files for peft
polinabinder1 Mar 12, 2025
ce5e138
trainer changes
polinabinder1 Mar 14, 2025
a9549e3
scripts for fine-tuning and inference with PEFT
polinabinder1 Mar 19, 2025
18999ba
test cases for esm2
polinabinder1 Mar 19, 2025
a3e4011
Merge branch 'main' into pbinder/auto_resume
polinabinder1 Mar 19, 2025
fd2d7c3
correct pre-commit
polinabinder1 Mar 19, 2025
bb60ef2
reverse nemo changes
polinabinder1 Mar 19, 2025
7f294be
Merge branch 'main' into pbinder/auto_resume
polinabinder1 Mar 31, 2025
68e62fc
seed as an argument
polinabinder1 Mar 31, 2025
903035b
Merge remote-tracking branch 'origin/main' into pbinder/auto_resume
polinabinder1 Apr 1, 2025
01b02ef
test file changes
polinabinder1 Apr 2, 2025
034eaf7
resumption test case
polinabinder1 Apr 3, 2025
eb9db7b
debugging inference
polinabinder1 Apr 4, 2025
7c2b2b4
experimetning with inference auto resume
polinabinder1 Apr 7, 2025
83a6a7f
inference working
polinabinder1 Apr 8, 2025
4de7922
not running distributed
polinabinder1 Apr 9, 2025
458d0b5
fixing test cases
polinabinder1 Apr 10, 2025
fc3a5a2
Merge branch 'main' into pbinder/auto_resume
polinabinder1 Apr 11, 2025
be2bc85
add + refactor test cases
polinabinder1 Apr 11, 2025
15273b5
proper file handling for test cases
polinabinder1 Apr 16, 2025
559e4fb
fixing some inference pipelines
polinabinder1 Apr 17, 2025
b46605a
Merge branch 'main' into pbinder/auto_resume
polinabinder1 Apr 17, 2025
7d5d3fa
Delete sub-packages/bionemo-esm2/tests/bionemo/esm2/scripts/test_fine…
polinabinder1 Apr 17, 2025
cf3d5a5
removing test files
polinabinder1 Apr 17, 2025
ceb9bff
addressing PR comments
polinabinder1 Apr 21, 2025
a20d6a6
fixing the imports
polinabinder1 Apr 22, 2025
593a2a0
Update conftest.py
polinabinder1 Apr 22, 2025
58c440e
adding correct NGC path
polinabinder1 Apr 22, 2025
15aa844
Update conftest.py
polinabinder1 Apr 23, 2025
48c91e9
Merge branch 'main' into pbinder/auto_resume
polinabinder1 Apr 23, 2025
b9db5a2
adding correct ngc location formatting
polinabinder1 Apr 23, 2025
68bda48
updating inference notebook
polinabinder1 Apr 24, 2025
0b82299
correct ngc info
polinabinder1 Apr 24, 2025
c6bc517
removing a test case that does not run well with megatron environemen…
polinabinder1 Apr 24, 2025
6f9712a
adding correct notebooks
polinabinder1 Apr 25, 2025
8708714
adding notebook
polinabinder1 Apr 25, 2025
4fe7621
Merge branch 'main' into pbinder/auto_resume
polinabinder1 Apr 25, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional

from nemo.collections.llm import fn
from nemo.collections.llm.fn.mixin import FNMixin
from nemo.collections.llm.peft.lora import LoRA
Expand All @@ -37,13 +39,33 @@
class ESM2LoRA(LoRA):
"""LoRA for the BioNeMo2 ESM Model."""

def __init__(self, peft_ckpt_path: Optional[str] = None, *args, **kwarg):
"""Initialize the LoRA Adapter.

Args:
peft_ckpt_path: config for peft chekpoint.
*args: args for the LoRA class.
**kwarg: kwargs for the LoRA class.
"""
super().__init__(*args, **kwarg)
self.peft_ckpt_path = peft_ckpt_path

def setup(self, *args, **kwarg):
"""Initialize the LoRA Adapter. Pass the peft_ckpt_path to the wrapped io.

Args:
*args: args for the LoRA class.
**kwarg: kwargs for the LoRA class.
"""
super().setup(*args, **kwarg)
self.wrapped_io.peft_ckpt_path = self.peft_ckpt_path

def __call__(self, model: nn.Module) -> nn.Module:
"""This method is called when the object is called as a function.

Args:
model: The input model.

Returns:
The modified model.
"""
fn.walk(model, self.selective_freeze)
Expand Down
179 changes: 109 additions & 70 deletions sub-packages/bionemo-esm2/src/bionemo/esm2/scripts/finetune_esm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from typing import Dict, List, Optional, Sequence, Tuple, Type, get_args

from lightning.pytorch.callbacks import Callback, LearningRateMonitor, RichModelSummary
from megatron.core.dist_checkpointing.validation import StrictHandling
from megatron.core.distributed import DistributedDataParallelConfig
from megatron.core.optimizer import OptimizerConfig
from nemo import lightning as nl
Expand All @@ -35,13 +36,15 @@
InMemoryProteinDataset,
InMemorySingleValueDataset,
)
from bionemo.esm2.model.finetune.peft import ESM2LoRA
from bionemo.esm2.model.finetune.sequence_model import ESM2FineTuneSeqConfig
from bionemo.esm2.model.finetune.token_model import ESM2FineTuneTokenConfig
from bionemo.llm.model.biobert.lightning import biobert_lightning_module
from bionemo.llm.model.biobert.model import BioBertConfig
from bionemo.llm.model.config import TorchmetricsConfig
from bionemo.llm.utils.datamodule_utils import float_or_int_or_none, infer_global_batch_size
from bionemo.llm.utils.logger_utils import WandbConfig, setup_nemo_lightning_logger
from bionemo.testing import megatron_parallel_state_utils


__all__: Sequence[str] = ("finetune_esm2_entrypoint", "get_parser", "train_model")
Expand Down Expand Up @@ -118,6 +121,8 @@ def train_model(
grad_reduce_in_fp32: bool = False,
ckpt_async_save: bool = True,
label_column: str = "labels",
lora_checkpoint_path: Optional[str] = None,
lora_finetune: bool = False,
) -> Tuple[Path, Callback | None, nl.Trainer]:
"""Train an ESM2 model on UR data.

Expand Down Expand Up @@ -180,6 +185,8 @@ def train_model(
grad_reduce_in_fp32 (bool): gradient reduction in fp32
ckpt_async_save (bool): whether to save ckpt async. Set to False for federated learning
label_column (str): name of label column in CSV data file. Defaults to `labels`.
lora_checkpoint_path (Optional[str]): path to the lora checkpoint file.
lora_finetune (bool): whether to use lora fine-tuning.
"""
# Create the result directory if it does not exist.
result_dir.mkdir(parents=True, exist_ok=True)
Expand All @@ -197,19 +204,20 @@ def train_model(
strategy = nl.MegatronStrategy(
tensor_model_parallel_size=tensor_model_parallel_size,
pipeline_model_parallel_size=pipeline_model_parallel_size,
find_unused_parameters=True,
gradient_as_bucket_view=True,
ckpt_include_optimizer=True,
ckpt_async_save=ckpt_async_save,
ckpt_parallel_load=True,
ckpt_load_strictness=StrictHandling.LOG_UNEXPECTED,
ddp=DistributedDataParallelConfig(
check_for_nan_in_grad=True,
overlap_grad_reduce=overlap_grad_reduce,
overlap_param_gather=overlap_param_gather,
average_in_collective=average_in_collective,
grad_reduce_in_fp32=grad_reduce_in_fp32,
use_distributed_optimizer=True,
use_distributed_optimizer=False,
),
find_unused_parameters=True,
gradient_as_bucket_view=True,
ckpt_include_optimizer=True,
ckpt_async_save=ckpt_async_save,
ckpt_parallel_load=True,
)

# for wandb integration
Expand Down Expand Up @@ -244,6 +252,9 @@ def train_model(
start_step=nsys_start_step, end_step=nsys_end_step, ranks=nsys_ranks, gen_shape=True
)
)
if lora_finetune:
peft = ESM2LoRA(peft_ckpt_path=lora_checkpoint_path)
callbacks.append(peft)

trainer = nl.Trainer(
devices=devices,
Expand All @@ -263,7 +274,6 @@ def train_model(
autocast_enabled=False,
),
)

tokenizer = get_tokenizer()

# Initialize the data module.
Expand Down Expand Up @@ -342,7 +352,12 @@ def train_model(
optimizer.scale_lr_cond = lambda name, param: scale_lr_layer in name
optimizer.lr_mult = lr_multiplier

module = biobert_lightning_module(config=config, tokenizer=tokenizer, optimizer=optimizer)
if lora_finetune:
module = biobert_lightning_module(
config=config, tokenizer=tokenizer, optimizer=optimizer, model_transform=peft
)
else:
module = biobert_lightning_module(config=config, tokenizer=tokenizer, optimizer=optimizer)

# Configure our custom Checkpointer
checkpoint_callback = nl_callbacks.ModelCheckpoint(
Expand All @@ -352,6 +367,8 @@ def train_model(
every_n_train_steps=val_check_interval,
always_save_context=True, # Enables the .nemo file-like checkpointing where all IOMixins are under SerDe
filename="checkpoint-{step}-{consumed_samples}", # Including step and consumed_samples in the checkpoint filename prevents duplicate filenames and bugs related to this.
save_weights_only=False,
save_optim_on_train_end=True,
)

# Setup the logger and train the model
Expand All @@ -374,6 +391,7 @@ def train_model(
),
)
ckpt_path = Path(checkpoint_callback.last_model_path.replace(".ckpt", ""))

return ckpt_path, metric_tracker, trainer


Expand All @@ -386,66 +404,71 @@ def finetune_esm2_entrypoint():
# to avoid padding for single value labels:
if args.min_seq_length is not None and args.datset_class is InMemorySingleValueDataset:
parser.error("Arguments --min-seq-length cannot be set when using InMemorySingleValueDataset.")

# 2. Call pretrain with args
train_model(
train_data_path=args.train_data_path,
valid_data_path=args.valid_data_path,
num_nodes=args.num_nodes,
devices=args.num_gpus,
min_seq_length=args.min_seq_length,
max_seq_length=args.max_seq_length,
result_dir=args.result_dir,
wandb_entity=args.wandb_entity,
wandb_project=args.wandb_project,
wandb_tags=args.wandb_tags,
wandb_group=args.wandb_group,
wandb_id=args.wandb_id,
wandb_anonymous=args.wandb_anonymous,
wandb_log_model=args.wandb_log_model,
wandb_offline=args.wandb_offline,
num_steps=args.num_steps,
limit_val_batches=args.limit_val_batches,
val_check_interval=args.val_check_interval,
log_every_n_steps=args.log_every_n_steps,
num_dataset_workers=args.num_dataset_workers,
lr=args.lr,
micro_batch_size=args.micro_batch_size,
pipeline_model_parallel_size=args.pipeline_model_parallel_size,
tensor_model_parallel_size=args.tensor_model_parallel_size,
accumulate_grad_batches=args.accumulate_grad_batches,
precision=args.precision,
task_type=args.task_type,
encoder_frozen=args.encoder_frozen,
scale_lr_layer=args.scale_lr_layer,
lr_multiplier=args.lr_multiplier,
# single value classification / regression mlp
mlp_ft_dropout=args.mlp_ft_dropout,
mlp_hidden_size=args.mlp_hidden_size,
mlp_target_size=args.mlp_target_size,
# token-level classification cnn
cnn_dropout=args.cnn_dropout,
cnn_hidden_size=args.cnn_hidden_size,
cnn_num_classes=args.cnn_num_classes,
experiment_name=args.experiment_name,
resume_if_exists=args.resume_if_exists,
restore_from_checkpoint_path=args.restore_from_checkpoint_path,
save_last_checkpoint=args.save_last_checkpoint,
metric_to_monitor_for_checkpoints=args.metric_to_monitor_for_checkpoints,
save_top_k=args.save_top_k,
nsys_profiling=args.nsys_profiling,
nsys_start_step=args.nsys_start_step,
nsys_end_step=args.nsys_end_step,
nsys_ranks=args.nsys_ranks,
dataset_class=args.dataset_class,
config_class=args.config_class,
overlap_grad_reduce=args.overlap_grad_reduce,
overlap_param_gather=not args.no_overlap_param_gather,
average_in_collective=not args.no_average_in_collective,
grad_reduce_in_fp32=args.grad_reduce_in_fp32,
ckpt_async_save=not args.avoid_ckpt_async_save,
label_column=args.label_column,
)
if args.lora_checkpoint_path and not args.lora_finetune:
parser.error("Arguments --lora=checkpoint-path cannot be set when not using lora-finetune.")

with megatron_parallel_state_utils.distributed_model_parallel_state(43):
# 2. Call pretrain with args
train_model(
train_data_path=args.train_data_path,
valid_data_path=args.valid_data_path,
num_nodes=args.num_nodes,
devices=args.num_gpus,
min_seq_length=args.min_seq_length,
max_seq_length=args.max_seq_length,
result_dir=args.result_dir,
wandb_entity=args.wandb_entity,
wandb_project=args.wandb_project,
wandb_tags=args.wandb_tags,
wandb_group=args.wandb_group,
wandb_id=args.wandb_id,
wandb_anonymous=args.wandb_anonymous,
wandb_log_model=args.wandb_log_model,
wandb_offline=args.wandb_offline,
num_steps=args.num_steps,
limit_val_batches=args.limit_val_batches,
val_check_interval=args.val_check_interval,
log_every_n_steps=args.log_every_n_steps,
num_dataset_workers=args.num_dataset_workers,
lr=args.lr,
micro_batch_size=args.micro_batch_size,
pipeline_model_parallel_size=args.pipeline_model_parallel_size,
tensor_model_parallel_size=args.tensor_model_parallel_size,
accumulate_grad_batches=args.accumulate_grad_batches,
precision=args.precision,
task_type=args.task_type,
encoder_frozen=args.encoder_frozen,
scale_lr_layer=args.scale_lr_layer,
lr_multiplier=args.lr_multiplier,
# single value classification / regression mlp
mlp_ft_dropout=args.mlp_ft_dropout,
mlp_hidden_size=args.mlp_hidden_size,
mlp_target_size=args.mlp_target_size,
# token-level classification cnn
cnn_dropout=args.cnn_dropout,
cnn_hidden_size=args.cnn_hidden_size,
cnn_num_classes=args.cnn_num_classes,
experiment_name=args.experiment_name,
resume_if_exists=args.resume_if_exists,
restore_from_checkpoint_path=args.restore_from_checkpoint_path,
save_last_checkpoint=args.save_last_checkpoint,
metric_to_monitor_for_checkpoints=args.metric_to_monitor_for_checkpoints,
save_top_k=args.save_top_k,
nsys_profiling=args.nsys_profiling,
nsys_start_step=args.nsys_start_step,
nsys_end_step=args.nsys_end_step,
nsys_ranks=args.nsys_ranks,
dataset_class=args.dataset_class,
config_class=args.config_class,
overlap_grad_reduce=args.overlap_grad_reduce,
overlap_param_gather=not args.no_overlap_param_gather,
average_in_collective=not args.no_average_in_collective,
grad_reduce_in_fp32=args.grad_reduce_in_fp32,
ckpt_async_save=not args.avoid_ckpt_async_save,
label_column=args.label_column,
lora_checkpoint_path=args.lora_checkpoint_path,
lora_finetune=args.lora_finetune,
)


def get_parser():
Expand Down Expand Up @@ -604,7 +627,7 @@ def get_parser():
"--num-steps",
type=int,
required=False,
default=500000,
default=5,
help="Number of steps to use for training. Default is 500000.",
)
parser.add_argument(
Expand All @@ -618,7 +641,7 @@ def get_parser():
"--val-check-interval",
type=int,
required=False,
default=10000,
default=5,
help="Number of steps between validation. Default is 10000.",
)
parser.add_argument(
Expand Down Expand Up @@ -703,6 +726,22 @@ def get_parser():
default=None,
help="Path to the checkpoint directory to restore from. Will override `--resume-if-exists` when set.",
)

parser.add_argument(
"--lora-finetune",
action="store_true",
default=True,
help="Perform fine-tuning with LoRA.",
)

parser.add_argument(
"--lora-checkpoint-path",
type=Path,
required=False,
default=None,
help="Path to the lora states to restore from.",
)

parser.add_argument(
"--nsys-profiling",
action="store_true",
Expand Down
Loading
Loading