Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
5543a7c
changes and files for peft
polinabinder1 Mar 12, 2025
ce5e138
trainer changes
polinabinder1 Mar 14, 2025
a9549e3
scripts for fine-tuning and inference with PEFT
polinabinder1 Mar 19, 2025
18999ba
test cases for esm2
polinabinder1 Mar 19, 2025
a3e4011
Merge branch 'main' into pbinder/auto_resume
polinabinder1 Mar 19, 2025
fd2d7c3
correct pre-commit
polinabinder1 Mar 19, 2025
bb60ef2
reverse nemo changes
polinabinder1 Mar 19, 2025
7f294be
Merge branch 'main' into pbinder/auto_resume
polinabinder1 Mar 31, 2025
68e62fc
seed as an argument
polinabinder1 Mar 31, 2025
903035b
Merge remote-tracking branch 'origin/main' into pbinder/auto_resume
polinabinder1 Apr 1, 2025
01b02ef
test file changes
polinabinder1 Apr 2, 2025
034eaf7
resumption test case
polinabinder1 Apr 3, 2025
eb9db7b
debugging inference
polinabinder1 Apr 4, 2025
7c2b2b4
experimetning with inference auto resume
polinabinder1 Apr 7, 2025
83a6a7f
inference working
polinabinder1 Apr 8, 2025
4de7922
not running distributed
polinabinder1 Apr 9, 2025
458d0b5
fixing test cases
polinabinder1 Apr 10, 2025
fc3a5a2
Merge branch 'main' into pbinder/auto_resume
polinabinder1 Apr 11, 2025
be2bc85
add + refactor test cases
polinabinder1 Apr 11, 2025
15273b5
proper file handling for test cases
polinabinder1 Apr 16, 2025
559e4fb
fixing some inference pipelines
polinabinder1 Apr 17, 2025
b46605a
Merge branch 'main' into pbinder/auto_resume
polinabinder1 Apr 17, 2025
7d5d3fa
Delete sub-packages/bionemo-esm2/tests/bionemo/esm2/scripts/test_fine…
polinabinder1 Apr 17, 2025
cf3d5a5
removing test files
polinabinder1 Apr 17, 2025
ceb9bff
addressing PR comments
polinabinder1 Apr 21, 2025
a20d6a6
fixing the imports
polinabinder1 Apr 22, 2025
593a2a0
Update conftest.py
polinabinder1 Apr 22, 2025
58c440e
adding correct NGC path
polinabinder1 Apr 22, 2025
15aa844
Update conftest.py
polinabinder1 Apr 23, 2025
48c91e9
Merge branch 'main' into pbinder/auto_resume
polinabinder1 Apr 23, 2025
b9db5a2
adding correct ngc location formatting
polinabinder1 Apr 23, 2025
68bda48
updating inference notebook
polinabinder1 Apr 24, 2025
0b82299
correct ngc info
polinabinder1 Apr 24, 2025
c6bc517
removing a test case that does not run well with megatron environemen…
polinabinder1 Apr 24, 2025
6f9712a
adding correct notebooks
polinabinder1 Apr 25, 2025
8708714
adding notebook
polinabinder1 Apr 25, 2025
4fe7621
Merge branch 'main' into pbinder/auto_resume
polinabinder1 Apr 25, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
428 changes: 415 additions & 13 deletions docs/docs/user-guide/examples/bionemo-esm2/finetune.ipynb

Large diffs are not rendered by default.

