diff --git a/docs/docs/user-guide/examples/bionemo-esm2/finetune.ipynb b/docs/docs/user-guide/examples/bionemo-esm2/finetune.ipynb index 7a806b65d..4d3dfb35e 100644 --- a/docs/docs/user-guide/examples/bionemo-esm2/finetune.ipynb +++ b/docs/docs/user-guide/examples/bionemo-esm2/finetune.ipynb @@ -867,7 +867,6 @@ "outputs": [], "source": [ "%%capture --no-display --no-stderr cell_output\n", - "\n", "! infer_esm2 --checkpoint-path {checkpoint_path} \\\n", " --config-class ESM2FineTuneTokenConfig \\\n", " --data-path {data_path} \\\n", @@ -973,16 +972,16 @@ "output_type": "stream", "text": [ "Predicted Secondary Structures:\n", - "EEEECCCCCHHHHHHHHHHHHHHHCCCEEEEEECCCHHHHHHHHHCCCCCCCCCEEE\n", - "CCCCCHHHHHHHHHHHHHHCCCCCHHHHHHCC\n", - "HHHHHCCCCCHHHHHHHHHHHHHHCCCHHHHHHHHHH\n", - "HHHHHHHHHHCCCHHHHHCCCCCCCCHHHHHHHHHHHHHHCCCCCHHHHHHCC\n", - "CHHHHHHHHHHHHHHHCCCEEEEEECCCHHHHHHHHHCCCCCCCCCEEE\n", - "HHHHHHHHHHHHHCHHHHHHHHHHHHCCCEECCCEEEECCEEEEECC\n", - "HHHHHCCCHHHHHCCCCCCCCHHHHHHHHHHHHHHCCCCCHHHHHHCC\n", - "CCCCCHHHHHHHHHHHHHHCCCCCHHHHHHCC\n", - "HHHHHCHHHHHHHHHHHHCCCEECCCEEEECCEEEEECC\n", - "CCCCCCCCCCCCCCCCCCCCCCCCCCEEECCCCEEECHHHHHHHHHCCCCCCCCEECCCCCCC\n" + "EEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEE\n", + "EEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEE\n", + "EEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEE\n", + "EEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEE\n", + "EEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEE\n", + "EEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEE\n", + "EEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEE\n", + "EEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEE\n", + "EEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEE\n", + "EEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEE\n" ] } ], @@ -995,12 +994,415 @@ " print(label_tokenizer.ids_to_text(ss_ids.tolist()))" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Fine-tuning with LoRA\n", + "Finte with LoRA is supported. In this regime, the encoder and the embedding layers are frozen, and LoRA weights are added to those layers. The classification and regression heads are not frozen. LoRA fine-tuning is supported for any of the classification types above. The outputted weights in the results directory only contain the LoRA weights and the classification and regression heads. For further inference and training, both the original model weights and fine-tuned weights are necessary. " + ] + }, { "cell_type": "code", - "execution_count": null, + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[INFO | pytorch_lightning.utilities.rank_zero]: Trainer already configured with model summary callbacks: []. Skipping setting a default `ModelSummary` callback.\n", + "[INFO | pytorch_lightning.utilities.rank_zero]: GPU available: True (cuda), used: True\n", + "[INFO | pytorch_lightning.utilities.rank_zero]: TPU available: False, using: 0 TPU cores\n", + "[INFO | pytorch_lightning.utilities.rank_zero]: HPU available: False, using: 0 HPUs\n", + "[NeMo W 2025-04-25 04:50:48 nemo_logging:405] WandB is currently turned off.\n", + "[NeMo W 2025-04-25 04:50:48 nemo_logging:405] User-set tensorboard is currently turned off. Internally one may still be set by NeMo2.\n", + "[NeMo I 2025-04-25 04:50:48 nemo_logging:393] Experiments will be logged at /workspace/bionemo2/esm2_finetune_tutorial/lora-token-level-classification/dev\n", + "[NeMo W 2025-04-25 04:50:48 nemo_logging:405] The Trainer already contains a ModelCheckpoint callback. This will be overwritten.\n", + "[NeMo W 2025-04-25 04:50:48 nemo_logging:405] The checkpoint callback was told to monitor a validation value and trainer's max_steps was set to 50. Please ensure that max_steps will run for at least 1 epochs to ensure that checkpointing will not error out.\n", + "[NeMo I 2025-04-25 04:50:48 nemo_logging:393] Fixing mis-match between ddp-config & mcore-optimizer config\n", + "[NeMo I 2025-04-25 04:50:48 nemo_logging:393] Rank 0 has data parallel group : [0]\n", + "[NeMo I 2025-04-25 04:50:48 nemo_logging:393] Rank 0 has combined group of data parallel and context parallel : [0]\n", + "[NeMo I 2025-04-25 04:50:48 nemo_logging:393] All data parallel group ranks with context parallel combined: [[0]]\n", + "[NeMo I 2025-04-25 04:50:48 nemo_logging:393] Ranks 0 has data parallel rank: 0\n", + "[NeMo I 2025-04-25 04:50:48 nemo_logging:393] Rank 0 has context parallel group: [0]\n", + "[NeMo I 2025-04-25 04:50:48 nemo_logging:393] All context parallel group ranks: [[0]]\n", + "[NeMo I 2025-04-25 04:50:48 nemo_logging:393] Ranks 0 has context parallel rank: 0\n", + "[NeMo I 2025-04-25 04:50:48 nemo_logging:393] Rank 0 has model parallel group: [0]\n", + "[NeMo I 2025-04-25 04:50:48 nemo_logging:393] All model parallel group ranks: [[0]]\n", + "[NeMo I 2025-04-25 04:50:48 nemo_logging:393] Rank 0 has tensor model parallel group: [0]\n", + "[NeMo I 2025-04-25 04:50:48 nemo_logging:393] All tensor model parallel group ranks: [[0]]\n", + "[NeMo I 2025-04-25 04:50:48 nemo_logging:393] Rank 0 has tensor model parallel rank: 0\n", + "[NeMo I 2025-04-25 04:50:48 nemo_logging:393] Rank 0 has pipeline model parallel group: [0]\n", + "[NeMo I 2025-04-25 04:50:48 nemo_logging:393] Rank 0 has embedding group: [0]\n", + "[NeMo I 2025-04-25 04:50:48 nemo_logging:393] All pipeline model parallel group ranks: [[0]]\n", + "[NeMo I 2025-04-25 04:50:48 nemo_logging:393] Rank 0 has pipeline model parallel rank 0\n", + "[NeMo I 2025-04-25 04:50:48 nemo_logging:393] All embedding group ranks: [[0]]\n", + "[NeMo I 2025-04-25 04:50:48 nemo_logging:393] Rank 0 has embedding rank: 0\n", + "[INFO | pytorch_lightning.utilities.rank_zero]: ----------------------------------------------------------------------------------------------------\n", + "distributed_backend=nccl\n", + "All distributed processes registered. Starting with 1 processes\n", + "----------------------------------------------------------------------------------------------------\n", + "\n", + "[NeMo I 2025-04-25 04:50:48 nemo_logging:393] Setting up ModelTransform for stage: TrainerFn.FITTING\n", + "[NeMo I 2025-04-25 04:50:48 nemo_logging:393] Found model_transform attribute on pl_module\n", + "[NeMo I 2025-04-25 04:50:48 nemo_logging:393] Set model_transform to: .wrapper at 0x770fd71d2840>\n", + "[WARNING | /workspaces/bionemo-framework/sub-packages/bionemo-llm/src/bionemo/llm/model/config.py]: Loading /home/ubuntu/.cache/bionemo/2957b2c36d5978d0f595d6f1b72104b312621cf0329209086537b613c1c96d16-esm2_hf_converted_8m_checkpoint.tar.gz.untar\n", + "[NeMo I 2025-04-25 04:50:48 nemo_logging:393] Padded vocab_size: 128, original vocab_size: 33, dummy tokens: 95.\n", + "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", + "[NeMo W 2025-04-25 04:50:48 nemo_logging:405] Could not copy Trainer's 'max_steps' to LR scheduler's 'max_steps'. If you are not using an LR scheduler, this warning can safely be ignored.\n", + "[NeMo I 2025-04-25 04:50:48 num_microbatches_calculator:228] setting number of microbatches to constant 1\n", + "┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━┳━━━━━━━┓\n", + "┃\u001b[1;35m \u001b[0m\u001b[1;35m \u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mName \u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mType \u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mParams\u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mMode \u001b[0m\u001b[1;35m \u001b[0m┃\n", + "┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━╇━━━━━━━┩\n", + "│\u001b[2m \u001b[0m\u001b[2m0 \u001b[0m\u001b[2m \u001b[0m│ valid_metric │ Multicl… │ 0 │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m1 \u001b[0m\u001b[2m \u001b[0m│ module │ ESM2Fin… │ 7.6 M │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m2 \u001b[0m\u001b[2m \u001b[0m│ module.embedding │ ESM2Emb… │ 41.0 K │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m3 \u001b[0m\u001b[2m \u001b[0m│ module.embedding.word_embeddings │ VocabPa… │ 41.0 K │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m4 \u001b[0m\u001b[2m \u001b[0m│ module.embedding.embedding_dropout │ Dropout │ 0 │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m5 \u001b[0m\u001b[2m \u001b[0m│ module.rotary_pos_emb │ RotaryE… │ 0 │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m6 \u001b[0m\u001b[2m \u001b[0m│ module.encoder │ Transfo… │ 7.4 M │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m7 \u001b[0m\u001b[2m \u001b[0m│ module.encoder.layers │ ModuleL… │ 7.4 M │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m8 \u001b[0m\u001b[2m \u001b[0m│ module.encoder.layers.0 │ Transfo… │ 1.2 M │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m9 \u001b[0m\u001b[2m \u001b[0m│ module.encoder.layers.1 │ Transfo… │ 1.2 M │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m10\u001b[0m\u001b[2m \u001b[0m│ module.encoder.layers.2 │ Transfo… │ 1.2 M │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m11\u001b[0m\u001b[2m \u001b[0m│ module.encoder.layers.3 │ Transfo… │ 1.2 M │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m12\u001b[0m\u001b[2m \u001b[0m│ module.encoder.layers.4 │ Transfo… │ 1.2 M │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m13\u001b[0m\u001b[2m \u001b[0m│ module.encoder.layers.5 │ Transfo… │ 1.2 M │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m14\u001b[0m\u001b[2m \u001b[0m│ module.encoder.final_layernorm │ LayerNo… │ 640 │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m15\u001b[0m\u001b[2m \u001b[0m│ module.lm_head │ BertLMH… │ 103 K │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m16\u001b[0m\u001b[2m \u001b[0m│ module.lm_head.dense │ Linear │ 102 K │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m17\u001b[0m\u001b[2m \u001b[0m│ module.lm_head.layer_norm │ FusedLa… │ 640 │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m18\u001b[0m\u001b[2m \u001b[0m│ module.output_layer │ ColumnP… │ 128 │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m19\u001b[0m\u001b[2m \u001b[0m│ module.classification_head │ Megatro… │ 72.4 K │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m20\u001b[0m\u001b[2m \u001b[0m│ module.classification_head.finetune_model │ Sequent… │ 71.7 K │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m21\u001b[0m\u001b[2m \u001b[0m│ module.classification_head.finetune_model.0 │ Conv2d │ 71.7 K │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m22\u001b[0m\u001b[2m \u001b[0m│ module.classification_head.finetune_model.1 │ ReLU │ 0 │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m23\u001b[0m\u001b[2m \u001b[0m│ module.classification_head.finetune_model.2 │ Dropout │ 0 │ train │\n", + "│\u001b[2m \u001b[0m\u001b[2m24\u001b[0m\u001b[2m \u001b[0m│ module.classification_head.class_heads │ Conv2d │ 675 │ train │\n", + "└────┴─────────────────────────────────────────────┴──────────┴────────┴───────┘\n", + "\u001b[1mTrainable params\u001b[0m: 72.4 K \n", + "\u001b[1mNon-trainable params\u001b[0m: 7.5 M \n", + "\u001b[1mTotal params\u001b[0m: 7.6 M \n", + "\u001b[1mTotal estimated model params size (MB)\u001b[0m: 30 \n", + "\u001b[1mModules in train mode\u001b[0m: 139 \n", + "\u001b[1mModules in eval mode\u001b[0m: 0 \n", + "[NeMo W 2025-04-25 04:50:48 nemo_logging:405] /usr/local/lib/python3.12/dist-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=63` in the `DataLoader` to improve performance.\n", + " \n", + "[NeMo I 2025-04-25 04:50:48 nemo_logging:393] Adding lora to: 0.module.encoder.layers.0.self_attention.linear_proj\n", + "[NeMo I 2025-04-25 04:50:48 nemo_logging:393] Adding lora to: 0.module.encoder.layers.0.self_attention.linear_qkv\n", + "[NeMo I 2025-04-25 04:50:48 nemo_logging:393] Adding lora to: 0.module.encoder.layers.0.mlp.linear_fc1\n", + "[NeMo I 2025-04-25 04:50:48 nemo_logging:393] Adding lora to: 0.module.encoder.layers.0.mlp.linear_fc2\n", + "[NeMo I 2025-04-25 04:50:48 nemo_logging:393] Adding lora to: 0.module.encoder.layers.1.self_attention.linear_proj\n", + "[NeMo I 2025-04-25 04:50:48 nemo_logging:393] Adding lora to: 0.module.encoder.layers.1.self_attention.linear_qkv\n", + "[NeMo I 2025-04-25 04:50:48 nemo_logging:393] Adding lora to: 0.module.encoder.layers.1.mlp.linear_fc1\n", + "[NeMo I 2025-04-25 04:50:48 nemo_logging:393] Adding lora to: 0.module.encoder.layers.1.mlp.linear_fc2\n", + "[NeMo I 2025-04-25 04:50:48 nemo_logging:393] Adding lora to: 0.module.encoder.layers.2.self_attention.linear_proj\n", + "[NeMo I 2025-04-25 04:50:48 nemo_logging:393] Adding lora to: 0.module.encoder.layers.2.self_attention.linear_qkv\n", + "[NeMo I 2025-04-25 04:50:48 nemo_logging:393] Adding lora to: 0.module.encoder.layers.2.mlp.linear_fc1\n", + "[NeMo I 2025-04-25 04:50:48 nemo_logging:393] Adding lora to: 0.module.encoder.layers.2.mlp.linear_fc2\n", + "[NeMo I 2025-04-25 04:50:48 nemo_logging:393] Adding lora to: 0.module.encoder.layers.3.self_attention.linear_proj\n", + "[NeMo I 2025-04-25 04:50:48 nemo_logging:393] Adding lora to: 0.module.encoder.layers.3.self_attention.linear_qkv\n", + "[NeMo I 2025-04-25 04:50:48 nemo_logging:393] Adding lora to: 0.module.encoder.layers.3.mlp.linear_fc1\n", + "[NeMo I 2025-04-25 04:50:48 nemo_logging:393] Adding lora to: 0.module.encoder.layers.3.mlp.linear_fc2\n", + "[NeMo I 2025-04-25 04:50:48 nemo_logging:393] Adding lora to: 0.module.encoder.layers.4.self_attention.linear_proj\n", + "[NeMo I 2025-04-25 04:50:48 nemo_logging:393] Adding lora to: 0.module.encoder.layers.4.self_attention.linear_qkv\n", + "[NeMo I 2025-04-25 04:50:48 nemo_logging:393] Adding lora to: 0.module.encoder.layers.4.mlp.linear_fc1\n", + "[NeMo I 2025-04-25 04:50:48 nemo_logging:393] Adding lora to: 0.module.encoder.layers.4.mlp.linear_fc2\n", + "[NeMo I 2025-04-25 04:50:48 nemo_logging:393] Adding lora to: 0.module.encoder.layers.5.self_attention.linear_proj\n", + "[NeMo I 2025-04-25 04:50:48 nemo_logging:393] Adding lora to: 0.module.encoder.layers.5.self_attention.linear_qkv\n", + "[NeMo I 2025-04-25 04:50:48 nemo_logging:393] Adding lora to: 0.module.encoder.layers.5.mlp.linear_fc1\n", + "[NeMo I 2025-04-25 04:50:48 nemo_logging:393] Adding lora to: 0.module.encoder.layers.5.mlp.linear_fc2\n", + "[NeMo I 2025-04-25 04:50:48 nemo_logging:393] After applying model_transform:\n", + " \n", + " | Name | Type | Params | Mode\n", + "---------------------------------------------------------------\n", + "0 | valid_metric | MulticlassAccuracy | 0 | eval\n", + "1 | module | ESM2FineTuneTokenModel | 8.6 M | eval\n", + "---------------------------------------------------------------\n", + "1.1 M Trainable params\n", + "7.5 M Non-trainable params\n", + "8.6 M Total params\n", + "34.393 Total estimated model params size (MB)\n", + "122 Modules in train mode\n", + "137 Modules in eval mode\n", + "[NeMo I 2025-04-25 04:50:48 nemo_logging:393] Initializing model parallel\n", + "[NeMo I 2025-04-25 04:50:48 nemo_logging:393] > number of parameters on (tensor, pipeline) model parallel rank (0 ,0): 8598275\n", + "[NeMo I 2025-04-25 04:50:48 nemo_logging:393] > number of trainable parameters: 1055427 (12.27% of total)\n", + "[NeMo I 2025-04-25 04:50:48 utils:302] Setting up DistributedDataParallel with config DistributedDataParallelConfig(grad_reduce_in_fp32=False, overlap_grad_reduce=False, overlap_param_gather=True, align_param_gather=False, use_distributed_optimizer=True, num_distributed_optimizer_instances=1, check_for_nan_in_grad=True, check_for_large_grads=False, bucket_size=None, average_in_collective=True, fp8_param_gather=False)\n", + "[NeMo I 2025-04-25 04:50:48 utils:323] Number of buckets for gradient all-reduce / reduce-scatter: 1\n", + " Params for bucket 1 (1055427 elements):\n", + " \tmodule.encoder.layers.5.self_attention.linear_proj.adapter.linear_in.weight\n", + " \tmodule.encoder.layers.3.mlp.linear_fc1.adapter.linear_out.weight\n", + " \tmodule.encoder.layers.3.mlp.linear_fc1.adapter.linear_in.weight\n", + " \tmodule.encoder.layers.1.self_attention.linear_qkv.adapter.linear_out.weight\n", + " \tmodule.encoder.layers.0.self_attention.linear_qkv.adapter.linear_out.weight\n", + " \tmodule.encoder.layers.0.self_attention.linear_proj.adapter.linear_in.weight\n", + " \tmodule.classification_head.class_heads.bias\n", + " \tmodule.encoder.layers.5.mlp.linear_fc2.adapter.linear_in.weight\n", + " \tmodule.encoder.layers.4.mlp.linear_fc1.adapter.linear_out.weight\n", + " \tmodule.encoder.layers.3.self_attention.linear_proj.adapter.linear_out.weight\n", + " \tmodule.encoder.layers.1.mlp.linear_fc1.adapter.linear_in.weight\n", + " \tmodule.encoder.layers.0.mlp.linear_fc2.adapter.linear_out.weight\n", + " \tmodule.encoder.layers.0.mlp.linear_fc1.adapter.linear_in.weight\n", + " \tmodule.encoder.layers.0.self_attention.linear_proj.adapter.linear_out.weight\n", + " \tmodule.encoder.layers.5.mlp.linear_fc1.adapter.linear_out.weight\n", + " \tmodule.encoder.layers.5.self_attention.linear_qkv.adapter.linear_out.weight\n", + " \tmodule.encoder.layers.2.self_attention.linear_qkv.adapter.linear_out.weight\n", + " \tmodule.encoder.layers.1.mlp.linear_fc2.adapter.linear_out.weight\n", + " \tmodule.encoder.layers.4.self_attention.linear_qkv.adapter.linear_out.weight\n", + " \tmodule.encoder.layers.5.mlp.linear_fc2.adapter.linear_out.weight\n", + " \tmodule.classification_head.finetune_model.0.bias\n", + " \tmodule.encoder.layers.4.self_attention.linear_proj.adapter.linear_in.weight\n", + " \tmodule.encoder.layers.3.mlp.linear_fc2.adapter.linear_out.weight\n", + " \tmodule.encoder.layers.2.mlp.linear_fc2.adapter.linear_out.weight\n", + " \tmodule.encoder.layers.2.self_attention.linear_proj.adapter.linear_out.weight\n", + " \tmodule.encoder.layers.1.self_attention.linear_qkv.adapter.linear_in.weight\n", + " \tmodule.encoder.layers.1.self_attention.linear_proj.adapter.linear_in.weight\n", + " \tmodule.classification_head.class_heads.weight\n", + " \tmodule.encoder.layers.4.mlp.linear_fc1.adapter.linear_in.weight\n", + " \tmodule.encoder.layers.4.self_attention.linear_qkv.adapter.linear_in.weight\n", + " \tmodule.encoder.layers.3.mlp.linear_fc2.adapter.linear_in.weight\n", + " \tmodule.encoder.layers.2.mlp.linear_fc1.adapter.linear_in.weight\n", + " \tmodule.encoder.layers.1.self_attention.linear_proj.adapter.linear_out.weight\n", + " \tmodule.encoder.layers.0.mlp.linear_fc2.adapter.linear_in.weight\n", + " \tmodule.classification_head.finetune_model.0.weight\n", + " \tmodule.encoder.layers.5.mlp.linear_fc1.adapter.linear_in.weight\n", + " \tmodule.encoder.layers.4.mlp.linear_fc2.adapter.linear_out.weight\n", + " \tmodule.encoder.layers.3.self_attention.linear_proj.adapter.linear_in.weight\n", + " \tmodule.encoder.layers.2.mlp.linear_fc1.adapter.linear_out.weight\n", + " \tmodule.encoder.layers.1.mlp.linear_fc1.adapter.linear_out.weight\n", + " \tmodule.encoder.layers.4.self_attention.linear_proj.adapter.linear_out.weight\n", + " \tmodule.encoder.layers.3.self_attention.linear_qkv.adapter.linear_in.weight\n", + " \tmodule.encoder.layers.2.mlp.linear_fc2.adapter.linear_in.weight\n", + " \tmodule.encoder.layers.2.self_attention.linear_proj.adapter.linear_in.weight\n", + " \tmodule.encoder.layers.1.mlp.linear_fc2.adapter.linear_in.weight\n", + " \tmodule.encoder.layers.0.self_attention.linear_qkv.adapter.linear_in.weight\n", + " \tmodule.encoder.layers.5.self_attention.linear_qkv.adapter.linear_in.weight\n", + " \tmodule.encoder.layers.5.self_attention.linear_proj.adapter.linear_out.weight\n", + " \tmodule.encoder.layers.4.mlp.linear_fc2.adapter.linear_in.weight\n", + " \tmodule.encoder.layers.3.self_attention.linear_qkv.adapter.linear_out.weight\n", + " \tmodule.encoder.layers.2.self_attention.linear_qkv.adapter.linear_in.weight\n", + " \tmodule.encoder.layers.0.mlp.linear_fc1.adapter.linear_out.weight\n", + "[NeMo I 2025-04-25 04:50:48 nemo_logging:393] Setting up optimizers\n", + "[NeMo I 2025-04-25 04:50:48 utils:302] Setting up optimizer with config OptimizerConfig(optimizer='adam', lr=0.005, min_lr=None, decoupled_lr=None, decoupled_min_lr=None, weight_decay=0.01, fp16=False, bf16=True, params_dtype=torch.bfloat16, use_precision_aware_optimizer=False, main_grads_dtype=torch.float32, main_params_dtype=torch.float32, exp_avg_dtype=torch.float32, exp_avg_sq_dtype=torch.float32, loss_scale=None, initial_loss_scale=4294967296, min_loss_scale=1.0, loss_scale_window=1000, hysteresis=2, adam_beta1=0.9, adam_beta2=0.98, adam_eps=1e-08, sgd_momentum=0.9, use_distributed_optimizer=True, overlap_param_gather_with_optimizer_step=False, optimizer_cpu_offload=False, optimizer_offload_fraction=0.0, use_torch_optimizer_for_cpu_offload=False, overlap_cpu_optimizer_d2h_h2d=False, pin_cpu_grads=True, pin_cpu_params=True, clip_grad=1.0, log_num_zeros_in_grad=False, barrier_with_L1_time=False, timers=None, config_logger_dir='')\n", + "Sanity checking Validation: iteration 1/2\n", + "Sanity checking Validation: iteration 2/2\n", + "[NeMo W 2025-04-25 04:50:51 nemo_logging:405] /usr/local/lib/python3.12/dist-packages/lightning/pytorch/trainer/connectors/logger_connector/result.py:431: It is recommended to use `self.log('global_batch_size', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.\n", + " \n", + "[NeMo W 2025-04-25 04:50:51 nemo_logging:405] /usr/local/lib/python3.12/dist-packages/lightning/pytorch/trainer/connectors/logger_connector/result.py:431: It is recommended to use `self.log('val_loss', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.\n", + " \n", + "[NeMo W 2025-04-25 04:50:51 nemo_logging:405] /usr/local/lib/python3.12/dist-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=63` in the `DataLoader` to improve performance.\n", + " \n", + "[NeMo W 2025-04-25 04:50:52 rerun_state_machine:1264] Implicit initialization of Rerun State Machine!\n", + "[NeMo W 2025-04-25 04:50:52 rerun_state_machine:239] RerunStateMachine initialized in mode RerunMode.DISABLED\n", + "Training epoch 0, iteration 0/49 | lr: 0.005 | global_batch_size: 2 | global_step: 0 | reduced_train_loss: 1.094\n", + "Training epoch 0, iteration 1/49 | lr: 0.005 | global_batch_size: 2 | global_step: 1 | reduced_train_loss: 3376 | consumed_samples: 4\n", + "Training epoch 0, iteration 2/49 | lr: 0.005 | global_batch_size: 2 | global_step: 2 | reduced_train_loss: 1216 | consumed_samples: 6\n", + "Training epoch 0, iteration 3/49 | lr: 0.005 | global_batch_size: 2 | global_step: 3 | reduced_train_loss: 8704 | consumed_samples: 8\n", + "Training epoch 0, iteration 4/49 | lr: 0.005 | global_batch_size: 2 | global_step: 4 | reduced_train_loss: 2064 | consumed_samples: 10\n", + "Training epoch 0, iteration 5/49 | lr: 0.005 | global_batch_size: 2 | global_step: 5 | reduced_train_loss: 7776 | consumed_samples: 12\n", + "Training epoch 0, iteration 6/49 | lr: 0.005 | global_batch_size: 2 | global_step: 6 | reduced_train_loss: 5952 | consumed_samples: 14\n", + "Training epoch 0, iteration 7/49 | lr: 0.005 | global_batch_size: 2 | global_step: 7 | reduced_train_loss: 8448 | consumed_samples: 16\n", + "Training epoch 0, iteration 8/49 | lr: 0.005 | global_batch_size: 2 | global_step: 8 | reduced_train_loss: 3152 | consumed_samples: 18\n", + "Training epoch 0, iteration 9/49 | lr: 0.005 | global_batch_size: 2 | global_step: 9 | reduced_train_loss: 5952 | consumed_samples: 20\n", + "[INFO | pytorch_lightning.utilities.rank_zero]: Epoch 0, global step 9: 'val_loss' was not in top 2\n", + "[NeMo I 2025-04-25 04:50:53 nemo_logging:393] Using FullyParallelSaveStrategyWrapper(torch_dist, 1) dist-ckpt save strategy.\n", + "[NeMo I 2025-04-25 04:51:03 nemo_logging:393] Global Checkpoint Save : Rank: 0 : Iteration: 9 : Start time: 1745556653.726s : Save duration: 9.420s\n", + "[NeMo I 2025-04-25 04:51:03 nemo_logging:393] Scheduled async checkpoint save for /workspace/bionemo2/esm2_finetune_tutorial/lora-token-level-classification/dev/checkpoints/checkpoint-step=9-consumed_samples=20.0-last.ckpt\n", + "[NeMo I 2025-04-25 04:51:03 nemo_logging:393] Async finalization time took 0.001 s\n", + "Validation: iteration 1/2\n", + "Validation: iteration 2/2\n", + "Training epoch 0, iteration 10/49 | lr: 0.005 | global_batch_size: 2 | global_step: 10 | reduced_train_loss: 5152 | consumed_samples: 22 | val_loss: 4.981e+03 | val_acc: 0.136\n", + "[NeMo I 2025-04-25 04:51:03 nemo_logging:393] Async finalization time took 0.000 s\n", + "Training epoch 0, iteration 11/49 | lr: 0.005 | global_batch_size: 2 | global_step: 11 | reduced_train_loss: 1104 | consumed_samples: 24 | val_loss: 4.981e+03 | val_acc: 0.136\n", + "[NeMo I 2025-04-25 04:51:03 nemo_logging:393] Successfully saved checkpoint from iteration 9 to /workspace/bionemo2/esm2_finetune_tutorial/lora-token-level-classification/dev/checkpoints/checkpoint-step=9-consumed_samples=20.0-last.ckpt\n", + "[NeMo I 2025-04-25 04:51:03 nemo_logging:393] Async checkpoint save for step 10 (/workspace/bionemo2/esm2_finetune_tutorial/lora-token-level-classification/dev/checkpoints/checkpoint-step=9-consumed_samples=20.0-last.ckpt) finalized successfully.\n", + "[NeMo I 2025-04-25 04:51:03 nemo_logging:393] Async finalization time took 0.016 s\n", + "Training epoch 0, iteration 12/49 | lr: 0.005 | global_batch_size: 2 | global_step: 12 | reduced_train_loss: 15808 | consumed_samples: 26 | val_loss: 4.981e+03 | val_acc: 0.136\n", + "Training epoch 0, iteration 13/49 | lr: 0.005 | global_batch_size: 2 | global_step: 13 | reduced_train_loss: 8832 | consumed_samples: 28 | val_loss: 4.981e+03 | val_acc: 0.136\n", + "Training epoch 0, iteration 14/49 | lr: 0.005 | global_batch_size: 2 | global_step: 14 | reduced_train_loss: 6208 | consumed_samples: 30 | val_loss: 4.981e+03 | val_acc: 0.136\n", + "Training epoch 0, iteration 15/49 | lr: 0.005 | global_batch_size: 2 | global_step: 15 | reduced_train_loss: 494 | consumed_samples: 32 | val_loss: 4.981e+03 | val_acc: 0.136\n", + "Training epoch 0, iteration 16/49 | lr: 0.005 | global_batch_size: 2 | global_step: 16 | reduced_train_loss: 1256 | consumed_samples: 34 | val_loss: 4.981e+03 | val_acc: 0.136\n", + "Training epoch 0, iteration 17/49 | lr: 0.005 | global_batch_size: 2 | global_step: 17 | reduced_train_loss: 6624 | consumed_samples: 36 | val_loss: 4.981e+03 | val_acc: 0.136\n", + "Training epoch 0, iteration 18/49 | lr: 0.005 | global_batch_size: 2 | global_step: 18 | reduced_train_loss: 4320 | consumed_samples: 38 | val_loss: 4.981e+03 | val_acc: 0.136\n", + "Training epoch 0, iteration 19/49 | lr: 0.005 | global_batch_size: 2 | global_step: 19 | reduced_train_loss: 2784 | consumed_samples: 40 | val_loss: 4.981e+03 | val_acc: 0.136\n", + "[INFO | pytorch_lightning.utilities.rank_zero]: Epoch 0, global step 19: 'val_loss' reached 4981.33350 (best 4981.33350), saving model to '/workspace/bionemo2/esm2_finetune_tutorial/lora-token-level-classification/dev/checkpoints/checkpoint-step=19-consumed_samples=40.0.ckpt' as top 2\n", + "[NeMo I 2025-04-25 04:51:03 nemo_logging:393] Global Checkpoint Save : Rank: 0 : Iteration: 19 : Start time: 1745556663.872s : Save duration: 0.037s\n", + "[NeMo I 2025-04-25 04:51:03 nemo_logging:393] Scheduled async checkpoint save for /workspace/bionemo2/esm2_finetune_tutorial/lora-token-level-classification/dev/checkpoints/checkpoint-step=19-consumed_samples=40.0.ckpt\n", + "[NeMo I 2025-04-25 04:51:04 nemo_logging:393] Global Checkpoint Save : Rank: 0 : Iteration: 19 : Start time: 1745556663.992s : Save duration: 0.038s\n", + "[NeMo I 2025-04-25 04:51:04 nemo_logging:393] Scheduled async checkpoint save for /workspace/bionemo2/esm2_finetune_tutorial/lora-token-level-classification/dev/checkpoints/checkpoint-step=19-consumed_samples=40.0-last.ckpt\n", + "[NeMo I 2025-04-25 04:51:04 nemo_logging:393] Async finalization time took 0.001 s\n", + "Validation: iteration 1/2\n", + "Validation: iteration 2/2\n", + "Training epoch 0, iteration 20/49 | lr: 0.005 | global_batch_size: 2 | global_step: 20 | reduced_train_loss: 7968 | consumed_samples: 42 | val_loss: 7.211e+03 | val_acc: 0.2105\n", + "[NeMo I 2025-04-25 04:51:04 nemo_logging:393] Successfully saved checkpoint from iteration 19 to /workspace/bionemo2/esm2_finetune_tutorial/lora-token-level-classification/dev/checkpoints/checkpoint-step=19-consumed_samples=40.0.ckpt\n", + "[NeMo I 2025-04-25 04:51:04 nemo_logging:393] Async checkpoint save for step 20 (/workspace/bionemo2/esm2_finetune_tutorial/lora-token-level-classification/dev/checkpoints/checkpoint-step=19-consumed_samples=40.0.ckpt) finalized successfully.\n", + "[NeMo I 2025-04-25 04:51:04 nemo_logging:393] Async finalization time took 0.013 s\n", + "Training epoch 0, iteration 21/49 | lr: 0.005 | global_batch_size: 2 | global_step: 21 | reduced_train_loss: 4640 | consumed_samples: 44 | val_loss: 7.211e+03 | val_acc: 0.2105\n", + "[NeMo I 2025-04-25 04:51:04 nemo_logging:393] Successfully saved checkpoint from iteration 19 to /workspace/bionemo2/esm2_finetune_tutorial/lora-token-level-classification/dev/checkpoints/checkpoint-step=19-consumed_samples=40.0-last.ckpt\n", + "[NeMo I 2025-04-25 04:51:04 nemo_logging:393] Async checkpoint save for step 20 (/workspace/bionemo2/esm2_finetune_tutorial/lora-token-level-classification/dev/checkpoints/checkpoint-step=19-consumed_samples=40.0-last.ckpt) finalized successfully.\n", + "[NeMo I 2025-04-25 04:51:04 nemo_logging:393] Async finalization time took 0.015 s\n", + "Training epoch 0, iteration 22/49 | lr: 0.005 | global_batch_size: 2 | global_step: 22 | reduced_train_loss: 4128 | consumed_samples: 46 | val_loss: 7.211e+03 | val_acc: 0.2105\n", + "Training epoch 0, iteration 23/49 | lr: 0.005 | global_batch_size: 2 | global_step: 23 | reduced_train_loss: 1048 | consumed_samples: 48 | val_loss: 7.211e+03 | val_acc: 0.2105\n", + "Training epoch 0, iteration 24/49 | lr: 0.005 | global_batch_size: 2 | global_step: 24 | reduced_train_loss: 510 | consumed_samples: 50 | val_loss: 7.211e+03 | val_acc: 0.2105\n", + "Training epoch 0, iteration 25/49 | lr: 0.005 | global_batch_size: 2 | global_step: 25 | reduced_train_loss: 13248 | consumed_samples: 52 | val_loss: 7.211e+03 | val_acc: 0.2105\n", + "Training epoch 0, iteration 26/49 | lr: 0.005 | global_batch_size: 2 | global_step: 26 | reduced_train_loss: 10432 | consumed_samples: 54 | val_loss: 7.211e+03 | val_acc: 0.2105\n", + "Training epoch 0, iteration 27/49 | lr: 0.005 | global_batch_size: 2 | global_step: 27 | reduced_train_loss: 1704 | consumed_samples: 56 | val_loss: 7.211e+03 | val_acc: 0.2105\n", + "Training epoch 0, iteration 28/49 | lr: 0.005 | global_batch_size: 2 | global_step: 28 | reduced_train_loss: 1760 | consumed_samples: 58 | val_loss: 7.211e+03 | val_acc: 0.2105\n", + "Training epoch 0, iteration 29/49 | lr: 0.005 | global_batch_size: 2 | global_step: 29 | reduced_train_loss: 162 | consumed_samples: 60 | val_loss: 7.211e+03 | val_acc: 0.2105\n", + "[INFO | pytorch_lightning.utilities.rank_zero]: Epoch 0, global step 29: 'val_loss' reached 7210.66650 (best 4981.33350), saving model to '/workspace/bionemo2/esm2_finetune_tutorial/lora-token-level-classification/dev/checkpoints/checkpoint-step=29-consumed_samples=60.0.ckpt' as top 2\n", + "[NeMo I 2025-04-25 04:51:04 nemo_logging:393] Global Checkpoint Save : Rank: 0 : Iteration: 29 : Start time: 1745556664.742s : Save duration: 0.036s\n", + "[NeMo I 2025-04-25 04:51:04 nemo_logging:393] Scheduled async checkpoint save for /workspace/bionemo2/esm2_finetune_tutorial/lora-token-level-classification/dev/checkpoints/checkpoint-step=29-consumed_samples=60.0.ckpt\n", + "[NeMo I 2025-04-25 04:51:04 nemo_logging:393] Global Checkpoint Save : Rank: 0 : Iteration: 29 : Start time: 1745556664.867s : Save duration: 0.039s\n", + "[NeMo I 2025-04-25 04:51:04 nemo_logging:393] Scheduled async checkpoint save for /workspace/bionemo2/esm2_finetune_tutorial/lora-token-level-classification/dev/checkpoints/checkpoint-step=29-consumed_samples=60.0-last.ckpt\n", + "[NeMo I 2025-04-25 04:51:04 nemo_logging:393] Async finalization time took 0.000 s\n", + "Validation: iteration 1/2\n", + "Validation: iteration 2/2\n", + "Training epoch 0, iteration 30/49 | lr: 0.005 | global_batch_size: 2 | global_step: 30 | reduced_train_loss: 4.5 | consumed_samples: 62 | val_loss: 6.073 | val_acc: 0.3289\n", + "[NeMo I 2025-04-25 04:51:05 nemo_logging:393] Successfully saved checkpoint from iteration 29 to /workspace/bionemo2/esm2_finetune_tutorial/lora-token-level-classification/dev/checkpoints/checkpoint-step=29-consumed_samples=60.0.ckpt\n", + "[NeMo I 2025-04-25 04:51:05 nemo_logging:393] Async checkpoint save for step 30 (/workspace/bionemo2/esm2_finetune_tutorial/lora-token-level-classification/dev/checkpoints/checkpoint-step=29-consumed_samples=60.0.ckpt) finalized successfully.\n", + "[NeMo I 2025-04-25 04:51:05 nemo_logging:393] Async finalization time took 0.015 s\n", + "Training epoch 0, iteration 31/49 | lr: 0.005 | global_batch_size: 2 | global_step: 31 | reduced_train_loss: 2.781 | consumed_samples: 64 | val_loss: 6.073 | val_acc: 0.3289\n", + "[NeMo I 2025-04-25 04:51:05 nemo_logging:393] Successfully saved checkpoint from iteration 29 to /workspace/bionemo2/esm2_finetune_tutorial/lora-token-level-classification/dev/checkpoints/checkpoint-step=29-consumed_samples=60.0-last.ckpt\n", + "[NeMo I 2025-04-25 04:51:05 nemo_logging:393] Async checkpoint save for step 30 (/workspace/bionemo2/esm2_finetune_tutorial/lora-token-level-classification/dev/checkpoints/checkpoint-step=29-consumed_samples=60.0-last.ckpt) finalized successfully.\n", + "[NeMo I 2025-04-25 04:51:05 nemo_logging:393] Async finalization time took 0.016 s\n", + "Training epoch 0, iteration 32/49 | lr: 0.005 | global_batch_size: 2 | global_step: 32 | reduced_train_loss: 3.531 | consumed_samples: 66 | val_loss: 6.073 | val_acc: 0.3289\n", + "Training epoch 0, iteration 33/49 | lr: 0.005 | global_batch_size: 2 | global_step: 33 | reduced_train_loss: 3.344 | consumed_samples: 68 | val_loss: 6.073 | val_acc: 0.3289\n", + "Training epoch 0, iteration 34/49 | lr: 0.005 | global_batch_size: 2 | global_step: 34 | reduced_train_loss: 2.672 | consumed_samples: 70 | val_loss: 6.073 | val_acc: 0.3289\n", + "Training epoch 0, iteration 35/49 | lr: 0.005 | global_batch_size: 2 | global_step: 35 | reduced_train_loss: 2.016 | consumed_samples: 72 | val_loss: 6.073 | val_acc: 0.3289\n", + "Training epoch 0, iteration 36/49 | lr: 0.005 | global_batch_size: 2 | global_step: 36 | reduced_train_loss: 1.508 | consumed_samples: 74 | val_loss: 6.073 | val_acc: 0.3289\n", + "Training epoch 0, iteration 37/49 | lr: 0.005 | global_batch_size: 2 | global_step: 37 | reduced_train_loss: 1.18 | consumed_samples: 76 | val_loss: 6.073 | val_acc: 0.3289\n", + "Training epoch 0, iteration 38/49 | lr: 0.005 | global_batch_size: 2 | global_step: 38 | reduced_train_loss: 1.047 | consumed_samples: 78 | val_loss: 6.073 | val_acc: 0.3289\n", + "Training epoch 0, iteration 39/49 | lr: 0.005 | global_batch_size: 2 | global_step: 39 | reduced_train_loss: 0.9531 | consumed_samples: 80 | val_loss: 6.073 | val_acc: 0.3289\n", + "[INFO | pytorch_lightning.utilities.rank_zero]: Epoch 0, global step 39: 'val_loss' reached 6.07292 (best 6.07292), saving model to '/workspace/bionemo2/esm2_finetune_tutorial/lora-token-level-classification/dev/checkpoints/checkpoint-step=39-consumed_samples=80.0.ckpt' as top 2\n", + "[NeMo I 2025-04-25 04:51:05 nemo_logging:393] Global Checkpoint Save : Rank: 0 : Iteration: 39 : Start time: 1745556665.620s : Save duration: 0.037s\n", + "[NeMo I 2025-04-25 04:51:05 nemo_logging:393] Scheduled async checkpoint save for /workspace/bionemo2/esm2_finetune_tutorial/lora-token-level-classification/dev/checkpoints/checkpoint-step=39-consumed_samples=80.0.ckpt\n", + "[NeMo I 2025-04-25 04:51:05 nemo_logging:393] Global Checkpoint Save : Rank: 0 : Iteration: 39 : Start time: 1745556665.741s : Save duration: 0.039s\n", + "[NeMo I 2025-04-25 04:51:05 nemo_logging:393] Scheduled async checkpoint save for /workspace/bionemo2/esm2_finetune_tutorial/lora-token-level-classification/dev/checkpoints/checkpoint-step=39-consumed_samples=80.0-last.ckpt\n", + "[NeMo I 2025-04-25 04:51:05 nemo_logging:393] Async finalization time took 0.001 s\n", + "Validation: iteration 1/2\n", + "Validation: iteration 2/2\n", + "Training epoch 0, iteration 40/49 | lr: 0.005 | global_batch_size: 2 | global_step: 40 | reduced_train_loss: 1 | consumed_samples: 82 | val_loss: 1.234 | val_acc: 0.3114\n", + "[NeMo I 2025-04-25 04:51:05 nemo_logging:393] Successfully saved checkpoint from iteration 39 to /workspace/bionemo2/esm2_finetune_tutorial/lora-token-level-classification/dev/checkpoints/checkpoint-step=39-consumed_samples=80.0.ckpt\n", + "[NeMo I 2025-04-25 04:51:05 nemo_logging:393] Async checkpoint save for step 40 (/workspace/bionemo2/esm2_finetune_tutorial/lora-token-level-classification/dev/checkpoints/checkpoint-step=39-consumed_samples=80.0.ckpt) finalized successfully.\n", + "[NeMo I 2025-04-25 04:51:05 nemo_logging:393] Async finalization time took 0.017 s\n", + "Training epoch 0, iteration 41/49 | lr: 0.005 | global_batch_size: 2 | global_step: 41 | reduced_train_loss: 1.445 | consumed_samples: 84 | val_loss: 1.234 | val_acc: 0.3114\n", + "[NeMo I 2025-04-25 04:51:06 nemo_logging:393] Async finalization time took 0.000 s\n", + "Training epoch 0, iteration 42/49 | lr: 0.005 | global_batch_size: 2 | global_step: 42 | reduced_train_loss: 1.273 | consumed_samples: 86 | val_loss: 1.234 | val_acc: 0.3114\n", + "[NeMo I 2025-04-25 04:51:06 nemo_logging:393] Successfully saved checkpoint from iteration 39 to /workspace/bionemo2/esm2_finetune_tutorial/lora-token-level-classification/dev/checkpoints/checkpoint-step=39-consumed_samples=80.0-last.ckpt\n", + "[NeMo I 2025-04-25 04:51:06 nemo_logging:393] Async checkpoint save for step 40 (/workspace/bionemo2/esm2_finetune_tutorial/lora-token-level-classification/dev/checkpoints/checkpoint-step=39-consumed_samples=80.0-last.ckpt) finalized successfully.\n", + "[NeMo I 2025-04-25 04:51:06 nemo_logging:393] Async finalization time took 0.017 s\n", + "Training epoch 0, iteration 43/49 | lr: 0.005 | global_batch_size: 2 | global_step: 43 | reduced_train_loss: 2.266 | consumed_samples: 88 | val_loss: 1.234 | val_acc: 0.3114\n", + "Training epoch 0, iteration 44/49 | lr: 0.005 | global_batch_size: 2 | global_step: 44 | reduced_train_loss: 1.18 | consumed_samples: 90 | val_loss: 1.234 | val_acc: 0.3114\n", + "Training epoch 0, iteration 45/49 | lr: 0.005 | global_batch_size: 2 | global_step: 45 | reduced_train_loss: 1.391 | consumed_samples: 92 | val_loss: 1.234 | val_acc: 0.3114\n", + "Training epoch 0, iteration 46/49 | lr: 0.005 | global_batch_size: 2 | global_step: 46 | reduced_train_loss: 1.242 | consumed_samples: 94 | val_loss: 1.234 | val_acc: 0.3114\n", + "Training epoch 0, iteration 47/49 | lr: 0.005 | global_batch_size: 2 | global_step: 47 | reduced_train_loss: 1.578 | consumed_samples: 96 | val_loss: 1.234 | val_acc: 0.3114\n", + "Training epoch 0, iteration 48/49 | lr: 0.005 | global_batch_size: 2 | global_step: 48 | reduced_train_loss: 1.945 | consumed_samples: 98 | val_loss: 1.234 | val_acc: 0.3114\n", + "Training epoch 0, iteration 49/49 | lr: 0.005 | global_batch_size: 2 | global_step: 49 | reduced_train_loss: 1.07 | consumed_samples: 100 | val_loss: 1.234 | val_acc: 0.3114\n", + "[INFO | pytorch_lightning.utilities.rank_zero]: Epoch 0, global step 49: 'val_loss' reached 1.23438 (best 1.23438), saving model to '/workspace/bionemo2/esm2_finetune_tutorial/lora-token-level-classification/dev/checkpoints/checkpoint-step=49-consumed_samples=100.0.ckpt' as top 2\n", + "[NeMo I 2025-04-25 04:51:06 nemo_logging:393] Global Checkpoint Save : Rank: 0 : Iteration: 49 : Start time: 1745556666.484s : Save duration: 0.037s\n", + "[NeMo I 2025-04-25 04:51:06 nemo_logging:393] Scheduled async checkpoint save for /workspace/bionemo2/esm2_finetune_tutorial/lora-token-level-classification/dev/checkpoints/checkpoint-step=49-consumed_samples=100.0.ckpt\n", + "[NeMo I 2025-04-25 04:51:06 nemo_logging:393] Global Checkpoint Save : Rank: 0 : Iteration: 49 : Start time: 1745556666.605s : Save duration: 0.041s\n", + "[NeMo I 2025-04-25 04:51:06 nemo_logging:393] Scheduled async checkpoint save for /workspace/bionemo2/esm2_finetune_tutorial/lora-token-level-classification/dev/checkpoints/checkpoint-step=49-consumed_samples=100.0-last.ckpt\n", + "[NeMo I 2025-04-25 04:51:06 nemo_logging:393] Async finalization time took 0.000 s\n", + "Validation: iteration 1/2\n", + "Validation: iteration 2/2\n", + "[NeMo I 2025-04-25 04:51:06 nemo_logging:393] Successfully saved checkpoint from iteration 49 to /workspace/bionemo2/esm2_finetune_tutorial/lora-token-level-classification/dev/checkpoints/checkpoint-step=49-consumed_samples=100.0.ckpt\n", + "[NeMo I 2025-04-25 04:51:06 nemo_logging:393] Async checkpoint save for step 50 (/workspace/bionemo2/esm2_finetune_tutorial/lora-token-level-classification/dev/checkpoints/checkpoint-step=49-consumed_samples=100.0.ckpt) finalized successfully.\n", + "[NeMo I 2025-04-25 04:51:06 nemo_logging:393] Async finalization time took 0.017 s\n", + "[INFO | pytorch_lightning.utilities.rank_zero]: `Trainer.fit` stopped: `max_steps=50` reached.\n", + "[NeMo I 2025-04-25 04:51:06 nemo_logging:393] Pending async checkpoint saves. Finalizing them synchronously now\n", + "[NeMo I 2025-04-25 04:51:06 nemo_logging:393] Successfully saved checkpoint from iteration 49 to /workspace/bionemo2/esm2_finetune_tutorial/lora-token-level-classification/dev/checkpoints/checkpoint-step=49-consumed_samples=100.0-last.ckpt\n", + "[NeMo I 2025-04-25 04:51:06 nemo_logging:393] Async checkpoint save for step 50 (/workspace/bionemo2/esm2_finetune_tutorial/lora-token-level-classification/dev/checkpoints/checkpoint-step=49-consumed_samples=100.0-last.ckpt) finalized successfully.\n", + "[NeMo I 2025-04-25 04:51:06 nemo_logging:393] Async finalization time took 0.120 s\n" + ] + } + ], + "source": [ + "# %%capture --no-display --no-stderr cell_output\n", + "\n", + "data_path = os.path.join(work_dir, \"token_classification_data.csv\")\n", + "\n", + "! finetune_esm2 \\\n", + " --restore-from-checkpoint-path {pretrain_checkpoint_path} \\\n", + " --train-data-path {data_path} \\\n", + " --valid-data-path {data_path} \\\n", + " --config-class ESM2FineTuneTokenConfig \\\n", + " --dataset-class InMemoryPerTokenValueDataset \\\n", + " --task-type \"classification\" \\\n", + " --cnn-dropout 0.25 \\\n", + " --cnn-hidden-size 32 \\\n", + " --cnn-num-classes 3 \\\n", + " --experiment-name \"lora-token-level-classification\" \\\n", + " --num-steps 50 \\\n", + " --num-gpus 1 \\\n", + " --val-check-interval 10 \\\n", + " --log-every-n-steps 10 \\\n", + " --encoder-frozen \\\n", + " --lr 5e-3 \\\n", + " --lr-multiplier 1e2 \\\n", + " --scale-lr-layer \"classification_head\" \\\n", + " --result-dir {work_dir} \\\n", + " --micro-batch-size 2 \\\n", + " --num-gpus 1 \\\n", + " --precision \"bf16-mixed\" \\\n", + " --lora-finetune" + ] + }, + { + "cell_type": "code", + "execution_count": 24, "metadata": {}, "outputs": [], - "source": [] + "source": [ + "lora_checkpoint_path = f\"{work_dir}/lora-token-level-classification/dev/checkpoints/checkpoint-step=49-consumed_samples=100.0-last/weights\"\n", + "results_path = f\"{work_dir}/lora-token-level-classification/infer/\"" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "/workspace/bionemo2/esm2_finetune_tutorial/lora-token-level-classification/infer/\n" + ] + } + ], + "source": [ + "print(results_path)" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [], + "source": [ + "%%capture --no-display --no-stderr cell_output\n", + "data_path = os.path.join(work_dir, \"sequences.csv\")\n", + "\n", + "! infer_esm2 --checkpoint-path {pretrain_checkpoint_path} \\\n", + " --config-class ESM2FineTuneTokenConfig \\\n", + " --data-path {data_path} \\\n", + " --results-path {results_path} \\\n", + " --micro-batch-size 3 \\\n", + " --num-gpus 1 \\\n", + " --precision \"bf16-mixed\" \\\n", + " --include-embeddings \\\n", + " --include-hiddens \\\n", + " --include-input-ids \\\n", + " --lora-checkpoint-path {lora_checkpoint_path}" + ] } ], "metadata": { diff --git a/docs/docs/user-guide/examples/bionemo-esm2/inference.ipynb b/docs/docs/user-guide/examples/bionemo-esm2/inference.ipynb index 3abb9d5b9..a4d1931e7 100644 --- a/docs/docs/user-guide/examples/bionemo-esm2/inference.ipynb +++ b/docs/docs/user-guide/examples/bionemo-esm2/inference.ipynb @@ -142,35 +142,6 @@ "execution_count": 4, "metadata": {}, "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Downloading data from 'nvidia/clara/esm2nv650m:2.0' to file '/home/ubuntu/.cache/bionemo/0798767e843e3d54315aef91934d28ae7d8e93c2849d5fcfbdf5fac242013997-esm2_650M_nemo2.tar.gz'.\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "{\n", - " \"download_end\": \"2025-01-14 22:01:24\",\n", - " \"download_start\": \"2025-01-14 22:01:05\",\n", - " \"download_time\": \"18s\",\n", - " \"files_downloaded\": 1,\n", - " \"local_path\": \"/home/ubuntu/.cache/bionemo/tmpfj1e52vw/esm2nv650m_v2.0\",\n", - " \"size_downloaded\": \"1.12 GB\",\n", - " \"status\": \"COMPLETED\"\n", - "}\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Untarring contents of '/home/ubuntu/.cache/bionemo/0798767e843e3d54315aef91934d28ae7d8e93c2849d5fcfbdf5fac242013997-esm2_650M_nemo2.tar.gz' to '/home/ubuntu/.cache/bionemo/0798767e843e3d54315aef91934d28ae7d8e93c2849d5fcfbdf5fac242013997-esm2_650M_nemo2.tar.gz.untar'\n" - ] - }, { "name": "stdout", "output_type": "stream", @@ -267,14 +238,6 @@ "name": "stdout", "output_type": "stream", "text": [ - "2025-01-14 22:01:45 - faiss.loader - INFO - Loading faiss with AVX512 support.\n", - "2025-01-14 22:01:45 - faiss.loader - INFO - Successfully loaded faiss with AVX512 support.\n", - "[NeMo W 2025-01-14 22:01:46 nemo_logging:405] /usr/local/lib/python3.12/dist-packages/pydub/utils.py:170: RuntimeWarning: Couldn't find ffmpeg or avconv - defaulting to ffmpeg, but may not work\n", - " warn(\"Couldn't find ffmpeg or avconv - defaulting to ffmpeg, but may not work\", RuntimeWarning)\n", - " \n", - "[NeMo W 2025-01-14 22:01:46 nemo_logging:405] /usr/local/lib/python3.12/dist-packages/pyannote/core/notebook.py:134: MatplotlibDeprecationWarning: The get_cmap function was deprecated in Matplotlib 3.7 and will be removed in 3.11. Use ``matplotlib.colormaps[name]`` or ``matplotlib.colormaps.get_cmap()`` or ``pyplot.get_cmap()`` instead.\n", - " cm = get_cmap(\"Set1\")\n", - " \n", "usage: infer_esm2 [-h] --checkpoint-path CHECKPOINT_PATH --data-path DATA_PATH\n", " --results-path RESULTS_PATH\n", " [--precision {fp16,bf16,fp32,bf16-mixed,fp32-mixed,16-mixed,fp16-mixed,16,32}]\n", @@ -285,6 +248,7 @@ " [--prediction-interval {epoch,batch}] [--include-hiddens]\n", " [--include-input-ids] [--include-embeddings]\n", " [--include-logits] [--config-class CONFIG_CLASS]\n", + " [--lora-checkpoint-path LORA_CHECKPOINT_PATH]\n", "\n", "Infer ESM2.\n", "\n", @@ -326,7 +290,9 @@ " script should also provide similar support for picking\n", " different data modules for fine-tuning with different\n", " data types. Choices: dict_keys(['ESM2Config',\n", - " 'ESM2FineTuneSeqConfig', 'ESM2FineTuneTokenConfig'])\n" + " 'ESM2FineTuneSeqConfig', 'ESM2FineTuneTokenConfig'])\n", + " --lora-checkpoint-path LORA_CHECKPOINT_PATH\n", + " Path to the lora states to restore from.\n" ] } ], @@ -507,6 +473,147 @@ "mask = torch.isin(input_ids, torch.tensor(extra_indices))" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Inference with LoRA\n", + "Inference with LoRA is supported. This requires the original model weights along with the additional LoRA weights. " + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[INFO | pytorch_lightning.utilities.rank_zero]: GPU available: True (cuda), used: True\n", + "[INFO | pytorch_lightning.utilities.rank_zero]: TPU available: False, using: 0 TPU cores\n", + "[INFO | pytorch_lightning.utilities.rank_zero]: HPU available: False, using: 0 HPUs\n", + "[NeMo I 2025-04-25 04:47:42 nemo_logging:393] Fixing mis-match between ddp-config & mcore-optimizer config\n", + "[NeMo I 2025-04-25 04:47:42 nemo_logging:393] Rank 0 has data parallel group : [0]\n", + "[NeMo I 2025-04-25 04:47:42 nemo_logging:393] Rank 0 has combined group of data parallel and context parallel : [0]\n", + "[NeMo I 2025-04-25 04:47:42 nemo_logging:393] All data parallel group ranks with context parallel combined: [[0]]\n", + "[NeMo I 2025-04-25 04:47:42 nemo_logging:393] Ranks 0 has data parallel rank: 0\n", + "[NeMo I 2025-04-25 04:47:42 nemo_logging:393] Rank 0 has context parallel group: [0]\n", + "[NeMo I 2025-04-25 04:47:42 nemo_logging:393] All context parallel group ranks: [[0]]\n", + "[NeMo I 2025-04-25 04:47:42 nemo_logging:393] Ranks 0 has context parallel rank: 0\n", + "[NeMo I 2025-04-25 04:47:42 nemo_logging:393] Rank 0 has model parallel group: [0]\n", + "[NeMo I 2025-04-25 04:47:42 nemo_logging:393] All model parallel group ranks: [[0]]\n", + "[NeMo I 2025-04-25 04:47:42 nemo_logging:393] Rank 0 has tensor model parallel group: [0]\n", + "[NeMo I 2025-04-25 04:47:42 nemo_logging:393] All tensor model parallel group ranks: [[0]]\n", + "[NeMo I 2025-04-25 04:47:42 nemo_logging:393] Rank 0 has tensor model parallel rank: 0\n", + "[NeMo I 2025-04-25 04:47:43 nemo_logging:393] Rank 0 has pipeline model parallel group: [0]\n", + "[NeMo I 2025-04-25 04:47:43 nemo_logging:393] Rank 0 has embedding group: [0]\n", + "[NeMo I 2025-04-25 04:47:43 nemo_logging:393] All pipeline model parallel group ranks: [[0]]\n", + "[NeMo I 2025-04-25 04:47:43 nemo_logging:393] Rank 0 has pipeline model parallel rank 0\n", + "[NeMo I 2025-04-25 04:47:43 nemo_logging:393] All embedding group ranks: [[0]]\n", + "[NeMo I 2025-04-25 04:47:43 nemo_logging:393] Rank 0 has embedding rank: 0\n", + "[INFO | pytorch_lightning.utilities.rank_zero]: ----------------------------------------------------------------------------------------------------\n", + "distributed_backend=nccl\n", + "All distributed processes registered. Starting with 1 processes\n", + "----------------------------------------------------------------------------------------------------\n", + "\n", + "[NeMo I 2025-04-25 04:47:43 nemo_logging:393] Setting up ModelTransform for stage: TrainerFn.PREDICTING\n", + "[NeMo I 2025-04-25 04:47:43 nemo_logging:393] Found model_transform attribute on pl_module\n", + "[NeMo I 2025-04-25 04:47:43 nemo_logging:393] Set model_transform to: .wrapper at 0x70af0b61c220>\n", + "[WARNING | /workspaces/bionemo-framework/sub-packages/bionemo-llm/src/bionemo/llm/model/config.py]: Loading /home/ubuntu/.cache/bionemo/2957b2c36d5978d0f595d6f1b72104b312621cf0329209086537b613c1c96d16-esm2_hf_converted_8m_checkpoint.tar.gz.untar\n", + "[NeMo I 2025-04-25 04:47:43 nemo_logging:393] Padded vocab_size: 128, original vocab_size: 33, dummy tokens: 95.\n", + "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", + "[NeMo W 2025-04-25 04:47:43 nemo_logging:405] Could not copy Trainer's 'max_steps' to LR scheduler's 'max_steps'. If you are not using an LR scheduler, this warning can safely be ignored.\n", + "[NeMo I 2025-04-25 04:47:43 nemo_logging:393] > number of parameters on (tensor, pipeline) model parallel rank (0 ,0): 7542848\n", + "[NeMo I 2025-04-25 04:47:43 nemo_logging:393] Adding lora to: 0.module.module.module.encoder.layers.0.self_attention.linear_proj\n", + "[NeMo I 2025-04-25 04:47:43 nemo_logging:393] Adding lora to: 0.module.module.module.encoder.layers.0.self_attention.linear_qkv\n", + "[NeMo I 2025-04-25 04:47:43 nemo_logging:393] Adding lora to: 0.module.module.module.encoder.layers.0.mlp.linear_fc1\n", + "[NeMo I 2025-04-25 04:47:43 nemo_logging:393] Adding lora to: 0.module.module.module.encoder.layers.0.mlp.linear_fc2\n", + "[NeMo I 2025-04-25 04:47:43 nemo_logging:393] Adding lora to: 0.module.module.module.encoder.layers.1.self_attention.linear_proj\n", + "[NeMo I 2025-04-25 04:47:43 nemo_logging:393] Adding lora to: 0.module.module.module.encoder.layers.1.self_attention.linear_qkv\n", + "[NeMo I 2025-04-25 04:47:43 nemo_logging:393] Adding lora to: 0.module.module.module.encoder.layers.1.mlp.linear_fc1\n", + "[NeMo I 2025-04-25 04:47:43 nemo_logging:393] Adding lora to: 0.module.module.module.encoder.layers.1.mlp.linear_fc2\n", + "[NeMo I 2025-04-25 04:47:43 nemo_logging:393] Adding lora to: 0.module.module.module.encoder.layers.2.self_attention.linear_proj\n", + "[NeMo I 2025-04-25 04:47:43 nemo_logging:393] Adding lora to: 0.module.module.module.encoder.layers.2.self_attention.linear_qkv\n", + "[NeMo I 2025-04-25 04:47:43 nemo_logging:393] Adding lora to: 0.module.module.module.encoder.layers.2.mlp.linear_fc1\n", + "[NeMo I 2025-04-25 04:47:43 nemo_logging:393] Adding lora to: 0.module.module.module.encoder.layers.2.mlp.linear_fc2\n", + "[NeMo I 2025-04-25 04:47:43 nemo_logging:393] Adding lora to: 0.module.module.module.encoder.layers.3.self_attention.linear_proj\n", + "[NeMo I 2025-04-25 04:47:43 nemo_logging:393] Adding lora to: 0.module.module.module.encoder.layers.3.self_attention.linear_qkv\n", + "[NeMo I 2025-04-25 04:47:43 nemo_logging:393] Adding lora to: 0.module.module.module.encoder.layers.3.mlp.linear_fc1\n", + "[NeMo I 2025-04-25 04:47:43 nemo_logging:393] Adding lora to: 0.module.module.module.encoder.layers.3.mlp.linear_fc2\n", + "[NeMo I 2025-04-25 04:47:43 nemo_logging:393] Adding lora to: 0.module.module.module.encoder.layers.4.self_attention.linear_proj\n", + "[NeMo I 2025-04-25 04:47:43 nemo_logging:393] Adding lora to: 0.module.module.module.encoder.layers.4.self_attention.linear_qkv\n", + "[NeMo I 2025-04-25 04:47:43 nemo_logging:393] Adding lora to: 0.module.module.module.encoder.layers.4.mlp.linear_fc1\n", + "[NeMo I 2025-04-25 04:47:43 nemo_logging:393] Adding lora to: 0.module.module.module.encoder.layers.4.mlp.linear_fc2\n", + "[NeMo I 2025-04-25 04:47:43 nemo_logging:393] Adding lora to: 0.module.module.module.encoder.layers.5.self_attention.linear_proj\n", + "[NeMo I 2025-04-25 04:47:43 nemo_logging:393] Adding lora to: 0.module.module.module.encoder.layers.5.self_attention.linear_qkv\n", + "[NeMo I 2025-04-25 04:47:43 nemo_logging:393] Adding lora to: 0.module.module.module.encoder.layers.5.mlp.linear_fc1\n", + "[NeMo I 2025-04-25 04:47:43 nemo_logging:393] Adding lora to: 0.module.module.module.encoder.layers.5.mlp.linear_fc2\n", + "[NeMo I 2025-04-25 04:47:43 nemo_logging:393] After applying model_transform:\n", + " \n", + " | Name | Type | Params | Mode\n", + "---------------------------------------\n", + "0 | module | DDP | 8.5 M | eval\n", + "---------------------------------------\n", + "1.1 M Trainable params\n", + "7.4 M Non-trainable params\n", + "8.5 M Total params\n", + "34.104 Total estimated model params size (MB)\n", + "121 Modules in train mode\n", + "133 Modules in eval mode\n", + "[NeMo I 2025-04-25 04:47:43 nemo_logging:393] Loading adapters from /home/ubuntu/.cache/bionemo/ffce714478009c353f20e8c1ad8e49638128f0cc936ebe1d44c161ac89831dcb-finetuned_peft_weights.tar.gz.untar/weights\n", + "[NeMo I 2025-04-25 04:47:43 nemo_logging:393] Initializing model parallel\n", + "[NeMo I 2025-04-25 04:47:43 nemo_logging:393] > number of parameters on (tensor, pipeline) model parallel rank (0 ,0): 8525888\n", + "[NeMo I 2025-04-25 04:47:43 nemo_logging:393] > number of trainable parameters: 1086528 (12.74% of total)\n", + "[NeMo I 2025-04-25 04:47:43 nemo_logging:393] Using dist-ckpt load strategy.\n", + "[NeMo I 2025-04-25 04:47:43 nemo_logging:393] Global Checkpoint Load : Rank : 0 : Start time : 1745556463.638s : Time spent in load_checkpoint: 0.029s\n", + "[NeMo W 2025-04-25 04:47:43 nemo_logging:405] MegatronOptimizerModule not found in trainer callbacks. finalize_model_grads is not properly set up for PEFT.\n" + ] + } + ], + "source": [ + "# Load the base model and PEFT weights\n", + "\n", + "checkpoint_path = load(\"esm2/8m:2.0\")\n", + "lora_checkpoint_path = load(\"esm2/esm2_lora_weights:1.1\") / \"weights\"\n", + "\n", + "# Perform inference with LoRA\n", + "! infer_esm2 --checkpoint-path {checkpoint_path} \\\n", + " --data-path {data_path} \\\n", + " --results-path {work_dir} \\\n", + " --micro-batch-size 3 \\\n", + " --num-gpus 1 \\\n", + " --include-hiddens \\\n", + " --include-embeddings \\\n", + " --include-logits \\\n", + " --include-input-ids \\\n", + " --lora-checkpoint-path {lora_checkpoint_path}" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "token_logits\ttorch.Size([1024, 10, 128])\n", + "hidden_states\ttorch.Size([10, 1024, 320])\n", + "input_ids\ttorch.Size([10, 1024])\n", + "embeddings\ttorch.Size([10, 320])\n" + ] + } + ], + "source": [ + "results = torch.load(f\"{work_dir}/predictions__rank_0.pt\")\n", + "\n", + "for key, val in results.items():\n", + " if val is not None:\n", + " print(f\"{key}\\t{val.shape}\")" + ] + }, { "cell_type": "markdown", "metadata": {}, diff --git a/sub-packages/bionemo-core/src/bionemo/core/data/resources/esm2.yaml b/sub-packages/bionemo-core/src/bionemo/core/data/resources/esm2.yaml index ddc5033b3..108432ed1 100644 --- a/sub-packages/bionemo-core/src/bionemo/core/data/resources/esm2.yaml +++ b/sub-packages/bionemo-core/src/bionemo/core/data/resources/esm2.yaml @@ -84,3 +84,11 @@ sha256: 14ae3acfbf82218bc9e3e53d21a5b0594ba7c0369e169c9f1034e3fe4378d175 # pragma: allowlist secret owner: Farhad Ramezanghorbani description: Test data for ESM2 inference. + +- tag: esm2_lora_weights:1.1 + ngc: nvidia/clara/esm2_lora_weights:1.1 + ngc_registry: model + pbss: "s3://general-purpose/esm2/checkpoints/finetuned_peft_weights.tar.gz" + sha256: ffce714478009c353f20e8c1ad8e49638128f0cc936ebe1d44c161ac89831dcb # pragma: allowlist secret + owner: Polina Binder + description: Weights for a LoRA finetuned ESM2 model. diff --git a/sub-packages/bionemo-esm2/src/bionemo/esm2/model/finetune/peft.py b/sub-packages/bionemo-esm2/src/bionemo/esm2/model/finetune/peft.py index f5091f030..b576d1e17 100644 --- a/sub-packages/bionemo-esm2/src/bionemo/esm2/model/finetune/peft.py +++ b/sub-packages/bionemo-esm2/src/bionemo/esm2/model/finetune/peft.py @@ -28,6 +28,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import List, Optional + +import lightning.pytorch as pl from nemo.collections.llm import fn from nemo.collections.llm.fn.mixin import FNMixin from nemo.collections.llm.peft.lora import LoRA @@ -37,13 +40,66 @@ class ESM2LoRA(LoRA): """LoRA for the BioNeMo2 ESM Model.""" + def __init__( + self, + peft_ckpt_path: Optional[str] = None, + freeze_modules: List[str] = ["encoder", "embedding"], + *args, + **kwarg, + ): + """Initialize the LoRA Adapter. + + Args: + peft_ckpt_path: config for peft chekpoint. + freeze_modules: modules to freeze. + *args: args for the LoRA class. + **kwarg: kwargs for the LoRA class. + """ + super().__init__(*args, **kwarg) + self.freeze_modules = freeze_modules + self.peft_ckpt_path = peft_ckpt_path + + def setup(self, *args, **kwarg): + """Initialize the LoRA Adapter. Pass the peft_ckpt_path to the wrapped io. + + Args: + *args: args for the LoRA class. + **kwarg: kwargs for the LoRA class. + """ + super().setup(*args, **kwarg) + self.wrapped_io.adapter_ckpt_path = self.peft_ckpt_path + + def on_predict_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + """Event hook. + + Args: + trainer: The trainer object. + pl_module: The LightningModule object. + """ + self._maybe_apply_transform(trainer) + + def adapter_key_filter(self, key: str) -> bool: + """Given a key in the state dict, return whether the key is an adapter (or base model). + + Args: + key: the key to filter + """ + if isinstance(key, tuple): + return key[1].requires_grad + if "_extra_state" in key: + return False + return ( + (not any(substring in key for substring in self.freeze_modules)) + or ".adapter." in key + or key.endswith(".adapters") + ) + def __call__(self, model: nn.Module) -> nn.Module: """This method is called when the object is called as a function. Args: model: The input model. - Returns: The modified model. """ fn.walk(model, self.selective_freeze) @@ -64,6 +120,6 @@ def selective_freeze(self, m: nn.Module, name=None, prefix=None): See Also: nemo.collections.llm.fn.mixin.FNMixin """ - if name in ["encoder", "embedding"]: + if name in self.freeze_modules: FNMixin.freeze(m) return m diff --git a/sub-packages/bionemo-esm2/src/bionemo/esm2/model/model.py b/sub-packages/bionemo-esm2/src/bionemo/esm2/model/model.py index b9c82ed25..34a44b66f 100644 --- a/sub-packages/bionemo-esm2/src/bionemo/esm2/model/model.py +++ b/sub-packages/bionemo-esm2/src/bionemo/esm2/model/model.py @@ -169,7 +169,6 @@ def __init__( pre_process=self.pre_process, post_process=self.post_process, ) - # Output if post_process: # TODO: Make sure you are passing in the mpu_vocab_size properly diff --git a/sub-packages/bionemo-esm2/src/bionemo/esm2/scripts/finetune_esm2.py b/sub-packages/bionemo-esm2/src/bionemo/esm2/scripts/finetune_esm2.py index c390afccd..ac13fe047 100644 --- a/sub-packages/bionemo-esm2/src/bionemo/esm2/scripts/finetune_esm2.py +++ b/sub-packages/bionemo-esm2/src/bionemo/esm2/scripts/finetune_esm2.py @@ -19,6 +19,7 @@ from typing import Dict, List, Optional, Sequence, Tuple, Type, get_args from lightning.pytorch.callbacks import Callback, LearningRateMonitor, RichModelSummary +from megatron.core.dist_checkpointing.validation import StrictHandling from megatron.core.distributed import DistributedDataParallelConfig from megatron.core.optimizer import OptimizerConfig from nemo import lightning as nl @@ -35,6 +36,7 @@ InMemoryProteinDataset, InMemorySingleValueDataset, ) +from bionemo.esm2.model.finetune.peft import ESM2LoRA from bionemo.esm2.model.finetune.sequence_model import ESM2FineTuneSeqConfig from bionemo.esm2.model.finetune.token_model import ESM2FineTuneTokenConfig from bionemo.llm.model.biobert.lightning import biobert_lightning_module @@ -118,6 +120,8 @@ def train_model( grad_reduce_in_fp32: bool = False, ckpt_async_save: bool = True, label_column: str = "labels", + lora_checkpoint_path: Optional[str] = None, + lora_finetune: bool = False, ) -> Tuple[Path, Callback | None, nl.Trainer]: """Train an ESM2 model on UR data. @@ -180,6 +184,8 @@ def train_model( grad_reduce_in_fp32 (bool): gradient reduction in fp32 ckpt_async_save (bool): whether to save ckpt async. Set to False for federated learning label_column (str): name of label column in CSV data file. Defaults to `labels`. + lora_checkpoint_path (Optional[str]): path to the lora checkpoint file. + lora_finetune (bool): whether to use lora fine-tuning. """ # Create the result directory if it does not exist. result_dir.mkdir(parents=True, exist_ok=True) @@ -194,22 +200,32 @@ def train_model( pipeline_model_parallel_size=pipeline_model_parallel_size, ) + # Convert lora_checkpoint_path to string if it's a Path object + if lora_checkpoint_path is not None: + lora_checkpoint_path = str(lora_checkpoint_path) + + # Initialize LoRA adapter first if needed + peft = None + if lora_finetune: + peft = ESM2LoRA(peft_ckpt_path=lora_checkpoint_path) + strategy = nl.MegatronStrategy( tensor_model_parallel_size=tensor_model_parallel_size, pipeline_model_parallel_size=pipeline_model_parallel_size, + find_unused_parameters=True, + gradient_as_bucket_view=True, + ckpt_include_optimizer=True, + ckpt_async_save=ckpt_async_save, + ckpt_parallel_load=True, + ckpt_load_strictness=StrictHandling.LOG_UNEXPECTED, ddp=DistributedDataParallelConfig( check_for_nan_in_grad=True, overlap_grad_reduce=overlap_grad_reduce, overlap_param_gather=overlap_param_gather, average_in_collective=average_in_collective, grad_reduce_in_fp32=grad_reduce_in_fp32, - use_distributed_optimizer=True, + use_distributed_optimizer=False, ), - find_unused_parameters=True, - gradient_as_bucket_view=True, - ckpt_include_optimizer=True, - ckpt_async_save=ckpt_async_save, - ckpt_parallel_load=True, ) # for wandb integration @@ -244,6 +260,8 @@ def train_model( start_step=nsys_start_step, end_step=nsys_end_step, ranks=nsys_ranks, gen_shape=True ) ) + if peft is not None: + callbacks.append(peft) trainer = nl.Trainer( devices=devices, @@ -342,7 +360,12 @@ def train_model( optimizer.scale_lr_cond = lambda name, param: scale_lr_layer in name optimizer.lr_mult = lr_multiplier - module = biobert_lightning_module(config=config, tokenizer=tokenizer, optimizer=optimizer) + if peft is not None: + module = biobert_lightning_module( + config=config, tokenizer=tokenizer, optimizer=optimizer, model_transform=peft + ) + else: + module = biobert_lightning_module(config=config, tokenizer=tokenizer, optimizer=optimizer) # Configure our custom Checkpointer checkpoint_callback = nl_callbacks.ModelCheckpoint( @@ -352,6 +375,8 @@ def train_model( every_n_train_steps=val_check_interval, always_save_context=True, # Enables the .nemo file-like checkpointing where all IOMixins are under SerDe filename="checkpoint-{step}-{consumed_samples}", # Including step and consumed_samples in the checkpoint filename prevents duplicate filenames and bugs related to this. + save_weights_only=False, + save_optim_on_train_end=True, ) # Setup the logger and train the model @@ -362,7 +387,6 @@ def train_model( wandb_config=wandb_config, ckpt_callback=checkpoint_callback, ) - llm.train( model=module, data=data_module, @@ -373,6 +397,7 @@ def train_model( resume_ignore_no_checkpoint=True, # When false this will throw an error with no existing checkpoint. ), ) + ckpt_path = Path(checkpoint_callback.last_model_path.replace(".ckpt", "")) return ckpt_path, metric_tracker, trainer @@ -386,8 +411,10 @@ def finetune_esm2_entrypoint(): # to avoid padding for single value labels: if args.min_seq_length is not None and args.datset_class is InMemorySingleValueDataset: parser.error("Arguments --min-seq-length cannot be set when using InMemorySingleValueDataset.") + if args.lora_checkpoint_path and not args.lora_finetune: + parser.error("Arguments --lora=checkpoint-path cannot be set when not using lora-finetune.") - # 2. Call pretrain with args + # 2. Call training with args train_model( train_data_path=args.train_data_path, valid_data_path=args.valid_data_path, @@ -445,6 +472,8 @@ def finetune_esm2_entrypoint(): grad_reduce_in_fp32=args.grad_reduce_in_fp32, ckpt_async_save=not args.avoid_ckpt_async_save, label_column=args.label_column, + lora_checkpoint_path=args.lora_checkpoint_path, + lora_finetune=args.lora_finetune, ) @@ -703,6 +732,22 @@ def get_parser(): default=None, help="Path to the checkpoint directory to restore from. Will override `--resume-if-exists` when set.", ) + + parser.add_argument( + "--lora-finetune", + action="store_true", + default=False, + help="Perform fine-tuning with LoRA.", + ) + + parser.add_argument( + "--lora-checkpoint-path", + type=str, + required=False, + default=None, + help="Path to the LoRA states to restore from.", + ) + parser.add_argument( "--nsys-profiling", action="store_true", @@ -760,6 +805,14 @@ def get_parser(): default=False, ) + parser.add_argument( + "--clip-grad", + type=float, + required=False, + default=1.0, + help="Gradient clipping based on global L2 norm. Default is 1.0", + ) + config_class_options: Dict[str, Type[BioBertConfig]] = SUPPORTED_CONFIGS def config_class_type(desc: str) -> Type[BioBertConfig]: @@ -797,6 +850,8 @@ def dataset_class_type(desc: str) -> Type[InMemoryProteinDataset]: default=InMemorySingleValueDataset, help=f"Dataset class name for finetuning. Choices: {config_class_options.keys()}", ) + parser.add_argument("--seed", type=int, default=43, help="Random seed.") + return parser diff --git a/sub-packages/bionemo-esm2/src/bionemo/esm2/scripts/infer_esm2.py b/sub-packages/bionemo-esm2/src/bionemo/esm2/scripts/infer_esm2.py index 5dbb86aac..82016739a 100644 --- a/sub-packages/bionemo-esm2/src/bionemo/esm2/scripts/infer_esm2.py +++ b/sub-packages/bionemo-esm2/src/bionemo/esm2/scripts/infer_esm2.py @@ -16,7 +16,7 @@ import argparse import os from pathlib import Path -from typing import Dict, Sequence, Type, get_args +from typing import Dict, Optional, Sequence, Type, get_args from nemo import lightning as nl @@ -25,6 +25,7 @@ from bionemo.esm2.data.tokenizer import get_tokenizer from bionemo.esm2.model.finetune.datamodule import ESM2FineTuneDataModule from bionemo.esm2.model.finetune.dataset import InMemoryProteinDataset +from bionemo.esm2.model.finetune.peft import ESM2LoRA from bionemo.esm2.model.finetune.sequence_model import ESM2FineTuneSeqConfig from bionemo.esm2.model.finetune.token_model import ESM2FineTuneTokenConfig from bionemo.llm.model.biobert.lightning import biobert_lightning_module @@ -60,6 +61,7 @@ def infer_model( num_nodes: int = 1, prediction_interval: IntervalT = "epoch", config_class: Type[BioBertConfig] = ESM2Config, + lora_checkpoint_path: Optional[str] = None, ) -> None: """Runs inference on a BioNeMo ESM2 model using PyTorch Lightning. @@ -80,6 +82,7 @@ def infer_model( num_nodes (int, optional): Number of nodes to use for distributed inference. Defaults to 1. prediction_interval (IntervalT, optional): Intervals to write predict method output into disck for DDP inference. Defaults to epoch. config_class (Type[BioBertConfig]): The config class for configuring the model using checkpoint provided + lora_checkpoint_path (Optional[str]): path to the lora checkpoint file. """ # create the directory to save the inference results os.makedirs(results_path, exist_ok=True) @@ -98,19 +101,13 @@ def infer_model( pipeline_model_parallel_size=pipeline_model_parallel_size, ddp="megatron", find_unused_parameters=True, + ckpt_parallel_load=True, ) prediction_writer = PredictionWriter(output_dir=results_path, write_interval=prediction_interval) + callbacks = [prediction_writer] - trainer = nl.Trainer( - accelerator="gpu", - devices=devices, - strategy=strategy, - num_nodes=num_nodes, - callbacks=[prediction_writer], - plugins=nl.MegatronMixedPrecision(precision=precision), - ) - + # Setup data dataset = InMemoryProteinDataset.from_csv(data_path, ignore_labels=True) datamodule = ESM2FineTuneDataModule( predict_dataset=dataset, @@ -119,6 +116,7 @@ def infer_model( min_seq_length=min_seq_length, ) + # Setup model config = config_class( params_dtype=get_autocast_dtype(precision), pipeline_dtype=get_autocast_dtype(precision), @@ -130,15 +128,36 @@ def infer_model( tensor_model_parallel_size=tensor_model_parallel_size, pipeline_model_parallel_size=pipeline_model_parallel_size, initial_ckpt_path=str(checkpoint_path), - initial_ckpt_skip_keys_with_these_prefixes=[], # load everything from the checkpoint. ) tokenizer = get_tokenizer() - module = biobert_lightning_module(config=config, tokenizer=tokenizer) - # datamodule is responsible for transforming dataloaders by adding MegatronDataSampler. Alternatively, to - # directly use dataloader in predict method, the data sampler should be included in MegatronStrategy - trainer.predict(module, datamodule=datamodule) # return_predictions=False failing due to a lightning bug + # Initialize LoRA adapter if needed + # Initialize base model with or without LoRA + + if lora_checkpoint_path: + peft = ESM2LoRA(peft_ckpt_path=lora_checkpoint_path) + callbacks.append(peft) + module = biobert_lightning_module(config=config, tokenizer=tokenizer, model_transform=peft) + module.configure_init_model_parallel = True + else: + module = biobert_lightning_module(config=config, tokenizer=tokenizer) + # In this case, the weights of the heads will be in the fine-tuned files and should be read + # from there as opposed to the base model checkpoint. + config_class.initial_ckpt_skip_keys_with_these_prefixes = [] + + trainer = nl.Trainer( + accelerator="gpu", + devices=devices, + strategy=strategy, + num_nodes=num_nodes, + callbacks=callbacks, + plugins=nl.MegatronMixedPrecision(precision=precision), + max_steps=100, + ) + + # Run prediction + trainer.predict(module, datamodule=datamodule) def infer_esm2_entrypoint(): @@ -162,6 +181,7 @@ def infer_esm2_entrypoint(): devices=args.num_gpus, num_nodes=args.num_nodes, config_class=args.config_class, + lora_checkpoint_path=args.lora_checkpoint_path, ) @@ -274,6 +294,14 @@ def config_class_type(desc: str) -> Type[BioBertConfig]: "and alternative loss. In the future this script should also provide similar support for picking different data " f"modules for fine-tuning with different data types. Choices: {config_class_options.keys()}", ) + parser.add_argument( + "--lora-checkpoint-path", + type=Path, + required=False, + default=None, + help="Path to the lora states to restore from.", + ) + return parser diff --git a/sub-packages/bionemo-esm2/tests/bionemo/esm2/conftest.py b/sub-packages/bionemo-esm2/tests/bionemo/esm2/conftest.py index 7e7dfb1b2..4cfa1e179 100644 --- a/sub-packages/bionemo-esm2/tests/bionemo/esm2/conftest.py +++ b/sub-packages/bionemo-esm2/tests/bionemo/esm2/conftest.py @@ -14,12 +14,30 @@ # limitations under the License. +import pandas as pd import pytest +import torch from bionemo.esm2.data.tokenizer import get_tokenizer from bionemo.testing.data.esm2 import create_mock_parquet_train_val_inputs, create_mock_protein_dataset +@pytest.fixture +def data_to_csv(): + """Create a mock protein dataset.""" + + def _data_to_csv(data, path): + csv_file = path / "protein_dataset.csv" + # Create a DataFrame + df = pd.DataFrame(data, columns=["sequences", "labels"]) + + # Save the DataFrame to a CSV file + df.to_csv(csv_file, index=False) + return csv_file + + return _data_to_csv + + @pytest.fixture def tokenizer(): """Return the ESM2 tokenizer.""" @@ -109,3 +127,41 @@ def dummy_protein_sequences(dummy_data_per_token_classification_ft): """ data = [seq for seq, _ in dummy_data_per_token_classification_ft] return data + + +@pytest.fixture +def load_dcp(): + """Fixture to load distributed checkpoints. + + Returns: + Callable: A function that takes a checkpoint directory path and returns the loaded state dict. + """ + if not torch.cuda.is_available(): + pytest.skip("Distributed checkpoint loading requires CUDA") + + def _load_dcp(ckpt_dir): + from pathlib import Path + + import torch.distributed.checkpoint as dcp + from torch.distributed.checkpoint import FileSystemReader + + if not isinstance(ckpt_dir, Path): + ckpt_dir = Path(ckpt_dir) + fs_reader = FileSystemReader(ckpt_dir) + metadata = fs_reader.read_metadata() + + # Create tensors directly on GPU + state_dict = { + k: torch.empty(tp.size, dtype=tp.properties.dtype, device="cuda") + for k, tp in metadata.state_dict_metadata.items() + if type(tp).__name__ == "TensorStorageMetadata" + and not any(keyword in k for keyword in {"head", "adapter", "optimizer", "output"}) + } + + dcp.load( + state_dict, + storage_reader=fs_reader, + ) + return state_dict + + return _load_dcp diff --git a/sub-packages/bionemo-esm2/tests/bionemo/esm2/scripts/test_finetune_esm2.py b/sub-packages/bionemo-esm2/tests/bionemo/esm2/scripts/test_finetune_esm2.py index 488178d88..d48355bb7 100644 --- a/sub-packages/bionemo-esm2/tests/bionemo/esm2/scripts/test_finetune_esm2.py +++ b/sub-packages/bionemo-esm2/tests/bionemo/esm2/scripts/test_finetune_esm2.py @@ -17,7 +17,6 @@ from pathlib import Path from unittest.mock import patch -import pandas as pd import pytest from nemo.lightning import io @@ -30,22 +29,16 @@ from bionemo.testing.callbacks import MetricTracker -def data_to_csv(data, tmp_path): - """Create a mock protein dataset.""" - csv_file = tmp_path / "protein_dataset.csv" - # Create a DataFrame - df = pd.DataFrame(data, columns=["sequences", "labels"]) - - # Save the DataFrame to a CSV file - df.to_csv(csv_file, index=False) - return csv_file - - +@pytest.mark.needs_gpu @pytest.mark.parametrize("encoder_frozen", [True, False]) +@pytest.mark.parametrize("with_peft", [True, False]) def test_esm2_finetune_token_classifier( tmp_path, dummy_data_per_token_classification_ft, encoder_frozen, + with_peft, + load_dcp, + data_to_csv, n_steps_train: int = 50, seed: int = 42, ): @@ -77,6 +70,7 @@ def test_esm2_finetune_token_classifier( dataset_class=InMemoryPerTokenValueDataset, config_class=ESM2FineTuneTokenConfig, metric_tracker=MetricTracker(metrics_to_track_val=["loss"], metrics_to_track_train=["loss"]), + lora_finetune=with_peft, ) weights_ckpt = simple_ft_checkpoint / "weights" @@ -86,20 +80,36 @@ def test_esm2_finetune_token_classifier( assert simple_ft_metrics.collection_train["loss"][0] > simple_ft_metrics.collection_train["loss"][-1] assert "val_acc" in trainer.logged_metrics # assert trainer.logged_metrics["val_acc"].item() <= 0.5 # TODO @farhad for a reasonable value - encoder_requires_grad = [ p.requires_grad for name, p in trainer.model.named_parameters() if "classification_head" not in name ] - assert not all(encoder_requires_grad) == encoder_frozen, ( - f"Conflict in param requires_grad when encoder_frozen={encoder_frozen}" - ) + if with_peft: + assert trainer.model.model_transform is not None + model = trainer.model[0].module.module.module + assert all(not p.requires_grad for p in model.embedding.parameters()) + assert all(not p.requires_grad for name, p in model.encoder.named_parameters() if "adapter" not in name) + assert all(p.requires_grad for name, p in model.encoder.named_parameters() if "adapter" in name) + assert all(p.requires_grad for p in model.classification_head.parameters()) + weight_param_dict = load_dcp(weights_ckpt) + for param in weight_param_dict.keys(): + assert any(keyword in param for keyword in {"head", "adapter", "optimizer", "output"}) + else: + assert not all(encoder_requires_grad) == encoder_frozen, ( + f"Conflict in param requires_grad when encoder_frozen={encoder_frozen}" + ) + +@pytest.mark.needs_gpu @pytest.mark.parametrize("encoder_frozen", [True, False]) +@pytest.mark.parametrize("with_peft", [True, False]) def test_esm2_finetune_regressor( tmp_path, dummy_data_single_value_regression_ft, encoder_frozen, + with_peft, + load_dcp, + data_to_csv, n_steps_train: int = 50, seed: int = 42, ): @@ -131,6 +141,7 @@ def test_esm2_finetune_regressor( dataset_class=InMemorySingleValueDataset, config_class=ESM2FineTuneSeqConfig, metric_tracker=MetricTracker(metrics_to_track_val=["loss"], metrics_to_track_train=["loss"]), + lora_finetune=with_peft, ) weights_ckpt = simple_ft_checkpoint / "weights" @@ -141,19 +152,37 @@ def test_esm2_finetune_regressor( assert "val_mse" in trainer.logged_metrics # assert trainer.logged_metrics["val_mse"].item() <= 0.5 # TODO @farhadrgh for a reasonable value - encoder_requires_grad = [ - p.requires_grad for name, p in trainer.model.named_parameters() if "regression_head" not in name - ] - assert not all(encoder_requires_grad) == encoder_frozen, ( - f"Conflict in param requires_grad when encoder_frozen={encoder_frozen}" - ) + if with_peft: + assert trainer.model.model_transform is not None + model = trainer.model[0].module.module.module + assert all(not p.requires_grad for p in model.embedding.parameters()) + assert all(not p.requires_grad for name, p in model.encoder.named_parameters() if "adapter" not in name) + assert all(p.requires_grad for name, p in model.encoder.named_parameters() if "adapter" in name) + assert all(p.requires_grad for p in model.regression_head.parameters()) + + weight_param_dict = load_dcp(weights_ckpt) + for param in weight_param_dict.keys(): + assert any(keyword in param for keyword in {"head", "adapter", "optimizer", "output"}) + + else: + encoder_requires_grad = [ + p.requires_grad for name, p in trainer.model.named_parameters() if "regression_head" not in name + ] + assert not all(encoder_requires_grad) == encoder_frozen, ( + f"Conflict in param requires_grad when encoder_frozen={encoder_frozen}" + ) +@pytest.mark.needs_gpu @pytest.mark.parametrize("encoder_frozen", [True, False]) +@pytest.mark.parametrize("with_peft", [True, False]) def test_esm2_finetune_classifier( tmp_path, dummy_data_single_value_classification_ft, encoder_frozen, + with_peft, + load_dcp, + data_to_csv, n_steps_train: int = 50, seed: int = 42, ): @@ -186,6 +215,7 @@ def test_esm2_finetune_classifier( dataset_class=InMemorySingleValueDataset, config_class=ESM2FineTuneSeqConfig, metric_tracker=MetricTracker(metrics_to_track_val=["loss"], metrics_to_track_train=["loss"]), + lora_finetune=with_peft, ) weights_ckpt = simple_ft_checkpoint / "weights" @@ -196,12 +226,26 @@ def test_esm2_finetune_classifier( assert "val_acc" in trainer.logged_metrics # assert trainer.logged_metrics["val_acc"].item() <= 0.5 # TODO @farhadrgh for a reasonable value - encoder_requires_grad = [ - p.requires_grad for name, p in trainer.model.named_parameters() if "classification_head" not in name - ] - assert not all(encoder_requires_grad) == encoder_frozen, ( - f"Conflict in param requires_grad when encoder_frozen={encoder_frozen}" - ) + if with_peft: + assert trainer.model.model_transform is not None + model = trainer.model[0].module.module.module + assert all(not p.requires_grad for p in model.embedding.parameters()) + assert all(not p.requires_grad for name, p in model.encoder.named_parameters() if "adapter" not in name) + assert all(p.requires_grad for name, p in model.encoder.named_parameters() if "adapter" in name) + assert all(p.requires_grad for p in model.classification_head.parameters()) + + weight_param_dict = load_dcp(weights_ckpt) + for param in weight_param_dict.keys(): + assert any(keyword in param for keyword in {"head", "adapter", "optimizer", "output"}) + + else: + encoder_requires_grad = [ + p.requires_grad for name, p in trainer.model.named_parameters() if "classification_head" not in name + ] + + assert not all(encoder_requires_grad) == encoder_frozen, ( + f"Conflict in param requires_grad when encoder_frozen={encoder_frozen}" + ) @pytest.fixture @@ -392,3 +436,15 @@ def test_get_parser(): assert args.encoder_frozen is True assert args.lr_multiplier == 100 assert args.scale_lr_layer == "dummy_layer" + + +def r_data_to_csv(data, path): + import pandas as pd + + csv_file = path / "protein_dataset.csv" + # Create a DataFrame + df = pd.DataFrame(data, columns=["sequences", "labels"]) + + # Save the DataFrame to a CSV file + df.to_csv(csv_file, index=False) + return csv_file diff --git a/sub-packages/bionemo-esm2/tests/bionemo/esm2/scripts/test_infer_esm2.py b/sub-packages/bionemo-esm2/tests/bionemo/esm2/scripts/test_infer_esm2.py index b3c349b8f..1a7974a20 100644 --- a/sub-packages/bionemo-esm2/tests/bionemo/esm2/scripts/test_infer_esm2.py +++ b/sub-packages/bionemo-esm2/tests/bionemo/esm2/scripts/test_infer_esm2.py @@ -24,6 +24,7 @@ from bionemo.core.utils.dtypes import get_autocast_dtype from bionemo.esm2.api import ESM2Config from bionemo.esm2.data.tokenizer import get_tokenizer +from bionemo.esm2.model.finetune.token_model import ESM2FineTuneTokenConfig from bionemo.esm2.scripts.infer_esm2 import infer_model from bionemo.llm.data import collate from bionemo.llm.lightning import batch_collator @@ -53,23 +54,32 @@ def padded_tokenized_sequences(dummy_protein_sequences): return collated_batch["text"] -@pytest.mark.parametrize("precision", ["fp32", "bf16-mixed"]) +@pytest.mark.needs_gpu @pytest.mark.parametrize("prediction_interval", get_args(IntervalT)) +@pytest.mark.parametrize("precision", ["fp32", "bf16-mixed"]) +@pytest.mark.parametrize("with_peft", [True, False]) +@pytest.mark.parametrize("config_class", [ESM2Config, ESM2FineTuneTokenConfig]) def test_infer_runs( tmpdir, dummy_protein_csv, dummy_protein_sequences, precision, prediction_interval, + with_peft, padded_tokenized_sequences, + config_class, ): + checkpoint_path = load("esm2/8m:2.0") data_path = dummy_protein_csv result_dir = tmpdir / "results" min_seq_len = 1024 # Minimum length of the output batch; tensors will be padded to this length. - + if with_peft: + lora_checkpoint_path = load("esm2/esm2_lora_weights:1.1") / "weights" + else: + lora_checkpoint_path = None infer_model( data_path=data_path, - checkpoint_path=load("esm2/8m:2.0"), + checkpoint_path=checkpoint_path, results_path=result_dir, min_seq_length=min_seq_len, prediction_interval=prediction_interval, @@ -79,7 +89,8 @@ def test_infer_runs( include_input_ids=True, include_logits=True, micro_batch_size=3, # dataset length (10) is not multiple of 3; this validates partial batch inference - config_class=ESM2Config, + config_class=config_class, + lora_checkpoint_path=lora_checkpoint_path, ) assert result_dir.exists(), "Could not find test results directory." @@ -108,3 +119,75 @@ def test_infer_runs( # for accurate mapping post-inference. if prediction_interval == "epoch": assert torch.equal(padded_tokenized_sequences, results["input_ids"]) + + +@pytest.mark.needs_gpu +@pytest.mark.parametrize("prediction_interval", get_args(IntervalT)) +@pytest.mark.parametrize("precision", ["fp32", "bf16-mixed"]) +def test_different_results_with_peft( + tmpdir, + dummy_protein_csv, + dummy_protein_sequences, + precision, + prediction_interval, + padded_tokenized_sequences, +): + checkpoint_path = load("esm2/8m:2.0") + data_path = dummy_protein_csv + result_dir_original = tmpdir / "results_original" + min_seq_len = 1024 # Minimum length of the output batch; tensors will be padded to this length. + lora_checkpoint_path = None + infer_model( + data_path=data_path, + checkpoint_path=checkpoint_path, + results_path=result_dir_original, + min_seq_length=min_seq_len, + prediction_interval=prediction_interval, + include_hiddens=True, + precision=precision, + include_embeddings=True, + include_input_ids=True, + include_logits=True, + micro_batch_size=3, # dataset length (10) is not multiple of 3; this validates partial batch inference + config_class=ESM2Config, + lora_checkpoint_path=lora_checkpoint_path, + ) + assert result_dir_original.exists(), "Could not find test results directory." + result_dir_peft = tmpdir / "results_peft" + lora_checkpoint_path = load("esm2/esm2_lora_weights:1.1") / "weights" + infer_model( + data_path=data_path, + checkpoint_path=checkpoint_path, + results_path=result_dir_peft, + min_seq_length=min_seq_len, + prediction_interval=prediction_interval, + include_hiddens=True, + precision=precision, + include_embeddings=True, + include_input_ids=True, + include_logits=True, + micro_batch_size=3, # dataset length (10) is not multiple of 3; this validates partial batch inference + config_class=ESM2Config, + lora_checkpoint_path=lora_checkpoint_path, + ) + + if prediction_interval == "epoch": + results_original = torch.load(f"{result_dir_original}/predictions__rank_0.pt") + results_peft = torch.load(f"{result_dir_peft}/predictions__rank_0.pt") + + elif prediction_interval == "batch": + results_original = batch_collator( + [ + torch.load(f, map_location="cpu") + for f in glob.glob(f"{result_dir_original}/predictions__rank_0__batch_*.pt") + ] + ) + results_peft = batch_collator( + [ + torch.load(f, map_location="cpu") + for f in glob.glob(f"{result_dir_peft}/predictions__rank_0__batch_*.pt") + ] + ) + assert (results_original["embeddings"] != results_peft["embeddings"]).any() + assert (results_original["hidden_states"] != results_peft["hidden_states"]).any() + assert (results_original["token_logits"] != results_peft["token_logits"]).any() diff --git a/sub-packages/bionemo-llm/src/bionemo/llm/lightning.py b/sub-packages/bionemo-llm/src/bionemo/llm/lightning.py index 88196a7af..fb259262a 100644 --- a/sub-packages/bionemo-llm/src/bionemo/llm/lightning.py +++ b/sub-packages/bionemo-llm/src/bionemo/llm/lightning.py @@ -258,6 +258,7 @@ def __init__( data_step: DataStep, optimizer: MegatronOptimizerModule, model_transform: Optional[Callable[[MegatronModelType], MegatronModelType]] = None, + configure_init_model_parallel: bool = False, **model_construct_args, ) -> None: """Constructor. @@ -273,6 +274,7 @@ def __init__( model_construct_args: Optional. Any arguments necessary to construct the model in the `config`'s `configure_model` method. model_transform: Optional. The model transform function. + configure_init_model_parallel: Optional. Whether to initialize the model parallel at configuration time. **model_construct_args: Optional. Arguments necessary for the supplied model configuration's `configure_model` method, which will make an instance of the model. """ @@ -288,7 +290,7 @@ def __init__( self._data_step = data_step self._forward_step = forward_step self.model_transform = model_transform - + self.configure_init_model_parallel = configure_init_model_parallel # configure metrics self.train_metric = self.config.train_metric.get_instance() if self.config.train_metric else None self.valid_metric = self.config.valid_metric.get_instance() if self.config.valid_metric else None @@ -301,6 +303,8 @@ def configure_model(self) -> None: Raises: ValueError iff the internal config's configure_model method returns None. """ + if self.configure_init_model_parallel: + self.trainer.strategy._init_model_parallel = True if self.module is None: model: MegatronModelType = ( self.config.configure_model(**self.module_construct_args) @@ -308,7 +312,6 @@ def configure_model(self) -> None: else self.config.configure_model() ) self.module = model - if self.module is None: raise ValueError("Invalid semantics: configure_model method **MUST** initialize the model.") diff --git a/sub-packages/bionemo-llm/src/bionemo/llm/model/biobert/model.py b/sub-packages/bionemo-llm/src/bionemo/llm/model/biobert/model.py index f14f44fc7..c6687bffd 100644 --- a/sub-packages/bionemo-llm/src/bionemo/llm/model/biobert/model.py +++ b/sub-packages/bionemo-llm/src/bionemo/llm/model/biobert/model.py @@ -554,7 +554,6 @@ def configure_model(self, tokenizer: AutoTokenizer) -> MegatronBioBertModelType: if self.initial_ckpt_path: self.load_settings_from_checkpoint(self.initial_ckpt_path) - model = self.model_cls( self, transformer_layer_spec=get_biobert_spec(