Skip to content

Commit 192e537

Browse files
Pbinder/auto resume (#766)
PEFT checkpointing and inference for esm2. --------- Signed-off-by: Polina Binder <pbinder@nvidia.com> Signed-off-by: polinabinder1 <pbinder@nvidia.com>
1 parent 6568d2b commit 192e537

File tree

12 files changed

+965
-113
lines changed

12 files changed

+965
-113
lines changed

docs/docs/user-guide/examples/bionemo-esm2/finetune.ipynb

Lines changed: 415 additions & 13 deletions
Large diffs are not rendered by default.

docs/docs/user-guide/examples/bionemo-esm2/inference.ipynb

Lines changed: 145 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -142,35 +142,6 @@
142142
"execution_count": 4,
143143
"metadata": {},
144144
"outputs": [
145-
{
146-
"name": "stderr",
147-
"output_type": "stream",
148-
"text": [
149-
"Downloading data from 'nvidia/clara/esm2nv650m:2.0' to file '/home/ubuntu/.cache/bionemo/0798767e843e3d54315aef91934d28ae7d8e93c2849d5fcfbdf5fac242013997-esm2_650M_nemo2.tar.gz'.\n"
150-
]
151-
},
152-
{
153-
"name": "stdout",
154-
"output_type": "stream",
155-
"text": [
156-
"{\n",
157-
" \"download_end\": \"2025-01-14 22:01:24\",\n",
158-
" \"download_start\": \"2025-01-14 22:01:05\",\n",
159-
" \"download_time\": \"18s\",\n",
160-
" \"files_downloaded\": 1,\n",
161-
" \"local_path\": \"/home/ubuntu/.cache/bionemo/tmpfj1e52vw/esm2nv650m_v2.0\",\n",
162-
" \"size_downloaded\": \"1.12 GB\",\n",
163-
" \"status\": \"COMPLETED\"\n",
164-
"}\n"
165-
]
166-
},
167-
{
168-
"name": "stderr",
169-
"output_type": "stream",
170-
"text": [
171-
"Untarring contents of '/home/ubuntu/.cache/bionemo/0798767e843e3d54315aef91934d28ae7d8e93c2849d5fcfbdf5fac242013997-esm2_650M_nemo2.tar.gz' to '/home/ubuntu/.cache/bionemo/0798767e843e3d54315aef91934d28ae7d8e93c2849d5fcfbdf5fac242013997-esm2_650M_nemo2.tar.gz.untar'\n"
172-
]
173-
},
174145
{
175146
"name": "stdout",
176147
"output_type": "stream",
@@ -267,14 +238,6 @@
267238
"name": "stdout",
268239
"output_type": "stream",
269240
"text": [
270-
"2025-01-14 22:01:45 - faiss.loader - INFO - Loading faiss with AVX512 support.\n",
271-
"2025-01-14 22:01:45 - faiss.loader - INFO - Successfully loaded faiss with AVX512 support.\n",
272-
"[NeMo W 2025-01-14 22:01:46 nemo_logging:405] /usr/local/lib/python3.12/dist-packages/pydub/utils.py:170: RuntimeWarning: Couldn't find ffmpeg or avconv - defaulting to ffmpeg, but may not work\n",
273-
" warn(\"Couldn't find ffmpeg or avconv - defaulting to ffmpeg, but may not work\", RuntimeWarning)\n",
274-
" \n",
275-
"[NeMo W 2025-01-14 22:01:46 nemo_logging:405] /usr/local/lib/python3.12/dist-packages/pyannote/core/notebook.py:134: MatplotlibDeprecationWarning: The get_cmap function was deprecated in Matplotlib 3.7 and will be removed in 3.11. Use ``matplotlib.colormaps[name]`` or ``matplotlib.colormaps.get_cmap()`` or ``pyplot.get_cmap()`` instead.\n",
276-
" cm = get_cmap(\"Set1\")\n",
277-
" \n",
278241
"usage: infer_esm2 [-h] --checkpoint-path CHECKPOINT_PATH --data-path DATA_PATH\n",
279242
" --results-path RESULTS_PATH\n",
280243
" [--precision {fp16,bf16,fp32,bf16-mixed,fp32-mixed,16-mixed,fp16-mixed,16,32}]\n",
@@ -285,6 +248,7 @@
285248
" [--prediction-interval {epoch,batch}] [--include-hiddens]\n",
286249
" [--include-input-ids] [--include-embeddings]\n",
287250
" [--include-logits] [--config-class CONFIG_CLASS]\n",
251+
" [--lora-checkpoint-path LORA_CHECKPOINT_PATH]\n",
288252
"\n",
289253
"Infer ESM2.\n",
290254
"\n",
@@ -326,7 +290,9 @@
326290
" script should also provide similar support for picking\n",
327291
" different data modules for fine-tuning with different\n",
328292
" data types. Choices: dict_keys(['ESM2Config',\n",
329-
" 'ESM2FineTuneSeqConfig', 'ESM2FineTuneTokenConfig'])\n"
293+
" 'ESM2FineTuneSeqConfig', 'ESM2FineTuneTokenConfig'])\n",
294+
" --lora-checkpoint-path LORA_CHECKPOINT_PATH\n",
295+
" Path to the lora states to restore from.\n"
330296
]
331297
}
332298
],
@@ -507,6 +473,147 @@
507473
"mask = torch.isin(input_ids, torch.tensor(extra_indices))"
508474
]
509475
},
476+
{
477+
"cell_type": "markdown",
478+
"metadata": {},
479+
"source": [
480+
"## Inference with LoRA\n",
481+
"Inference with LoRA is supported. This requires the original model weights along with the additional LoRA weights. "
482+
]
483+
},
484+
{
485+
"cell_type": "code",
486+
"execution_count": 13,
487+
"metadata": {},
488+
"outputs": [
489+
{
490+
"name": "stdout",
491+
"output_type": "stream",
492+
"text": [
493+
"[INFO | pytorch_lightning.utilities.rank_zero]: GPU available: True (cuda), used: True\n",
494+
"[INFO | pytorch_lightning.utilities.rank_zero]: TPU available: False, using: 0 TPU cores\n",
495+
"[INFO | pytorch_lightning.utilities.rank_zero]: HPU available: False, using: 0 HPUs\n",
496+
"[NeMo I 2025-04-25 04:47:42 nemo_logging:393] Fixing mis-match between ddp-config & mcore-optimizer config\n",
497+
"[NeMo I 2025-04-25 04:47:42 nemo_logging:393] Rank 0 has data parallel group : [0]\n",
498+
"[NeMo I 2025-04-25 04:47:42 nemo_logging:393] Rank 0 has combined group of data parallel and context parallel : [0]\n",
499+
"[NeMo I 2025-04-25 04:47:42 nemo_logging:393] All data parallel group ranks with context parallel combined: [[0]]\n",
500+
"[NeMo I 2025-04-25 04:47:42 nemo_logging:393] Ranks 0 has data parallel rank: 0\n",
501+
"[NeMo I 2025-04-25 04:47:42 nemo_logging:393] Rank 0 has context parallel group: [0]\n",
502+
"[NeMo I 2025-04-25 04:47:42 nemo_logging:393] All context parallel group ranks: [[0]]\n",
503+
"[NeMo I 2025-04-25 04:47:42 nemo_logging:393] Ranks 0 has context parallel rank: 0\n",
504+
"[NeMo I 2025-04-25 04:47:42 nemo_logging:393] Rank 0 has model parallel group: [0]\n",
505+
"[NeMo I 2025-04-25 04:47:42 nemo_logging:393] All model parallel group ranks: [[0]]\n",
506+
"[NeMo I 2025-04-25 04:47:42 nemo_logging:393] Rank 0 has tensor model parallel group: [0]\n",
507+
"[NeMo I 2025-04-25 04:47:42 nemo_logging:393] All tensor model parallel group ranks: [[0]]\n",
508+
"[NeMo I 2025-04-25 04:47:42 nemo_logging:393] Rank 0 has tensor model parallel rank: 0\n",
509+
"[NeMo I 2025-04-25 04:47:43 nemo_logging:393] Rank 0 has pipeline model parallel group: [0]\n",
510+
"[NeMo I 2025-04-25 04:47:43 nemo_logging:393] Rank 0 has embedding group: [0]\n",
511+
"[NeMo I 2025-04-25 04:47:43 nemo_logging:393] All pipeline model parallel group ranks: [[0]]\n",
512+
"[NeMo I 2025-04-25 04:47:43 nemo_logging:393] Rank 0 has pipeline model parallel rank 0\n",
513+
"[NeMo I 2025-04-25 04:47:43 nemo_logging:393] All embedding group ranks: [[0]]\n",
514+
"[NeMo I 2025-04-25 04:47:43 nemo_logging:393] Rank 0 has embedding rank: 0\n",
515+
"[INFO | pytorch_lightning.utilities.rank_zero]: ----------------------------------------------------------------------------------------------------\n",
516+
"distributed_backend=nccl\n",
517+
"All distributed processes registered. Starting with 1 processes\n",
518+
"----------------------------------------------------------------------------------------------------\n",
519+
"\n",
520+
"[NeMo I 2025-04-25 04:47:43 nemo_logging:393] Setting up ModelTransform for stage: TrainerFn.PREDICTING\n",
521+
"[NeMo I 2025-04-25 04:47:43 nemo_logging:393] Found model_transform attribute on pl_module\n",
522+
"[NeMo I 2025-04-25 04:47:43 nemo_logging:393] Set model_transform to: <function _call_counter.<locals>.wrapper at 0x70af0b61c220>\n",
523+
"[WARNING | /workspaces/bionemo-framework/sub-packages/bionemo-llm/src/bionemo/llm/model/config.py]: Loading /home/ubuntu/.cache/bionemo/2957b2c36d5978d0f595d6f1b72104b312621cf0329209086537b613c1c96d16-esm2_hf_converted_8m_checkpoint.tar.gz.untar\n",
524+
"[NeMo I 2025-04-25 04:47:43 nemo_logging:393] Padded vocab_size: 128, original vocab_size: 33, dummy tokens: 95.\n",
525+
"LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
526+
"[NeMo W 2025-04-25 04:47:43 nemo_logging:405] Could not copy Trainer's 'max_steps' to LR scheduler's 'max_steps'. If you are not using an LR scheduler, this warning can safely be ignored.\n",
527+
"[NeMo I 2025-04-25 04:47:43 nemo_logging:393] > number of parameters on (tensor, pipeline) model parallel rank (0 ,0): 7542848\n",
528+
"[NeMo I 2025-04-25 04:47:43 nemo_logging:393] Adding lora to: 0.module.module.module.encoder.layers.0.self_attention.linear_proj\n",
529+
"[NeMo I 2025-04-25 04:47:43 nemo_logging:393] Adding lora to: 0.module.module.module.encoder.layers.0.self_attention.linear_qkv\n",
530+
"[NeMo I 2025-04-25 04:47:43 nemo_logging:393] Adding lora to: 0.module.module.module.encoder.layers.0.mlp.linear_fc1\n",
531+
"[NeMo I 2025-04-25 04:47:43 nemo_logging:393] Adding lora to: 0.module.module.module.encoder.layers.0.mlp.linear_fc2\n",
532+
"[NeMo I 2025-04-25 04:47:43 nemo_logging:393] Adding lora to: 0.module.module.module.encoder.layers.1.self_attention.linear_proj\n",
533+
"[NeMo I 2025-04-25 04:47:43 nemo_logging:393] Adding lora to: 0.module.module.module.encoder.layers.1.self_attention.linear_qkv\n",
534+
"[NeMo I 2025-04-25 04:47:43 nemo_logging:393] Adding lora to: 0.module.module.module.encoder.layers.1.mlp.linear_fc1\n",
535+
"[NeMo I 2025-04-25 04:47:43 nemo_logging:393] Adding lora to: 0.module.module.module.encoder.layers.1.mlp.linear_fc2\n",
536+
"[NeMo I 2025-04-25 04:47:43 nemo_logging:393] Adding lora to: 0.module.module.module.encoder.layers.2.self_attention.linear_proj\n",
537+
"[NeMo I 2025-04-25 04:47:43 nemo_logging:393] Adding lora to: 0.module.module.module.encoder.layers.2.self_attention.linear_qkv\n",
538+
"[NeMo I 2025-04-25 04:47:43 nemo_logging:393] Adding lora to: 0.module.module.module.encoder.layers.2.mlp.linear_fc1\n",
539+
"[NeMo I 2025-04-25 04:47:43 nemo_logging:393] Adding lora to: 0.module.module.module.encoder.layers.2.mlp.linear_fc2\n",
540+
"[NeMo I 2025-04-25 04:47:43 nemo_logging:393] Adding lora to: 0.module.module.module.encoder.layers.3.self_attention.linear_proj\n",
541+
"[NeMo I 2025-04-25 04:47:43 nemo_logging:393] Adding lora to: 0.module.module.module.encoder.layers.3.self_attention.linear_qkv\n",
542+
"[NeMo I 2025-04-25 04:47:43 nemo_logging:393] Adding lora to: 0.module.module.module.encoder.layers.3.mlp.linear_fc1\n",
543+
"[NeMo I 2025-04-25 04:47:43 nemo_logging:393] Adding lora to: 0.module.module.module.encoder.layers.3.mlp.linear_fc2\n",
544+
"[NeMo I 2025-04-25 04:47:43 nemo_logging:393] Adding lora to: 0.module.module.module.encoder.layers.4.self_attention.linear_proj\n",
545+
"[NeMo I 2025-04-25 04:47:43 nemo_logging:393] Adding lora to: 0.module.module.module.encoder.layers.4.self_attention.linear_qkv\n",
546+
"[NeMo I 2025-04-25 04:47:43 nemo_logging:393] Adding lora to: 0.module.module.module.encoder.layers.4.mlp.linear_fc1\n",
547+
"[NeMo I 2025-04-25 04:47:43 nemo_logging:393] Adding lora to: 0.module.module.module.encoder.layers.4.mlp.linear_fc2\n",
548+
"[NeMo I 2025-04-25 04:47:43 nemo_logging:393] Adding lora to: 0.module.module.module.encoder.layers.5.self_attention.linear_proj\n",
549+
"[NeMo I 2025-04-25 04:47:43 nemo_logging:393] Adding lora to: 0.module.module.module.encoder.layers.5.self_attention.linear_qkv\n",
550+
"[NeMo I 2025-04-25 04:47:43 nemo_logging:393] Adding lora to: 0.module.module.module.encoder.layers.5.mlp.linear_fc1\n",
551+
"[NeMo I 2025-04-25 04:47:43 nemo_logging:393] Adding lora to: 0.module.module.module.encoder.layers.5.mlp.linear_fc2\n",
552+
"[NeMo I 2025-04-25 04:47:43 nemo_logging:393] After applying model_transform:\n",
553+
" \n",
554+
" | Name | Type | Params | Mode\n",
555+
"---------------------------------------\n",
556+
"0 | module | DDP | 8.5 M | eval\n",
557+
"---------------------------------------\n",
558+
"1.1 M Trainable params\n",
559+
"7.4 M Non-trainable params\n",
560+
"8.5 M Total params\n",
561+
"34.104 Total estimated model params size (MB)\n",
562+
"121 Modules in train mode\n",
563+
"133 Modules in eval mode\n",
564+
"[NeMo I 2025-04-25 04:47:43 nemo_logging:393] Loading adapters from /home/ubuntu/.cache/bionemo/ffce714478009c353f20e8c1ad8e49638128f0cc936ebe1d44c161ac89831dcb-finetuned_peft_weights.tar.gz.untar/weights\n",
565+
"[NeMo I 2025-04-25 04:47:43 nemo_logging:393] Initializing model parallel\n",
566+
"[NeMo I 2025-04-25 04:47:43 nemo_logging:393] > number of parameters on (tensor, pipeline) model parallel rank (0 ,0): 8525888\n",
567+
"[NeMo I 2025-04-25 04:47:43 nemo_logging:393] > number of trainable parameters: 1086528 (12.74% of total)\n",
568+
"[NeMo I 2025-04-25 04:47:43 nemo_logging:393] Using <megatron.core.dist_checkpointing.strategies.fully_parallel.FullyParallelLoadStrategyWrapper object at 0x70af081e3f50> dist-ckpt load strategy.\n",
569+
"[NeMo I 2025-04-25 04:47:43 nemo_logging:393] Global Checkpoint Load : Rank : 0 : Start time : 1745556463.638s : Time spent in load_checkpoint: 0.029s\n",
570+
"[NeMo W 2025-04-25 04:47:43 nemo_logging:405] MegatronOptimizerModule not found in trainer callbacks. finalize_model_grads is not properly set up for PEFT.\n"
571+
]
572+
}
573+
],
574+
"source": [
575+
"# Load the base model and PEFT weights\n",
576+
"\n",
577+
"checkpoint_path = load(\"esm2/8m:2.0\")\n",
578+
"lora_checkpoint_path = load(\"esm2/esm2_lora_weights:1.1\") / \"weights\"\n",
579+
"\n",
580+
"# Perform inference with LoRA\n",
581+
"! infer_esm2 --checkpoint-path {checkpoint_path} \\\n",
582+
" --data-path {data_path} \\\n",
583+
" --results-path {work_dir} \\\n",
584+
" --micro-batch-size 3 \\\n",
585+
" --num-gpus 1 \\\n",
586+
" --include-hiddens \\\n",
587+
" --include-embeddings \\\n",
588+
" --include-logits \\\n",
589+
" --include-input-ids \\\n",
590+
" --lora-checkpoint-path {lora_checkpoint_path}"
591+
]
592+
},
593+
{
594+
"cell_type": "code",
595+
"execution_count": 14,
596+
"metadata": {},
597+
"outputs": [
598+
{
599+
"name": "stdout",
600+
"output_type": "stream",
601+
"text": [
602+
"token_logits\ttorch.Size([1024, 10, 128])\n",
603+
"hidden_states\ttorch.Size([10, 1024, 320])\n",
604+
"input_ids\ttorch.Size([10, 1024])\n",
605+
"embeddings\ttorch.Size([10, 320])\n"
606+
]
607+
}
608+
],
609+
"source": [
610+
"results = torch.load(f\"{work_dir}/predictions__rank_0.pt\")\n",
611+
"\n",
612+
"for key, val in results.items():\n",
613+
" if val is not None:\n",
614+
" print(f\"{key}\\t{val.shape}\")"
615+
]
616+
},
510617
{
511618
"cell_type": "markdown",
512619
"metadata": {},

sub-packages/bionemo-core/src/bionemo/core/data/resources/esm2.yaml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,3 +84,11 @@
8484
sha256: 14ae3acfbf82218bc9e3e53d21a5b0594ba7c0369e169c9f1034e3fe4378d175 # pragma: allowlist secret
8585
owner: Farhad Ramezanghorbani <farhadr@nvidia.com>
8686
description: Test data for ESM2 inference.
87+
88+
- tag: esm2_lora_weights:1.1
89+
ngc: nvidia/clara/esm2_lora_weights:1.1
90+
ngc_registry: model
91+
pbss: "s3://general-purpose/esm2/checkpoints/finetuned_peft_weights.tar.gz"
92+
sha256: ffce714478009c353f20e8c1ad8e49638128f0cc936ebe1d44c161ac89831dcb # pragma: allowlist secret
93+
owner: Polina Binder <pbinder@nvidia.com>
94+
description: Weights for a LoRA finetuned ESM2 model.

sub-packages/bionemo-esm2/src/bionemo/esm2/model/finetune/peft.py

Lines changed: 58 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@
2828
# See the License for the specific language governing permissions and
2929
# limitations under the License.
3030

31+
from typing import List, Optional
32+
33+
import lightning.pytorch as pl
3134
from nemo.collections.llm import fn
3235
from nemo.collections.llm.fn.mixin import FNMixin
3336
from nemo.collections.llm.peft.lora import LoRA
@@ -37,13 +40,66 @@
3740
class ESM2LoRA(LoRA):
3841
"""LoRA for the BioNeMo2 ESM Model."""
3942

43+
def __init__(
44+
self,
45+
peft_ckpt_path: Optional[str] = None,
46+
freeze_modules: List[str] = ["encoder", "embedding"],
47+
*args,
48+
**kwarg,
49+
):
50+
"""Initialize the LoRA Adapter.
51+
52+
Args:
53+
peft_ckpt_path: config for peft chekpoint.
54+
freeze_modules: modules to freeze.
55+
*args: args for the LoRA class.
56+
**kwarg: kwargs for the LoRA class.
57+
"""
58+
super().__init__(*args, **kwarg)
59+
self.freeze_modules = freeze_modules
60+
self.peft_ckpt_path = peft_ckpt_path
61+
62+
def setup(self, *args, **kwarg):
63+
"""Initialize the LoRA Adapter. Pass the peft_ckpt_path to the wrapped io.
64+
65+
Args:
66+
*args: args for the LoRA class.
67+
**kwarg: kwargs for the LoRA class.
68+
"""
69+
super().setup(*args, **kwarg)
70+
self.wrapped_io.adapter_ckpt_path = self.peft_ckpt_path
71+
72+
def on_predict_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
73+
"""Event hook.
74+
75+
Args:
76+
trainer: The trainer object.
77+
pl_module: The LightningModule object.
78+
"""
79+
self._maybe_apply_transform(trainer)
80+
81+
def adapter_key_filter(self, key: str) -> bool:
82+
"""Given a key in the state dict, return whether the key is an adapter (or base model).
83+
84+
Args:
85+
key: the key to filter
86+
"""
87+
if isinstance(key, tuple):
88+
return key[1].requires_grad
89+
if "_extra_state" in key:
90+
return False
91+
return (
92+
(not any(substring in key for substring in self.freeze_modules))
93+
or ".adapter." in key
94+
or key.endswith(".adapters")
95+
)
96+
4097
def __call__(self, model: nn.Module) -> nn.Module:
4198
"""This method is called when the object is called as a function.
4299
43100
Args:
44101
model: The input model.
45102
46-
Returns:
47103
The modified model.
48104
"""
49105
fn.walk(model, self.selective_freeze)
@@ -64,6 +120,6 @@ def selective_freeze(self, m: nn.Module, name=None, prefix=None):
64120
See Also:
65121
nemo.collections.llm.fn.mixin.FNMixin
66122
"""
67-
if name in ["encoder", "embedding"]:
123+
if name in self.freeze_modules:
68124
FNMixin.freeze(m)
69125
return m

sub-packages/bionemo-esm2/src/bionemo/esm2/model/model.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,6 @@ def __init__(
169169
pre_process=self.pre_process,
170170
post_process=self.post_process,
171171
)
172-
173172
# Output
174173
if post_process:
175174
# TODO: Make sure you are passing in the mpu_vocab_size properly

0 commit comments

Comments
 (0)