183 changes: 145 additions & 38 deletions docs/docs/user-guide/examples/bionemo-esm2/inference.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -142,35 +142,6 @@
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Downloading data from 'nvidia/clara/esm2nv650m:2.0' to file '/home/ubuntu/.cache/bionemo/0798767e843e3d54315aef91934d28ae7d8e93c2849d5fcfbdf5fac242013997-esm2_650M_nemo2.tar.gz'.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{\n",
" \"download_end\": \"2025-01-14 22:01:24\",\n",
" \"download_start\": \"2025-01-14 22:01:05\",\n",
" \"download_time\": \"18s\",\n",
" \"files_downloaded\": 1,\n",
" \"local_path\": \"/home/ubuntu/.cache/bionemo/tmpfj1e52vw/esm2nv650m_v2.0\",\n",
" \"size_downloaded\": \"1.12 GB\",\n",
" \"status\": \"COMPLETED\"\n",
"}\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Untarring contents of '/home/ubuntu/.cache/bionemo/0798767e843e3d54315aef91934d28ae7d8e93c2849d5fcfbdf5fac242013997-esm2_650M_nemo2.tar.gz' to '/home/ubuntu/.cache/bionemo/0798767e843e3d54315aef91934d28ae7d8e93c2849d5fcfbdf5fac242013997-esm2_650M_nemo2.tar.gz.untar'\n"
]
},
{
"name": "stdout",
"output_type": "stream",
Expand Down Expand Up @@ -267,14 +238,6 @@
"name": "stdout",
"output_type": "stream",
"text": [
"2025-01-14 22:01:45 - faiss.loader - INFO - Loading faiss with AVX512 support.\n",
"2025-01-14 22:01:45 - faiss.loader - INFO - Successfully loaded faiss with AVX512 support.\n",
"[NeMo W 2025-01-14 22:01:46 nemo_logging:405] /usr/local/lib/python3.12/dist-packages/pydub/utils.py:170: RuntimeWarning: Couldn't find ffmpeg or avconv - defaulting to ffmpeg, but may not work\n",
" warn(\"Couldn't find ffmpeg or avconv - defaulting to ffmpeg, but may not work\", RuntimeWarning)\n",
" \n",
"[NeMo W 2025-01-14 22:01:46 nemo_logging:405] /usr/local/lib/python3.12/dist-packages/pyannote/core/notebook.py:134: MatplotlibDeprecationWarning: The get_cmap function was deprecated in Matplotlib 3.7 and will be removed in 3.11. Use ``matplotlib.colormaps[name]`` or ``matplotlib.colormaps.get_cmap()`` or ``pyplot.get_cmap()`` instead.\n",
" cm = get_cmap(\"Set1\")\n",
" \n",
"usage: infer_esm2 [-h] --checkpoint-path CHECKPOINT_PATH --data-path DATA_PATH\n",
" --results-path RESULTS_PATH\n",
" [--precision {fp16,bf16,fp32,bf16-mixed,fp32-mixed,16-mixed,fp16-mixed,16,32}]\n",
Expand All @@ -285,6 +248,7 @@
" [--prediction-interval {epoch,batch}] [--include-hiddens]\n",
" [--include-input-ids] [--include-embeddings]\n",
" [--include-logits] [--config-class CONFIG_CLASS]\n",
" [--lora-checkpoint-path LORA_CHECKPOINT_PATH]\n",
"\n",
"Infer ESM2.\n",
"\n",
Expand Down Expand Up @@ -326,7 +290,9 @@
" script should also provide similar support for picking\n",
" different data modules for fine-tuning with different\n",
" data types. Choices: dict_keys(['ESM2Config',\n",
" 'ESM2FineTuneSeqConfig', 'ESM2FineTuneTokenConfig'])\n"
" 'ESM2FineTuneSeqConfig', 'ESM2FineTuneTokenConfig'])\n",
" --lora-checkpoint-path LORA_CHECKPOINT_PATH\n",
" Path to the lora states to restore from.\n"
]
}
],
Expand Down Expand Up @@ -507,6 +473,147 @@
"mask = torch.isin(input_ids, torch.tensor(extra_indices))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Inference with LoRA\n",
"Inference with LoRA is supported. This requires the original model weights along with the additional LoRA weights. "
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[INFO | pytorch_lightning.utilities.rank_zero]: GPU available: True (cuda), used: True\n",
"[INFO | pytorch_lightning.utilities.rank_zero]: TPU available: False, using: 0 TPU cores\n",
"[INFO | pytorch_lightning.utilities.rank_zero]: HPU available: False, using: 0 HPUs\n",
"[NeMo I 2025-04-25 04:47:42 nemo_logging:393] Fixing mis-match between ddp-config & mcore-optimizer config\n",
"[NeMo I 2025-04-25 04:47:42 nemo_logging:393] Rank 0 has data parallel group : [0]\n",
"[NeMo I 2025-04-25 04:47:42 nemo_logging:393] Rank 0 has combined group of data parallel and context parallel : [0]\n",
"[NeMo I 2025-04-25 04:47:42 nemo_logging:393] All data parallel group ranks with context parallel combined: [[0]]\n",
"[NeMo I 2025-04-25 04:47:42 nemo_logging:393] Ranks 0 has data parallel rank: 0\n",
"[NeMo I 2025-04-25 04:47:42 nemo_logging:393] Rank 0 has context parallel group: [0]\n",
"[NeMo I 2025-04-25 04:47:42 nemo_logging:393] All context parallel group ranks: [[0]]\n",
"[NeMo I 2025-04-25 04:47:42 nemo_logging:393] Ranks 0 has context parallel rank: 0\n",
"[NeMo I 2025-04-25 04:47:42 nemo_logging:393] Rank 0 has model parallel group: [0]\n",
"[NeMo I 2025-04-25 04:47:42 nemo_logging:393] All model parallel group ranks: [[0]]\n",
"[NeMo I 2025-04-25 04:47:42 nemo_logging:393] Rank 0 has tensor model parallel group: [0]\n",
"[NeMo I 2025-04-25 04:47:42 nemo_logging:393] All tensor model parallel group ranks: [[0]]\n",
"[NeMo I 2025-04-25 04:47:42 nemo_logging:393] Rank 0 has tensor model parallel rank: 0\n",
"[NeMo I 2025-04-25 04:47:43 nemo_logging:393] Rank 0 has pipeline model parallel group: [0]\n",
"[NeMo I 2025-04-25 04:47:43 nemo_logging:393] Rank 0 has embedding group: [0]\n",
"[NeMo I 2025-04-25 04:47:43 nemo_logging:393] All pipeline model parallel group ranks: [[0]]\n",
"[NeMo I 2025-04-25 04:47:43 nemo_logging:393] Rank 0 has pipeline model parallel rank 0\n",
"[NeMo I 2025-04-25 04:47:43 nemo_logging:393] All embedding group ranks: [[0]]\n",
"[NeMo I 2025-04-25 04:47:43 nemo_logging:393] Rank 0 has embedding rank: 0\n",
"[INFO | pytorch_lightning.utilities.rank_zero]: ----------------------------------------------------------------------------------------------------\n",
"distributed_backend=nccl\n",
"All distributed processes registered. Starting with 1 processes\n",
"----------------------------------------------------------------------------------------------------\n",
"\n",
"[NeMo I 2025-04-25 04:47:43 nemo_logging:393] Setting up ModelTransform for stage: TrainerFn.PREDICTING\n",
"[NeMo I 2025-04-25 04:47:43 nemo_logging:393] Found model_transform attribute on pl_module\n",
"[NeMo I 2025-04-25 04:47:43 nemo_logging:393] Set model_transform to: <function _call_counter.<locals>.wrapper at 0x70af0b61c220>\n",
"[WARNING | /workspaces/bionemo-framework/sub-packages/bionemo-llm/src/bionemo/llm/model/config.py]: Loading /home/ubuntu/.cache/bionemo/2957b2c36d5978d0f595d6f1b72104b312621cf0329209086537b613c1c96d16-esm2_hf_converted_8m_checkpoint.tar.gz.untar\n",
"[NeMo I 2025-04-25 04:47:43 nemo_logging:393] Padded vocab_size: 128, original vocab_size: 33, dummy tokens: 95.\n",
"LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
"[NeMo W 2025-04-25 04:47:43 nemo_logging:405] Could not copy Trainer's 'max_steps' to LR scheduler's 'max_steps'. If you are not using an LR scheduler, this warning can safely be ignored.\n",
"[NeMo I 2025-04-25 04:47:43 nemo_logging:393] > number of parameters on (tensor, pipeline) model parallel rank (0 ,0): 7542848\n",
"[NeMo I 2025-04-25 04:47:43 nemo_logging:393] Adding lora to: 0.module.module.module.encoder.layers.0.self_attention.linear_proj\n",
"[NeMo I 2025-04-25 04:47:43 nemo_logging:393] Adding lora to: 0.module.module.module.encoder.layers.0.self_attention.linear_qkv\n",
"[NeMo I 2025-04-25 04:47:43 nemo_logging:393] Adding lora to: 0.module.module.module.encoder.layers.0.mlp.linear_fc1\n",
"[NeMo I 2025-04-25 04:47:43 nemo_logging:393] Adding lora to: 0.module.module.module.encoder.layers.0.mlp.linear_fc2\n",
"[NeMo I 2025-04-25 04:47:43 nemo_logging:393] Adding lora to: 0.module.module.module.encoder.layers.1.self_attention.linear_proj\n",
"[NeMo I 2025-04-25 04:47:43 nemo_logging:393] Adding lora to: 0.module.module.module.encoder.layers.1.self_attention.linear_qkv\n",
"[NeMo I 2025-04-25 04:47:43 nemo_logging:393] Adding lora to: 0.module.module.module.encoder.layers.1.mlp.linear_fc1\n",
"[NeMo I 2025-04-25 04:47:43 nemo_logging:393] Adding lora to: 0.module.module.module.encoder.layers.1.mlp.linear_fc2\n",
"[NeMo I 2025-04-25 04:47:43 nemo_logging:393] Adding lora to: 0.module.module.module.encoder.layers.2.self_attention.linear_proj\n",
"[NeMo I 2025-04-25 04:47:43 nemo_logging:393] Adding lora to: 0.module.module.module.encoder.layers.2.self_attention.linear_qkv\n",
"[NeMo I 2025-04-25 04:47:43 nemo_logging:393] Adding lora to: 0.module.module.module.encoder.layers.2.mlp.linear_fc1\n",
"[NeMo I 2025-04-25 04:47:43 nemo_logging:393] Adding lora to: 0.module.module.module.encoder.layers.2.mlp.linear_fc2\n",
"[NeMo I 2025-04-25 04:47:43 nemo_logging:393] Adding lora to: 0.module.module.module.encoder.layers.3.self_attention.linear_proj\n",
"[NeMo I 2025-04-25 04:47:43 nemo_logging:393] Adding lora to: 0.module.module.module.encoder.layers.3.self_attention.linear_qkv\n",
"[NeMo I 2025-04-25 04:47:43 nemo_logging:393] Adding lora to: 0.module.module.module.encoder.layers.3.mlp.linear_fc1\n",
"[NeMo I 2025-04-25 04:47:43 nemo_logging:393] Adding lora to: 0.module.module.module.encoder.layers.3.mlp.linear_fc2\n",
"[NeMo I 2025-04-25 04:47:43 nemo_logging:393] Adding lora to: 0.module.module.module.encoder.layers.4.self_attention.linear_proj\n",
"[NeMo I 2025-04-25 04:47:43 nemo_logging:393] Adding lora to: 0.module.module.module.encoder.layers.4.self_attention.linear_qkv\n",
"[NeMo I 2025-04-25 04:47:43 nemo_logging:393] Adding lora to: 0.module.module.module.encoder.layers.4.mlp.linear_fc1\n",
"[NeMo I 2025-04-25 04:47:43 nemo_logging:393] Adding lora to: 0.module.module.module.encoder.layers.4.mlp.linear_fc2\n",
"[NeMo I 2025-04-25 04:47:43 nemo_logging:393] Adding lora to: 0.module.module.module.encoder.layers.5.self_attention.linear_proj\n",
"[NeMo I 2025-04-25 04:47:43 nemo_logging:393] Adding lora to: 0.module.module.module.encoder.layers.5.self_attention.linear_qkv\n",
"[NeMo I 2025-04-25 04:47:43 nemo_logging:393] Adding lora to: 0.module.module.module.encoder.layers.5.mlp.linear_fc1\n",
"[NeMo I 2025-04-25 04:47:43 nemo_logging:393] Adding lora to: 0.module.module.module.encoder.layers.5.mlp.linear_fc2\n",
"[NeMo I 2025-04-25 04:47:43 nemo_logging:393] After applying model_transform:\n",
" \n",
" | Name | Type | Params | Mode\n",
"---------------------------------------\n",
"0 | module | DDP | 8.5 M | eval\n",
"---------------------------------------\n",
"1.1 M Trainable params\n",
"7.4 M Non-trainable params\n",
"8.5 M Total params\n",
"34.104 Total estimated model params size (MB)\n",
"121 Modules in train mode\n",
"133 Modules in eval mode\n",
"[NeMo I 2025-04-25 04:47:43 nemo_logging:393] Loading adapters from /home/ubuntu/.cache/bionemo/ffce714478009c353f20e8c1ad8e49638128f0cc936ebe1d44c161ac89831dcb-finetuned_peft_weights.tar.gz.untar/weights\n",
"[NeMo I 2025-04-25 04:47:43 nemo_logging:393] Initializing model parallel\n",
"[NeMo I 2025-04-25 04:47:43 nemo_logging:393] > number of parameters on (tensor, pipeline) model parallel rank (0 ,0): 8525888\n",
"[NeMo I 2025-04-25 04:47:43 nemo_logging:393] > number of trainable parameters: 1086528 (12.74% of total)\n",
"[NeMo I 2025-04-25 04:47:43 nemo_logging:393] Using <megatron.core.dist_checkpointing.strategies.fully_parallel.FullyParallelLoadStrategyWrapper object at 0x70af081e3f50> dist-ckpt load strategy.\n",
"[NeMo I 2025-04-25 04:47:43 nemo_logging:393] Global Checkpoint Load : Rank : 0 : Start time : 1745556463.638s : Time spent in load_checkpoint: 0.029s\n",
"[NeMo W 2025-04-25 04:47:43 nemo_logging:405] MegatronOptimizerModule not found in trainer callbacks. finalize_model_grads is not properly set up for PEFT.\n"
]
}
],
"source": [
"# Load the base model and PEFT weights\n",
"\n",
"checkpoint_path = load(\"esm2/8m:2.0\")\n",
"lora_checkpoint_path = load(\"esm2/esm2_lora_weights:1.1\") / \"weights\"\n",
"\n",
"# Perform inference with LoRA\n",
"! infer_esm2 --checkpoint-path {checkpoint_path} \\\n",
" --data-path {data_path} \\\n",
" --results-path {work_dir} \\\n",
" --micro-batch-size 3 \\\n",
" --num-gpus 1 \\\n",
" --include-hiddens \\\n",
" --include-embeddings \\\n",
" --include-logits \\\n",
" --include-input-ids \\\n",
" --lora-checkpoint-path {lora_checkpoint_path}"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"token_logits\ttorch.Size([1024, 10, 128])\n",
"hidden_states\ttorch.Size([10, 1024, 320])\n",
"input_ids\ttorch.Size([10, 1024])\n",
"embeddings\ttorch.Size([10, 320])\n"
]
}
],
"source": [
"results = torch.load(f\"{work_dir}/predictions__rank_0.pt\")\n",
"\n",
"for key, val in results.items():\n",
" if val is not None:\n",
" print(f\"{key}\\t{val.shape}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,11 @@
sha256: 14ae3acfbf82218bc9e3e53d21a5b0594ba7c0369e169c9f1034e3fe4378d175 # pragma: allowlist secret
owner: Farhad Ramezanghorbani <farhadr@nvidia.com>
description: Test data for ESM2 inference.

- tag: esm2_lora_weights:1.1
ngc: nvidia/clara/esm2_lora_weights:1.1
ngc_registry: model
pbss: "s3://general-purpose/esm2/checkpoints/finetuned_peft_weights.tar.gz"
sha256: ffce714478009c353f20e8c1ad8e49638128f0cc936ebe1d44c161ac89831dcb # pragma: allowlist secret
owner: Polina Binder <pbinder@nvidia.com>
description: Weights for a LoRA finetuned ESM2 model.
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Optional

import lightning.pytorch as pl
from nemo.collections.llm import fn
from nemo.collections.llm.fn.mixin import FNMixin
from nemo.collections.llm.peft.lora import LoRA
Expand All @@ -37,13 +40,66 @@
class ESM2LoRA(LoRA):
"""LoRA for the BioNeMo2 ESM Model."""

def __init__(
self,
peft_ckpt_path: Optional[str] = None,
freeze_modules: List[str] = ["encoder", "embedding"],
*args,
**kwarg,
):
"""Initialize the LoRA Adapter.

Args:
peft_ckpt_path: config for peft chekpoint.
freeze_modules: modules to freeze.
*args: args for the LoRA class.
**kwarg: kwargs for the LoRA class.
"""
super().__init__(*args, **kwarg)
self.freeze_modules = freeze_modules
self.peft_ckpt_path = peft_ckpt_path

def setup(self, *args, **kwarg):
"""Initialize the LoRA Adapter. Pass the peft_ckpt_path to the wrapped io.

Args:
*args: args for the LoRA class.
**kwarg: kwargs for the LoRA class.
"""
super().setup(*args, **kwarg)
self.wrapped_io.adapter_ckpt_path = self.peft_ckpt_path

def on_predict_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""Event hook.

Args:
trainer: The trainer object.
pl_module: The LightningModule object.
"""
self._maybe_apply_transform(trainer)

def adapter_key_filter(self, key: str) -> bool:
"""Given a key in the state dict, return whether the key is an adapter (or base model).

Args:
key: the key to filter
"""
if isinstance(key, tuple):
return key[1].requires_grad
if "_extra_state" in key:
return False
return (
(not any(substring in key for substring in self.freeze_modules))
or ".adapter." in key
or key.endswith(".adapters")
)

def __call__(self, model: nn.Module) -> nn.Module:
"""This method is called when the object is called as a function.

Args:
model: The input model.

Returns:
The modified model.
"""
fn.walk(model, self.selective_freeze)
Expand All @@ -64,6 +120,6 @@ def selective_freeze(self, m: nn.Module, name=None, prefix=None):
See Also:
nemo.collections.llm.fn.mixin.FNMixin
"""
if name in ["encoder", "embedding"]:
if name in self.freeze_modules:
FNMixin.freeze(m)
return m
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,6 @@ def __init__(
pre_process=self.pre_process,
post_process=self.post_process,
)

# Output
if post_process:
# TODO: Make sure you are passing in the mpu_vocab_size properly
Expand Down
Loading