diff --git a/bionemo-recipes/recipes/evo2_megatron/examples/.gitignore b/bionemo-recipes/recipes/evo2_megatron/examples/.gitignore index 055a26eb49..132f4bd7a4 100644 --- a/bionemo-recipes/recipes/evo2_megatron/examples/.gitignore +++ b/bionemo-recipes/recipes/evo2_megatron/examples/.gitignore @@ -8,6 +8,7 @@ # directories created during these notebook runs. nemo2_evo2_1b_8k/ +evo2_1b_bf16_mbridge/ preprocessed_data/ pretraining_demo/ brca1_fasta_files/ diff --git a/bionemo-recipes/recipes/evo2_megatron/examples/fine-tuning-tutorial.ipynb b/bionemo-recipes/recipes/evo2_megatron/examples/fine-tuning-tutorial.ipynb index 4b65525ce0..b3f5a5cf72 100644 --- a/bionemo-recipes/recipes/evo2_megatron/examples/fine-tuning-tutorial.ipynb +++ b/bionemo-recipes/recipes/evo2_megatron/examples/fine-tuning-tutorial.ipynb @@ -39,7 +39,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": { "jupyter": { "source_hidden": true @@ -61,6 +61,7 @@ " !rm -rf preprocessed_data\n", " !rm -rf preatraining_demo\n", " !rm -rf pretraining_demo\n", + " !rm -rf evo2_1b_bf16_mbridge\n", " !rm -rf training_data_config.yaml\n", " !rm -rf preprocess_config.yaml\n", " !rm -f chr20.fa.gz\n", @@ -114,10 +115,40 @@ "cell_type": "code", "execution_count": 3, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.12/dist-packages/torch/cuda/__init__.py:65: FutureWarning: The pynvml package is deprecated. Please install nvidia-ml-py instead. If you did not install pynvml directly, please report this to the maintainers of the package that installed pynvml for you.\n", + " import pynvml # type: ignore[import]\n", + "/workspaces/bionemo-framework/bionemo-recipes/recipes/evo2_megatron/.venv/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n", + "/usr/local/lib/python3.12/dist-packages/torch/library.py:356: UserWarning: Warning only once for all operators, other operators may also be overridden.\n", + " Overriding a previously registered kernel for the same operator and the same dispatch key\n", + " operator: flash_attn::_flash_attn_backward(Tensor dout, Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor(a6!)? dq, Tensor(a7!)? dk, Tensor(a8!)? dv, float dropout_p, float softmax_scale, bool causal, SymInt window_size_left, SymInt window_size_right, float softcap, Tensor? alibi_slopes, bool deterministic, Tensor? rng_state=None) -> Tensor\n", + " registered at /usr/local/lib/python3.12/dist-packages/torch/_library/custom_ops.py:922\n", + " dispatch key: ADInplaceOrView\n", + " previous kernel: no debug info\n", + " new kernel: registered at /usr/local/lib/python3.12/dist-packages/torch/_library/custom_ops.py:922 (Triggered internally at /opt/pytorch/pytorch/aten/src/ATen/core/dispatch/OperatorEntry.cpp:208.)\n", + " self.m.impl(\n", + "/usr/bin/ld: cannot find -laio: No such file or directory\n", + "collect2: error: ld returned 1 exit status\n", + "/usr/bin/ld: cannot find -laio: No such file or directory\n", + "collect2: error: ld returned 1 exit status\n" + ] + } + ], "source": [ + "from bionemo.evo2.data.dataset_tokenizer import (\n", + " DEFAULT_HF_TOKENIZER_MODEL_PATH, # use the 512 size for historical reasons\n", + ")\n", + "\n", + "\n", "full_fasta_path = os.path.abspath(concat_path)\n", "output_dir = os.path.abspath(\"preprocessed_data\")\n", + "\n", + "\n", "output_yaml = f\"\"\"\n", "- datapaths: [\"{full_fasta_path}\"]\n", " output_dir: \"{output_dir}\"\n", @@ -133,10 +164,7 @@ " transcribe: \"back_transcribe\"\n", " force_uppercase: false\n", " indexed_dataset_dtype: \"uint8\"\n", - " tokenizer_type: \"Byte-Level\"\n", - " vocab_file: null\n", - " vocab_size: null\n", - " merges_file: null\n", + " hf_tokenizer_model_path: {DEFAULT_HF_TOKENIZER_MODEL_PATH}\n", " pretrained_tokenizer_model: null\n", " special_tokens: null\n", " fast_hf_tokenizer: true\n", @@ -174,15 +202,15 @@ "output_type": "stream", "text": [ "total 309M\n", - "drwxr-xr-x 3 ubuntu ubuntu 4.0K Mar 10 22:17 chr20_21_22_uint8_distinct_byte-level_test\n", - "-rw-r--r-- 1 ubuntu ubuntu 90M Mar 10 23:07 chr20_21_22_uint8_distinct_byte-level_test.bin\n", - "-rw-r--r-- 1 ubuntu ubuntu 82 Mar 10 23:07 chr20_21_22_uint8_distinct_byte-level_test.idx\n", - "drwxr-xr-x 3 ubuntu ubuntu 4.0K Mar 10 22:17 chr20_21_22_uint8_distinct_byte-level_train\n", - "-rw-r--r-- 1 ubuntu ubuntu 123M Mar 10 23:06 chr20_21_22_uint8_distinct_byte-level_train.bin\n", - "-rw-r--r-- 1 ubuntu ubuntu 82 Mar 10 23:07 chr20_21_22_uint8_distinct_byte-level_train.idx\n", - "drwxr-xr-x 3 ubuntu ubuntu 4.0K Mar 10 22:17 chr20_21_22_uint8_distinct_byte-level_val\n", - "-rw-r--r-- 1 ubuntu ubuntu 97M Mar 10 23:06 chr20_21_22_uint8_distinct_byte-level_val.bin\n", - "-rw-r--r-- 1 ubuntu ubuntu 82 Mar 10 23:07 chr20_21_22_uint8_distinct_byte-level_val.idx\n" + "drwxr-xr-x 3 ubuntu ubuntu 4.0K Jan 14 18:48 chr20_21_22_uint8_distinct_nucleotide_fast_tokenizer_256_test\n", + "-rw-r--r-- 1 ubuntu ubuntu 90M Jan 15 00:46 chr20_21_22_uint8_distinct_nucleotide_fast_tokenizer_256_test.bin\n", + "-rw-r--r-- 1 ubuntu ubuntu 82 Jan 15 00:48 chr20_21_22_uint8_distinct_nucleotide_fast_tokenizer_256_test.idx\n", + "drwxr-xr-x 3 ubuntu ubuntu 4.0K Jan 14 18:48 chr20_21_22_uint8_distinct_nucleotide_fast_tokenizer_256_train\n", + "-rw-r--r-- 1 ubuntu ubuntu 123M Jan 15 00:48 chr20_21_22_uint8_distinct_nucleotide_fast_tokenizer_256_train.bin\n", + "-rw-r--r-- 1 ubuntu ubuntu 82 Jan 15 00:48 chr20_21_22_uint8_distinct_nucleotide_fast_tokenizer_256_train.idx\n", + "drwxr-xr-x 3 ubuntu ubuntu 4.0K Jan 14 18:48 chr20_21_22_uint8_distinct_nucleotide_fast_tokenizer_256_val\n", + "-rw-r--r-- 1 ubuntu ubuntu 97M Jan 15 00:44 chr20_21_22_uint8_distinct_nucleotide_fast_tokenizer_256_val.bin\n", + "-rw-r--r-- 1 ubuntu ubuntu 82 Jan 15 00:48 chr20_21_22_uint8_distinct_nucleotide_fast_tokenizer_256_val.idx\n" ] } ], @@ -197,10 +225,14 @@ "source": [ "### [Optional] specify or convert initial checkpoint\n", "The main difference between pre-training and fine-tuning is whether or not you decide to start training the model with\n", - "weights from a prior training run. For this tutorial we want to tune a `1b` checkpoint from hugging face that is known\n", - "(at the time of this writing) to be sensitive to GPU architecture so that it will work with your architecture. We have a\n", - "script that will download and convert a savanna format evo2 checkpoint from hugging face, and output that into a NeMo2\n", - "format checkpoint directory that can be used as the starting point for a fine-tuning run." + "weights from a prior training run. For this tutorial we want to tune a `1b` checkpoint that is known\n", + "(at the time of this writing) to be sensitive to GPU architecture so that it will work with your architecture.\n", + "\n", + "We use `bionemo.core.data.load` to download pre-trained NeMo2 checkpoints from NGC. These checkpoints can be used\n", + "directly with the `train_evo2` command for fine-tuning.\n", + "\n", + "**Note**: The `train_evo2` command produces MBridge format checkpoints that can be used directly with\n", + "`infer_evo2` or `predict_evo2` for inference - no conversion step is needed." ] }, { @@ -209,11 +241,33 @@ "metadata": {}, "outputs": [], "source": [ - "%%capture\n", - "if not os.path.exists(\"nemo2_evo2_1b_8k\"):\n", - " !evo2_convert_to_nemo2 \\\n", - " --model-path hf://arcinstitute/savanna_evo2_1b_base \\\n", - " --model-size 1b --output-dir nemo2_evo2_1b_8k" + "from pathlib import Path\n", + "\n", + "from bionemo.core.data.load import load\n", + "from bionemo.evo2.data.dataset_tokenizer import (\n", + " DEFAULT_HF_TOKENIZER_MODEL_PATH_512, # use the 512 size for historical reasons\n", + ")\n", + "\n", + "\n", + "# Download the 1b BF16 checkpoint from NGC\n", + "# Available checkpoints: evo2/1b-8k-bf16:1.0, evo2/1b-8k:1.0, evo2/7b-8k:1.0, evo2/7b-1m:1.0\n", + "mbridge_ckpt_path = Path(\"evo2_1b_bf16_mbridge\")\n", + "\n", + "if not mbridge_ckpt_path.exists():\n", + " nemo2_ckpt_path = load(\"evo2/1b-8k-bf16:1.0\")\n", + " mixed_precision_recipe = \"bf16_mixed\" # also try bf16_with_fp8_current_scaling_mixed\n", + " convert_ckpt_cmd = f\"\"\"evo2_convert_nemo2_to_mbridge \\\n", + " --nemo2-ckpt-dir {nemo2_ckpt_path} \\\n", + " --mbridge-ckpt-dir {mbridge_ckpt_path} \\\n", + " --model-size 1b \\\n", + " --mixed-precision-recipe {mixed_precision_recipe} \\\n", + " --seq-length 8192 \\\n", + " --tokenizer-path {DEFAULT_HF_TOKENIZER_MODEL_PATH_512} \\\n", + " \"\"\".rstrip()\n", + " print(f\"Running command: {convert_ckpt_cmd}\")\n", + "\n", + " result = run_subprocess_safely(convert_ckpt_cmd)\n", + " print(f\"Downloaded checkpoint to: {nemo2_ckpt_path} and converted to mbridge format at {mbridge_ckpt_path}\")" ] }, { @@ -227,14 +281,16 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "from pathlib import Path\n", "\n", "\n", - "output_pfx = str(Path(os.path.abspath(\"preprocessed_data\")) / \"chr20_21_22_uint8_distinct_byte-level\")\n", + "output_pfx = str(\n", + " Path(os.path.abspath(\"preprocessed_data\")) / \"chr20_21_22_uint8_distinct_nucleotide_fast_tokenizer_256\"\n", + ")\n", "output_yaml = f\"\"\"\n", "- dataset_prefix: {output_pfx}_train\n", " dataset_split: train\n", @@ -260,11 +316,14 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "%%capture\n", + "import torch\n", + "\n", + "\n", "MAX_STEPS: int = 10 if FAST_CI_MODE else 100\n", "val_check_interval = min(int(MAX_STEPS // 2), 50)\n", "warmup_steps = min(MAX_STEPS, 100)\n", @@ -278,30 +337,32 @@ "else:\n", " # By default do 5 layers of activation checkpointing\n", " model_subset_option = \"--activation-checkpoint-recompute-num-layers 5\"\n", - "train_cmd = f\"\"\"train_evo2 \\\n", + "num_gpus = torch.cuda.device_count()\n", + "# The 1b model is configured in a way that you can not use TP, but you can use CP.\n", + "train_cmd = f\"\"\"torchrun --nproc_per_node={num_gpus} --no-python train_evo2 \\\n", " -d training_data_config.yaml \\\n", " --dataset-dir ./preprocessed_data \\\n", " --result-dir pretraining_demo \\\n", " --experiment-name evo2 \\\n", + " --context-parallel-size {num_gpus} \\\n", " --model-size 1b \\\n", - " --devices 1 \\\n", - " --num-nodes 1 \\\n", " --seq-length 8192 \\\n", - " --micro-batch-size 2 \\\n", + " --micro-batch-size 1 \\\n", + " --global-batch-size 8 \\\n", + " --eval-iters 5 \\\n", + " --decay-steps 100000 --warmup-steps 10 \\\n", + " --hf-tokenizer-model-path {DEFAULT_HF_TOKENIZER_MODEL_PATH_512} \\\n", + " --eval-interval {val_check_interval} \\\n", " --lr 0.000015 \\\n", " --min-lr 0.0000149 \\\n", " --warmup-steps {warmup_steps} \\\n", - " --grad-acc-batches 4 \\\n", " --max-steps {MAX_STEPS} \\\n", - " --ckpt-dir nemo2_evo2_1b_8k \\\n", + " --finetune-ckpt-dir {mbridge_ckpt_path} \\\n", " --clip-grad 250 \\\n", " --wd 0.001 \\\n", " --attention-dropout 0.01 \\\n", " --hidden-dropout 0.01 \\\n", - " --val-check-interval {val_check_interval} \\\n", - " {model_subset_option} \\\n", - " --create-tensorboard-logger \\\n", - " --ckpt-async-save\"\"\"\n", + " {model_subset_option}\"\"\"\n", "\n", "print(f\"Running command: {train_cmd}\")\n", "\n", @@ -310,11 +371,16 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 16, "metadata": {}, "outputs": [], "source": [ - "assert result[\"returncode\"] == 0, result" + "if result[\"returncode\"] != 0:\n", + " print(\"================== STDOUT ==========================\")\n", + " print(result[\"stdout\"])\n", + " print(\"================== STDERR ==========================\")\n", + " print(result[\"stderr\"])\n", + " raise AssertionError(\"Training failed. See stdout and stderr above for more details.\")" ] }, { @@ -326,7 +392,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 17, "metadata": { "jupyter": { "source_hidden": true @@ -422,12 +488,404 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 18, "metadata": {}, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n", + "/tmp/ipykernel_3593112/2519380624.py:41: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n", + " df[tag] = df[\"step\"].map(step_to_value)\n" + ] + }, { "data": { - "image/png": "", + "image/png": "", "text/plain": [ "
" ] @@ -438,25 +896,34 @@ ], "source": [ "# Get the TensorBoard event file for the training run\n", - "log_dirs = !find pretraining_demo/evo2/dev -name \"events.out.tfevents*\"\n", + "log_dirs = !find pretraining_demo/evo2/tb_logs -name \"events.out.tfevents*\" | sort | tail -1\n", "tf_event_file = log_dirs[0]\n", "\n", "# Extract data from your event file\n", "df = tensorboard_to_dataframe(tf_event_file)\n", "# You can uncomment and modify this to plot multiple metrics once you see what's available\n", - "plot_multiple_training_metrics(df, [\"reduced_train_loss\", \"lr\", \"grad_norm\", \"val_loss\"])" + "plot_multiple_training_metrics(df, [\"lm loss\", \"learning-rate\", \"grad-norm\", \"lm loss validation\"])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Now you have a checkpoint that you can try out in place of the converted evo2 checkpoint in the BRCA-1 tutorial \n", - "(the path is displayed in the next code cell). To test your checkpoint, please supply the following path to the saved \n", - "checkpoint produced by this notebook as the `--ckpt-dir {checkpoint_path}`\n", - "argument to the `predict_evo2` command in the zero shot BRCA tutorial. For the 1b checkpoint you should see AUC above\n", - "0.73 if you successfully fine-tuned the checkpoint for your hardware, or to check that your hardware works with the \n", - "converted checkpoint from hugging face as is.\n", + "Now you have a fine-tuned checkpoint that you can use for inference. The `train_evo2` command produces\n", + "MBridge format checkpoints that work directly with the inference scripts.\n", + "\n", + "**Option 1: Use `predict_evo2` for log-probability scoring**\n", + "```bash\n", + "predict_evo2 --ckpt-dir --input-fasta sequences.fa --output-dir results/\n", + "```\n", + "\n", + "**Option 2: Use `infer_evo2` for text generation**\n", + "```bash\n", + "infer_evo2 --ckpt-dir --prompt \"ATCGATCG\" --max-new-tokens 100\n", + "```\n", + "\n", + "The checkpoint directory path is displayed in the next cell. You can also use this checkpoint for further\n", + "fine-tuning by passing it to `--ckpt-dir` in another `train_evo2` run.\n", "\n", "In our experience running this notebook for up to an hour on a single GPU is not sufficient to recover BF16 accuracy. We\n", "have more details about what did work in the Next Steps section below." @@ -464,24 +931,40 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 19, "metadata": {}, "outputs": [ { - "data": { - "text/plain": [ - "'pretraining_demo/default--val_loss=0.8664-epoch=0-consumed_samples=800.0-last'" - ] - }, - "execution_count": 14, - "metadata": {}, - "output_type": "execute_result" + "name": "stdout", + "output_type": "stream", + "text": [ + "Checkpoint directory: /workspaces/bionemo-framework/bionemo-recipes/recipes/evo2_megatron/examples/pretraining_demo/evo2/checkpoints\n", + "\n", + "Available checkpoints: ['iter_0000050', 'iter_0000100']\n", + "\n", + "You can now run inference with:\n", + " infer_evo2 --ckpt-dir pretraining_demo/evo2/checkpoints --prompt 'ATCGATCG' --max-new-tokens 100\n", + " predict_evo2 --ckpt-dir pretraining_demo/evo2/checkpoints --input-fasta --output-dir \n" + ] } ], "source": [ - "final_ckpt_paths = !ls -d pretraining_demo/evo2/checkpoints/*-last\n", - "final_ckpt_path = final_ckpt_paths[-1]\n", - "final_ckpt_path" + "from pathlib import Path\n", + "\n", + "\n", + "# The checkpoint directory contains all saved iterations\n", + "# The inference scripts automatically find the latest iteration\n", + "ckpt_dir = Path(\"pretraining_demo/evo2/checkpoints\")\n", + "print(f\"Checkpoint directory: {ckpt_dir.absolute()}\")\n", + "\n", + "# List available checkpoints\n", + "if ckpt_dir.exists():\n", + " checkpoints = list(ckpt_dir.glob(\"iter_*\"))\n", + " print(f\"\\nAvailable checkpoints: {[c.name for c in sorted(checkpoints)]}\")\n", + "\n", + "print(\"\\nYou can now run inference with:\")\n", + "print(f\" infer_evo2 --ckpt-dir {ckpt_dir} --prompt 'ATCGATCG' --max-new-tokens 100\")\n", + "print(f\" predict_evo2 --ckpt-dir {ckpt_dir} --input-fasta --output-dir \")" ] }, { @@ -499,7 +982,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 20, "metadata": { "jupyter": { "source_hidden": true @@ -665,7 +1148,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": ".venv", "language": "python", "name": "python3" }, diff --git a/bionemo-recipes/recipes/evo2_megatron/examples/zeroshot_brca1.ipynb b/bionemo-recipes/recipes/evo2_megatron/examples/zeroshot_brca1.ipynb index 0075a0a8ef..979b8a2589 100644 --- a/bionemo-recipes/recipes/evo2_megatron/examples/zeroshot_brca1.ipynb +++ b/bionemo-recipes/recipes/evo2_megatron/examples/zeroshot_brca1.ipynb @@ -1,1245 +1,1273 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Zero-shot prediction of BRCA1 variant effects with Evo 2\n", - "Deploy this tutorial on brev.dev: \n", - "[![ Click here to deploy.](https://brev-assets.s3.us-west-1.amazonaws.com/nv-lb-dark.svg)](https://console.brev.dev/launchable/deploy?launchableID=env-2uGqxNLgVdl752F2qcTFwHHn4Rj)\n", - "\n", - "*Note - this notebook is a reproduction of The Arc Institute’s same-titled notebook [here](https://github.com/ArcInstitute/evo2/blob/main/notebooks/brca1/brca1_zero_shot_vep.ipynb), using the BioNeMo 2 implementation of Evo2.*\n", - "\n", - "Evo2 is a foundation AI model trained on 9.3 trillion DNA base pairs, predicting variant effects without prior tast-specific training. \n", - "\n", - "Without being explicitly trained on BRCA1 variants, we show Evo 2's ability to generalize across all life forms.\n", - "\n", - "The human *BRCA1* gene encodes for a protein that repairs damaged DNA ([Moynahan et al., 1999](https://www.cell.com/molecular-cell/fulltext/S1097-2765%2800%2980202-6)). Certain variants of this gene have been associated with an increased risk of breast and ovarian cancers ([Miki et al., 1994](https://www.science.org/doi/10.1126/science.7545954?url_ver=Z39.88-2003&rfr_id=ori:rid:crossref.org&rfr_dat=cr_pub%20%200pubmed)). Using Evo 2, we can predict whether a particular single nucleotide variant (SNV) of the *BRCA1* gene is likely to be harmful to the protein's function, and thus potentially increase the risk of cancer for the patient with the genetic variant." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "%%capture\n", - "!pip install biopython openpyxl\n", - "import os\n", - "\n", - "\n", - "# Runs a subset of the model layers to test that the notebook runs in CI, but the output will be incorrect.\n", - "FAST_CI_MODE: bool = bool(int(os.environ.get(\"FAST_CI_MODE\", \"0\")))" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "import glob\n", - "import gzip\n", - "import json\n", - "import math\n", - "import os\n", - "from pathlib import Path\n", - "\n", - "import matplotlib.pyplot as plt\n", - "import pandas as pd\n", - "import seaborn as sns\n", - "import torch\n", - "from Bio import SeqIO\n", - "from sklearn.metrics import auc, roc_auc_score, roc_curve\n", - "\n", - "from bionemo.core.utils.subprocess_utils import run_subprocess_safely" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We start by loading a dataset from [Findlay et al. (2018)](https://www.nature.com/articles/s41586-018-0461-z), which contains experimentally measured function scores of 3,893 *BRCA1* SNVs. These function scores reflect the extent by which the genetic variant has disrupted the protein's function, with lower scores indicating greater disruption. In this dataset, the SNVs are classified into three categories based on their function scores: `LOF` (loss-of-function), `INT` (intermediate), and `FUNC` (functional). We start by reading in this dataset.\n", - "\n", - "To keep the notebook streamlined, we've abstracted much of the preprocessing logic into accompanying scripts located in `brca1_utils`. This notebook can also be viewed [here](https://docs.nvidia.com/bionemo-framework/latest/main/examples/bionemo-evo2/zeroshot_brca1/)." - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": { - "jupyter": { - "source_hidden": true + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Zero-shot prediction of BRCA1 variant effects with Evo 2\n", + "Deploy this tutorial on brev.dev: \n", + "[![ Click here to deploy.](https://brev-assets.s3.us-west-1.amazonaws.com/nv-lb-dark.svg)](https://console.brev.dev/launchable/deploy?launchableID=env-2uGqxNLgVdl752F2qcTFwHHn4Rj)\n", + "\n", + "*Note - this notebook is a reproduction of The Arc Institute’s same-titled notebook [here](https://github.com/ArcInstitute/evo2/blob/main/notebooks/brca1/brca1_zero_shot_vep.ipynb), using the BioNeMo 2 implementation of Evo2.*\n", + "\n", + "Evo2 is a foundation AI model trained on 9.3 trillion DNA base pairs, predicting variant effects without prior tast-specific training. \n", + "\n", + "Without being explicitly trained on BRCA1 variants, we show Evo 2's ability to generalize across all life forms.\n", + "\n", + "The human *BRCA1* gene encodes for a protein that repairs damaged DNA ([Moynahan et al., 1999](https://www.cell.com/molecular-cell/fulltext/S1097-2765%2800%2980202-6)). Certain variants of this gene have been associated with an increased risk of breast and ovarian cancers ([Miki et al., 1994](https://www.science.org/doi/10.1126/science.7545954?url_ver=Z39.88-2003&rfr_id=ori:rid:crossref.org&rfr_dat=cr_pub%20%200pubmed)). Using Evo 2, we can predict whether a particular single nucleotide variant (SNV) of the *BRCA1* gene is likely to be harmful to the protein's function, and thus potentially increase the risk of cancer for the patient with the genetic variant." + ] }, - "tags": [ - "hide-cell" - ] - }, - "outputs": [], - "source": [ - "def download_data(data_dir=\"brca1\", commit_hash=\"3819474bee6c24938016614411f1fa025e542bbe\"):\n", - " \"\"\"Download required data files if they don't exist locally.\n", - "\n", - " Parameters:\n", - " -----------\n", - " data_dir : str\n", - " Directory to store downloaded files\n", - " commit_hash : str\n", - " GitHub commit hash for data version\n", - " \"\"\"\n", - " if not os.path.exists(data_dir):\n", - " os.makedirs(data_dir)\n", - "\n", - " excel_path = os.path.join(data_dir, \"41586_2018_461_MOESM3_ESM.xlsx\")\n", - " genome_path = os.path.join(data_dir, \"GRCh37.p13_chr17.fna.gz\")\n", - "\n", - " if not os.path.exists(excel_path):\n", - " os.system(\n", - " f\"wget https://github.com/ArcInstitute/evo2/raw/{commit_hash}/notebooks/brca1/41586_2018_461_MOESM3_ESM.xlsx -O {excel_path}\"\n", - " )\n", - "\n", - " if not os.path.exists(genome_path):\n", - " os.system(\n", - " f\"wget https://github.com/ArcInstitute/evo2/raw/{commit_hash}/notebooks/brca1/GRCh37.p13_chr17.fna.gz -O {genome_path}\"\n", - " )\n", - "\n", - " return excel_path, genome_path\n", - "\n", - "\n", - "def load_genome_sequence(genome_path):\n", - " \"\"\"Load genome sequence from FASTA file.\n", - "\n", - " Parameters:\n", - " -----------\n", - " genome_path : str\n", - " Path to the genome FASTA file\n", - "\n", - " Returns:\n", - " --------\n", - " str\n", - " Genome sequence string\n", - " \"\"\"\n", - " with gzip.open(genome_path, \"rt\") as handle:\n", - " for record in SeqIO.parse(handle, \"fasta\"):\n", - " return str(record.seq)\n", - "\n", - " raise ValueError(\"Failed to parse genome sequence\")\n", - "\n", - "\n", - "def load_brca1_data(excel_path):\n", - " \"\"\"Load and preprocess BRCA1 data from Excel file.\n", - "\n", - " Parameters:\n", - " -----------\n", - " excel_path : str\n", - " Path to the Excel file\n", - "\n", - " Returns:\n", - " --------\n", - " pandas.DataFrame\n", - " Processed BRCA1 dataframe\n", - " \"\"\"\n", - " # Load the dataframe\n", - " brca1_df = pd.read_excel(excel_path, header=2)\n", - "\n", - " # Select and rename columns\n", - " brca1_df = brca1_df[\n", - " [\n", - " \"chromosome\",\n", - " \"position (hg19)\",\n", - " \"reference\",\n", - " \"alt\",\n", - " \"function.score.mean\",\n", - " \"func.class\",\n", - " ]\n", - " ]\n", - "\n", - " brca1_df.rename(\n", - " columns={\n", - " \"chromosome\": \"chrom\",\n", - " \"position (hg19)\": \"pos\",\n", - " \"reference\": \"ref\",\n", - " \"alt\": \"alt\",\n", - " \"function.score.mean\": \"score\",\n", - " \"func.class\": \"class\",\n", - " },\n", - " inplace=True,\n", - " )\n", - "\n", - " # Convert to two-class system\n", - " brca1_df[\"class\"] = brca1_df[\"class\"].replace([\"FUNC\", \"INT\"], \"FUNC/INT\")\n", - "\n", - " return brca1_df" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [], - "source": [ - "%%capture\n", - "\n", - "# Configuration parameters\n", - "DATA_DIR = \"brca1\"\n", - "SAMPLE_CONFIG = {\"sample_frac\": 0.05, \"balanced\": True, \"disable\": False, \"random_state\": 42}\n", - "\n", - "# 1. Download the necessary data files if not present\n", - "excel_path, genome_path = download_data(DATA_DIR)\n", - "seq_chr17 = load_genome_sequence(genome_path)\n", - "\n", - "# 2. Load and preprocess BRCA1 data\n", - "brca1_df = load_brca1_data(excel_path)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We then group the `FUNC` and `INT` classes of SNVs together into a single category (`FUNC/INT`).\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We build a function to parse the reference and variant sequences of a 8,192-bp window around the genomic position of each SNV, using the reference sequence of human chromosome 17 where *BRCA1* is located.\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "To make things run faster, we'll just look at a balanced sample of our data. If you want to run on the full dataset, set `disable_sample=True`" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": { - "jupyter": { - "source_hidden": true + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "%%capture\n", + "!pip install biopython openpyxl\n", + "import os\n", + "\n", + "\n", + "# Runs a subset of the model layers to test that the notebook runs in CI, but the output will be incorrect.\n", + "FAST_CI_MODE: bool = bool(int(os.environ.get(\"FAST_CI_MODE\", \"0\")))" + ] }, - "tags": [ - "hide-cell" - ] - }, - "outputs": [], - "source": [ - "def sample_data(df, sample_frac=1.0, balanced=True, disable=False, random_state=42):\n", - " \"\"\"Sample dataframe, optionally with balanced classes.\n", - "\n", - " Parameters:\n", - " -----------\n", - " df : pandas.DataFrame\n", - " Input dataframe\n", - " sample_frac : float\n", - " Fraction of data to sample\n", - " balanced : bool\n", - " Whether to balance classes\n", - " disable : bool\n", - " Whether to disable sampling\n", - " random_state : int\n", - " Random seed for reproducibility\n", - "\n", - " Returns:\n", - " --------\n", - " pandas.DataFrame\n", - " Sampled dataframe\n", - " \"\"\"\n", - " if disable:\n", - " return df\n", - "\n", - " if balanced:\n", - " # Get the number of rows in the dataframe\n", - " num_rows_minor_class = math.ceil(len(df[df[\"class\"] == \"LOF\"]) * sample_frac)\n", - " return (\n", - " pd.concat(\n", - " [\n", - " df[df[\"class\"] == \"LOF\"].sample(n=num_rows_minor_class, random_state=random_state),\n", - " df[df[\"class\"] == \"FUNC/INT\"].sample(n=num_rows_minor_class, random_state=random_state),\n", - " ]\n", - " )\n", - " .sample(frac=1.0, random_state=random_state)\n", - " .reset_index(drop=True)\n", - " )\n", - " else:\n", - " # Calculate the number of rows to sample\n", - " return df.sample(frac=sample_frac, random_state=random_state).reset_index(drop=True)" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
chromposrefaltscoreclass
01741199726TC0.159762FUNC/INT
11741209074TA-2.065569LOF
21741256913AC-0.847753FUNC/INT
31741219631TA-2.053739LOF
41741215965GA-1.671525LOF
\n", - "
" + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.12/dist-packages/torch/cuda/__init__.py:65: FutureWarning: The pynvml package is deprecated. Please install nvidia-ml-py instead. If you did not install pynvml directly, please report this to the maintainers of the package that installed pynvml for you.\n", + " import pynvml # type: ignore[import]\n" + ] + } ], - "text/plain": [ - " chrom pos ref alt score class\n", - "0 17 41199726 T C 0.159762 FUNC/INT\n", - "1 17 41209074 T A -2.065569 LOF\n", - "2 17 41256913 A C -0.847753 FUNC/INT\n", - "3 17 41219631 T A -2.053739 LOF\n", - "4 17 41215965 G A -1.671525 LOF" + "source": [ + "import glob\n", + "import gzip\n", + "import json\n", + "import math\n", + "import os\n", + "from pathlib import Path\n", + "\n", + "import matplotlib.pyplot as plt\n", + "import pandas as pd\n", + "import seaborn as sns\n", + "import torch\n", + "from Bio import SeqIO\n", + "from sklearn.metrics import auc, roc_auc_score, roc_curve\n", + "\n", + "from bionemo.core.utils.subprocess_utils import run_subprocess_safely" ] - }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "OUTPUT_DIR = \"brca1_fasta_files\"\n", - "\n", - "brca1_df = sample_data(\n", - " brca1_df,\n", - " sample_frac=SAMPLE_CONFIG[\"sample_frac\"],\n", - " balanced=SAMPLE_CONFIG[\"balanced\"],\n", - " disable=SAMPLE_CONFIG[\"disable\"],\n", - " random_state=SAMPLE_CONFIG[\"random_state\"],\n", - ")\n", - "\n", - "brca1_df.head(5)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Next, we'll write these to local `.fasta` files so we can use them for prediction below." - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": { - "jupyter": { - "source_hidden": true }, - "tags": [ - "hide-cell" - ] - }, - "outputs": [], - "source": [ - "def parse_sequences(pos, ref, alt, seq_chr17, window_size=8192):\n", - " \"\"\"Parse reference and variant sequences from the reference genome sequence.\n", - "\n", - " Parameters:\n", - " -----------\n", - " pos : int\n", - " Position (1-indexed)\n", - " ref : str\n", - " Reference base\n", - " alt : str\n", - " Alternate base\n", - " seq_chr17 : str\n", - " Full chromosome 17 sequence\n", - " window_size : int\n", - " Size of the sequence window to extract\n", - "\n", - " Returns:\n", - " --------\n", - " tuple\n", - " (reference_sequence, variant_sequence)\n", - " \"\"\"\n", - " p = pos - 1 # Convert to 0-indexed position\n", - " full_seq = seq_chr17\n", - "\n", - " ref_seq_start = max(0, p - window_size // 2)\n", - " ref_seq_end = min(len(full_seq), p + window_size // 2)\n", - " ref_seq = seq_chr17[ref_seq_start:ref_seq_end]\n", - " snv_pos_in_ref = min(window_size // 2, p)\n", - " var_seq = ref_seq[:snv_pos_in_ref] + alt + ref_seq[snv_pos_in_ref + 1 :]\n", - "\n", - " # Sanity checks\n", - " assert len(var_seq) == len(ref_seq)\n", - " assert ref_seq[snv_pos_in_ref] == ref\n", - " assert var_seq[snv_pos_in_ref] == alt\n", - "\n", - " return ref_seq, var_seq\n", - "\n", - "\n", - "def generate_fasta_files(df, seq_chr17, output_dir=\"brca1_fasta_files\", window_size=8192):\n", - " \"\"\"Generate FASTA files for reference and variant sequences.\n", - "\n", - " Parameters:\n", - " -----------\n", - " df : pandas.DataFrame\n", - " Dataframe with variant information\n", - " seq_chr17 : str\n", - " Chromosome 17 sequence\n", - " output_dir : str\n", - " Output directory for FASTA files\n", - " window_size : int\n", - " Size of sequence window\n", - "\n", - " Returns:\n", - " --------\n", - " pandas.DataFrame\n", - " Dataframe with added columns for FASTA names\n", - " \"\"\"\n", - " # Create output directory\n", - " output_dir = Path(output_dir)\n", - " output_dir.mkdir(parents=True, exist_ok=True)\n", - "\n", - " # Paths for output files\n", - " ref_fasta_path = output_dir / \"brca1_reference_sequences.fasta\"\n", - " var_fasta_path = output_dir / \"brca1_variant_sequences.fasta\"\n", - "\n", - " # Track unique sequences\n", - " ref_sequences = set()\n", - " var_sequences = set()\n", - " ref_seq_to_name = {}\n", - "\n", - " # Store unique sequences with metadata for writing\n", - " ref_entries = []\n", - " var_entries = []\n", - " ref_names = []\n", - " var_names = []\n", - "\n", - " # Collect unique reference and variant sequences\n", - " for idx, row in df.iterrows():\n", - " ref_seq, var_seq = parse_sequences(row[\"pos\"], row[\"ref\"], row[\"alt\"], seq_chr17, window_size)\n", - "\n", - " # Add to sets to ensure uniqueness\n", - " if ref_seq not in ref_sequences:\n", - " ref_sequences.add(ref_seq)\n", - " ref_name = f\"BRCA1_ref_pos_{row['pos']}_{row['ref']}_class_{row['class']}\"\n", - "\n", - " ref_entries.append(f\">{ref_name}\\n{ref_seq}\\n\")\n", - " ref_names.append(ref_name)\n", - " ref_seq_to_name[ref_seq] = ref_name\n", - " else:\n", - " ref_name = ref_seq_to_name[ref_seq]\n", - " ref_names.append(ref_name)\n", - "\n", - " if var_seq not in var_sequences:\n", - " var_sequences.add(var_seq)\n", - " var_name = f\"BRCA1_var_pos_{row['pos']}_{row['ref']}to{row['alt']}_class_{row['class']}\"\n", - "\n", - " var_entries.append(f\">{var_name}\\n{var_seq}\\n\")\n", - " var_names.append(var_name)\n", - " else:\n", - " assert False, \"Duplicate variant sequence\"\n", - "\n", - " # Write unique sequences to FASTA files\n", - " with open(ref_fasta_path, \"w\") as f:\n", - " f.writelines(ref_entries)\n", - "\n", - " with open(var_fasta_path, \"w\") as f:\n", - " f.writelines(var_entries)\n", - "\n", - " # Add FASTA names to dataframe\n", - " df_with_names = df.copy()\n", - " df_with_names[\"ref_fasta_name\"] = ref_names\n", - " df_with_names[\"var_fasta_name\"] = var_names\n", - "\n", - " print(f\"Total unique reference sequences: {len(ref_sequences)}\")\n", - " print(f\"Total unique variant sequences: {len(var_sequences)}\")\n", - "\n", - " return df_with_names" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "Total unique reference sequences: 79\n", - "Total unique variant sequences: 84\n" - ] - } - ], - "source": [ - "brca1_df = generate_fasta_files(brca1_df, seq_chr17, output_dir=OUTPUT_DIR)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Load Evo 2 Checkpoints" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "\n", - "Then, we load Evo 2 1B model, loading the Evo 2 weights from hugging face.\n", - "\n", - "*Note - for better performance, load the 7b model by setting `MODEL_SIZE=\"7b\"` which also works well GPUs that do not support FP8.*\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [], - "source": [ - "%%capture\n", - "MODEL_SIZE = \"1b\" # also try 7b if you have a GPU with more than 32GB of memory\n", - "\n", - "# Define checkpoint path\n", - "if MODEL_SIZE == \"1b\":\n", - " from bionemo.core.data.load import load\n", - "\n", - " # This line will download the checkpoint from NGC to your $HOME/.cache/bionemo directory and return the path.\n", - " # To do the same from the command line, use `CHECKPOINT_PATH=$(download_bionemo_data evo2/1b-8k-bf16:1.0)`\n", - " checkpoint_path = load(\"evo2/1b-8k-bf16:1.0\")\n", - "else:\n", - " checkpoint_path = Path(f\"nemo2_evo2_{MODEL_SIZE}_8k\")\n", - "\n", - " # Check if the directory does not exist or is empty\n", - " if not checkpoint_path.exists() or not any(checkpoint_path.iterdir()):\n", - " !evo2_convert_to_nemo2 --model-path hf://arcinstitute/savanna_evo2_{MODEL_SIZE}_base --model-size {MODEL_SIZE} --output-dir nemo2_evo2_{MODEL_SIZE}_8k\n", - " else:\n", - " print(\"Checkpoint directory is not empty. Skipping command.\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Score Sequences" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Next, we score the likelihoods of the reference and variant sequences of each SNV.\n" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": { - "jupyter": { - "source_hidden": true + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We start by loading a dataset from [Findlay et al. (2018)](https://www.nature.com/articles/s41586-018-0461-z), which contains experimentally measured function scores of 3,893 *BRCA1* SNVs. These function scores reflect the extent by which the genetic variant has disrupted the protein's function, with lower scores indicating greater disruption. In this dataset, the SNVs are classified into three categories based on their function scores: `LOF` (loss-of-function), `INT` (intermediate), and `FUNC` (functional). We start by reading in this dataset.\n", + "\n", + "To keep the notebook streamlined, we've abstracted much of the preprocessing logic into accompanying scripts located in `brca1_utils`. This notebook can also be viewed [here](https://docs.nvidia.com/bionemo-framework/latest/main/examples/bionemo-evo2/zeroshot_brca1/)." + ] }, - "tags": [ - "hide-cell" - ] - }, - "outputs": [], - "source": [ - "def check_fp8_support():\n", - " \"\"\"Check if FP8 is supported on the current GPU.\n", - "\n", - " FP8 requires compute capability 8.9+ (Ada Lovelace/Hopper architecture or newer).\n", - " \"\"\"\n", - " if not torch.cuda.is_available():\n", - " return False, \"CUDA not available\"\n", - "\n", - " device_props = torch.cuda.get_device_properties(0)\n", - " compute_capability = f\"{device_props.major}.{device_props.minor}\"\n", - " device_name = device_props.name\n", - "\n", - " # FP8 is supported on compute capability 8.9+ (Ada Lovelace/Hopper architecture)\n", - " is_supported = (device_props.major > 8) or (device_props.major == 8 and device_props.minor >= 9)\n", - "\n", - " return is_supported, f\"Device: {device_name}, Compute Capability: {compute_capability}\"" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "FP8 Support: False\n", - "Device: NVIDIA RTX A6000, Compute Capability: 8.6\n" - ] - } - ], - "source": [ - "# Define output directories for prediction results\n", - "output_dir = Path(\"brca1_fasta_files\")\n", - "output_dir.mkdir(parents=True, exist_ok=True)\n", - "\n", - "# Save reference and variant sequences to FASTA\n", - "ref_fasta_path = output_dir / \"brca1_reference_sequences.fasta\"\n", - "var_fasta_path = output_dir / \"brca1_variant_sequences.fasta\"\n", - "\n", - "predict_ref_dir = output_dir / \"reference_predictions\"\n", - "predict_var_dir = output_dir / \"variant_predictions\"\n", - "predict_ref_dir.mkdir(parents=True, exist_ok=True)\n", - "predict_var_dir.mkdir(parents=True, exist_ok=True)\n", - "\n", - "fp8_supported, gpu_info = check_fp8_support()\n", - "print(f\"FP8 Support: {fp8_supported}\")\n", - "print(gpu_info)\n", - "\n", - "# Note: If FP8 is not supported, you may want to disable it in the model config\n", - "# The Evo2 config has 'use_fp8_input_projections: True' by default\n", - "\n", - "if FAST_CI_MODE:\n", - " model_subset_option = \"--num-layers 4 --hybrid-override-pattern SDH*\"\n", - "else:\n", - " model_subset_option = \"\"\n", - "\n", - "fp8_option = \"--fp8\" if fp8_supported else \"\"\n", - "\n", - "# Update predict commands to run on the full dataset\n", - "predict_ref_command = (\n", - " f\"predict_evo2 --fasta {ref_fasta_path} --ckpt-dir {checkpoint_path} \"\n", - " f\"--output-dir {predict_ref_dir} --model-size {MODEL_SIZE} --tensor-parallel-size 1 {model_subset_option} \"\n", - " f\"--pipeline-model-parallel-size 1 --context-parallel-size 1 --output-log-prob-seqs {fp8_option}\"\n", - ")\n", - "\n", - "predict_var_command = (\n", - " f\"predict_evo2 --fasta {var_fasta_path} --ckpt-dir {checkpoint_path} \"\n", - " f\"--output-dir {predict_var_dir} --model-size {MODEL_SIZE} --tensor-parallel-size 1 {model_subset_option} \"\n", - " f\"--pipeline-model-parallel-size 1 --context-parallel-size 1 --output-log-prob-seqs {fp8_option}\"\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Score reference sequences:" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": {}, - "outputs": [], - "source": [ - "%%capture\n", - "print(f\"Running command: {predict_ref_command}\")\n", - "\n", - "result = run_subprocess_safely(predict_ref_command)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "assert result[\"returncode\"] == 0, result" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Score variant sequences:" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": {}, - "outputs": [], - "source": [ - "%%capture\n", - "print(f\"Running command: {predict_var_command}\")\n", - "\n", - "result = run_subprocess_safely(predict_var_command)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "assert result[\"returncode\"] == 0, result" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We calculate the change in likelihoods for each variant relative to the likelihood of their respective wild-type sequence.\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "First, we load the prediction files and sequence id maps:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Find and load prediction files\n", - "ref_pred_files = glob.glob(os.path.join(predict_ref_dir, \"predictions__rank_*.pt\"))\n", - "var_pred_files = glob.glob(os.path.join(predict_var_dir, \"predictions__rank_*.pt\"))\n", - "\n", - "# Load sequence ID maps (maps sequence ID -> prediction index)\n", - "with open(os.path.join(predict_ref_dir, \"seq_idx_map.json\"), \"r\") as f:\n", - " ref_seq_idx_map = json.load(f)\n", - "with open(os.path.join(predict_var_dir, \"seq_idx_map.json\"), \"r\") as f:\n", - " var_seq_idx_map = json.load(f)\n", - "\n", - "# Load predictions\n", - "ref_preds = torch.load(ref_pred_files[0])\n", - "var_preds = torch.load(var_pred_files[0])" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Then, calculate the delta score:" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": {}, - "outputs": [ + "cell_type": "code", + "execution_count": 3, + "metadata": { + "jupyter": { + "source_hidden": true + }, + "tags": [ + "hide-cell" + ] + }, + "outputs": [], + "source": [ + "def download_data(data_dir=\"brca1\", commit_hash=\"3819474bee6c24938016614411f1fa025e542bbe\"):\n", + " \"\"\"Download required data files if they don't exist locally.\n", + "\n", + " Parameters:\n", + " -----------\n", + " data_dir : str\n", + " Directory to store downloaded files\n", + " commit_hash : str\n", + " GitHub commit hash for data version\n", + " \"\"\"\n", + " if not os.path.exists(data_dir):\n", + " os.makedirs(data_dir)\n", + "\n", + " excel_path = os.path.join(data_dir, \"41586_2018_461_MOESM3_ESM.xlsx\")\n", + " genome_path = os.path.join(data_dir, \"GRCh37.p13_chr17.fna.gz\")\n", + "\n", + " if not os.path.exists(excel_path):\n", + " os.system(\n", + " f\"wget https://github.com/ArcInstitute/evo2/raw/{commit_hash}/notebooks/brca1/41586_2018_461_MOESM3_ESM.xlsx -O {excel_path}\"\n", + " )\n", + "\n", + " if not os.path.exists(genome_path):\n", + " os.system(\n", + " f\"wget https://github.com/ArcInstitute/evo2/raw/{commit_hash}/notebooks/brca1/GRCh37.p13_chr17.fna.gz -O {genome_path}\"\n", + " )\n", + "\n", + " return excel_path, genome_path\n", + "\n", + "\n", + "def load_genome_sequence(genome_path):\n", + " \"\"\"Load genome sequence from FASTA file.\n", + "\n", + " Parameters:\n", + " -----------\n", + " genome_path : str\n", + " Path to the genome FASTA file\n", + "\n", + " Returns:\n", + " --------\n", + " str\n", + " Genome sequence string\n", + " \"\"\"\n", + " with gzip.open(genome_path, \"rt\") as handle:\n", + " for record in SeqIO.parse(handle, \"fasta\"):\n", + " return str(record.seq)\n", + "\n", + " raise ValueError(\"Failed to parse genome sequence\")\n", + "\n", + "\n", + "def load_brca1_data(excel_path):\n", + " \"\"\"Load and preprocess BRCA1 data from Excel file.\n", + "\n", + " Parameters:\n", + " -----------\n", + " excel_path : str\n", + " Path to the Excel file\n", + "\n", + " Returns:\n", + " --------\n", + " pandas.DataFrame\n", + " Processed BRCA1 dataframe\n", + " \"\"\"\n", + " # Load the dataframe\n", + " brca1_df = pd.read_excel(excel_path, header=2)\n", + "\n", + " # Select and rename columns\n", + " brca1_df = brca1_df[\n", + " [\n", + " \"chromosome\",\n", + " \"position (hg19)\",\n", + " \"reference\",\n", + " \"alt\",\n", + " \"function.score.mean\",\n", + " \"func.class\",\n", + " ]\n", + " ]\n", + "\n", + " brca1_df.rename(\n", + " columns={\n", + " \"chromosome\": \"chrom\",\n", + " \"position (hg19)\": \"pos\",\n", + " \"reference\": \"ref\",\n", + " \"alt\": \"alt\",\n", + " \"function.score.mean\": \"score\",\n", + " \"func.class\": \"class\",\n", + " },\n", + " inplace=True,\n", + " )\n", + "\n", + " # Convert to two-class system\n", + " brca1_df[\"class\"] = brca1_df[\"class\"].replace([\"FUNC\", \"INT\"], \"FUNC/INT\")\n", + "\n", + " return brca1_df" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "%%capture\n", + "\n", + "# Configuration parameters\n", + "DATA_DIR = \"brca1\"\n", + "SAMPLE_CONFIG = {\"sample_frac\": 0.1, \"balanced\": True, \"disable\": False, \"random_state\": 42}\n", + "\n", + "# 1. Download the necessary data files if not present\n", + "excel_path, genome_path = download_data(DATA_DIR)\n", + "seq_chr17 = load_genome_sequence(genome_path)\n", + "\n", + "# 2. Load and preprocess BRCA1 data\n", + "brca1_df = load_brca1_data(excel_path)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We then group the `FUNC` and `INT` classes of SNVs together into a single category (`FUNC/INT`).\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We build a function to parse the reference and variant sequences of a 8,192-bp window around the genomic position of each SNV, using the reference sequence of human chromosome 17 where *BRCA1* is located.\n", + "\n" + ] + }, { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
chromposrefaltscoreclassref_fasta_namevar_fasta_nameref_log_probsvar_log_probsevo2_delta_score
01741199726TC0.159762FUNC/INTBRCA1_ref_pos_41199726_T_class_FUNC/INTBRCA1_var_pos_41199726_TtoC_class_FUNC/INT-0.952952-0.953219-0.000267
11741209074TA-2.065569LOFBRCA1_ref_pos_41209074_T_class_LOFBRCA1_var_pos_41209074_TtoA_class_LOF-0.750379-0.750438-0.000059
21741256913AC-0.847753FUNC/INTBRCA1_ref_pos_41256913_A_class_FUNC/INTBRCA1_var_pos_41256913_AtoC_class_FUNC/INT-0.798110-0.799099-0.000989
31741219631TA-2.053739LOFBRCA1_ref_pos_41219631_T_class_LOFBRCA1_var_pos_41219631_TtoA_class_LOF-1.032214-1.032696-0.000482
41741215965GA-1.671525LOFBRCA1_ref_pos_41215965_G_class_LOFBRCA1_var_pos_41215965_GtoA_class_LOF-0.860933-0.861262-0.000329
\n", - "
" + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To make things run faster, we'll just look at a balanced sample of our data. If you want to run on the full dataset, set `disable_sample=True`" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "jupyter": { + "source_hidden": true + }, + "tags": [ + "hide-cell" + ] + }, + "outputs": [], + "source": [ + "def sample_data(df, sample_frac=1.0, balanced=True, disable=False, random_state=42):\n", + " \"\"\"Sample dataframe, optionally with balanced classes.\n", + "\n", + " Parameters:\n", + " -----------\n", + " df : pandas.DataFrame\n", + " Input dataframe\n", + " sample_frac : float\n", + " Fraction of data to sample\n", + " balanced : bool\n", + " Whether to balance classes\n", + " disable : bool\n", + " Whether to disable sampling\n", + " random_state : int\n", + " Random seed for reproducibility\n", + "\n", + " Returns:\n", + " --------\n", + " pandas.DataFrame\n", + " Sampled dataframe\n", + " \"\"\"\n", + " if disable:\n", + " return df\n", + "\n", + " if balanced:\n", + " # Get the number of rows in the dataframe\n", + " num_rows_minor_class = math.ceil(len(df[df[\"class\"] == \"LOF\"]) * sample_frac)\n", + " return (\n", + " pd.concat(\n", + " [\n", + " df[df[\"class\"] == \"LOF\"].sample(n=num_rows_minor_class, random_state=random_state),\n", + " df[df[\"class\"] == \"FUNC/INT\"].sample(n=num_rows_minor_class, random_state=random_state),\n", + " ]\n", + " )\n", + " .sample(frac=1.0, random_state=random_state)\n", + " .reset_index(drop=True)\n", + " )\n", + " else:\n", + " # Calculate the number of rows to sample\n", + " return df.sample(frac=sample_frac, random_state=random_state).reset_index(drop=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
chromposrefaltscoreclass
01741276097AG0.326953FUNC/INT
11741201130AG0.056569FUNC/INT
21741215938TA-2.017579LOF
31741215932AC-1.706222LOF
41741219685GT0.037593FUNC/INT
\n", + "
" + ], + "text/plain": [ + " chrom pos ref alt score class\n", + "0 17 41276097 A G 0.326953 FUNC/INT\n", + "1 17 41201130 A G 0.056569 FUNC/INT\n", + "2 17 41215938 T A -2.017579 LOF\n", + "3 17 41215932 A C -1.706222 LOF\n", + "4 17 41219685 G T 0.037593 FUNC/INT" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } ], - "text/plain": [ - " chrom pos ref alt score class \\\n", - "0 17 41199726 T C 0.159762 FUNC/INT \n", - "1 17 41209074 T A -2.065569 LOF \n", - "2 17 41256913 A C -0.847753 FUNC/INT \n", - "3 17 41219631 T A -2.053739 LOF \n", - "4 17 41215965 G A -1.671525 LOF \n", - "\n", - " ref_fasta_name \\\n", - "0 BRCA1_ref_pos_41199726_T_class_FUNC/INT \n", - "1 BRCA1_ref_pos_41209074_T_class_LOF \n", - "2 BRCA1_ref_pos_41256913_A_class_FUNC/INT \n", - "3 BRCA1_ref_pos_41219631_T_class_LOF \n", - "4 BRCA1_ref_pos_41215965_G_class_LOF \n", - "\n", - " var_fasta_name ref_log_probs var_log_probs \\\n", - "0 BRCA1_var_pos_41199726_TtoC_class_FUNC/INT -0.952952 -0.953219 \n", - "1 BRCA1_var_pos_41209074_TtoA_class_LOF -0.750379 -0.750438 \n", - "2 BRCA1_var_pos_41256913_AtoC_class_FUNC/INT -0.798110 -0.799099 \n", - "3 BRCA1_var_pos_41219631_TtoA_class_LOF -1.032214 -1.032696 \n", - "4 BRCA1_var_pos_41215965_GtoA_class_LOF -0.860933 -0.861262 \n", - "\n", - " evo2_delta_score \n", - "0 -0.000267 \n", - "1 -0.000059 \n", - "2 -0.000989 \n", - "3 -0.000482 \n", - "4 -0.000329 " + "source": [ + "OUTPUT_DIR = \"brca1_fasta_files\"\n", + "\n", + "brca1_df = sample_data(\n", + " brca1_df,\n", + " sample_frac=SAMPLE_CONFIG[\"sample_frac\"],\n", + " balanced=SAMPLE_CONFIG[\"balanced\"],\n", + " disable=SAMPLE_CONFIG[\"disable\"],\n", + " random_state=SAMPLE_CONFIG[\"random_state\"],\n", + ")\n", + "\n", + "brca1_df.head(5)" ] - }, - "execution_count": 15, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# next, calculate change in likelihoods\n", - "ref_log_probs = []\n", - "var_log_probs = []\n", - "for _, row in brca1_df.iterrows():\n", - " ref_name = row[\"ref_fasta_name\"]\n", - " var_name = row[\"var_fasta_name\"]\n", - " ref_log_probs.append(ref_preds[\"log_probs_seqs\"][ref_seq_idx_map[ref_name]].item())\n", - " var_log_probs.append(var_preds[\"log_probs_seqs\"][var_seq_idx_map[var_name]].item())\n", - "brca1_df[\"ref_log_probs\"] = ref_log_probs\n", - "brca1_df[\"var_log_probs\"] = var_log_probs\n", - "# ideally probability of a broken variant is lower than a good one. So a bad var - good ref is negative.\n", - "brca1_df[\"evo2_delta_score\"] = brca1_df[\"var_log_probs\"] - brca1_df[\"ref_log_probs\"]\n", - "brca1_df.head()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "This delta likelihood should be predictive of how disruptive the SNV is to the protein's function: the lower the delta, the more likely that the SNV is disruptive. We can show this by comparing the distributions of delta likelihoods for the two classes of SNVs (functional/intermediate vs loss-of-function)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "jupyter": { - "source_hidden": true }, - "tags": [ - "hide-cell" - ] - }, - "outputs": [], - "source": [ - "def plot_strip_with_means(df, x_col=\"evo2_delta_score\", class_col=\"class\"):\n", - " \"\"\"Creates a strip plot with jittered points and median indicators for each class using Seaborn.\n", - "\n", - " Parameters:\n", - " - df (pd.DataFrame): The input DataFrame containing data.\n", - " - x_col (str): The column name representing the x-axis values (e.g., evo2_delta_score).\n", - " - class_col (str): The column name representing the class labels.\n", - "\n", - " Returns:\n", - " - matplotlib Figure: Strip plot with median indicators.\n", - " \"\"\"\n", - " # NVIDIA theme colors\n", - " NVIDIA_GREEN = \"#76B900\" # noqa: N806\n", - " BACKGROUND_COLOR = \"#F8F8F8\" # noqa: N806\n", - " GRID_COLOR = \"#DDDDDD\" # noqa: N806\n", - " FONT_COLOR = \"#333333\" # noqa: N806\n", - "\n", - " # Determine order of classes (if not already specified)\n", - " unique_classes = sorted(df[class_col].unique())\n", - "\n", - " # Set up the plot with NVIDIA theme\n", - " plt.figure(figsize=(9, 5), facecolor=BACKGROUND_COLOR)\n", - " plt.style.use(\"default\") # Reset to default to avoid any pre-existing style\n", - "\n", - " # Create strip plot\n", - " p = sns.stripplot(\n", - " data=df,\n", - " x=x_col,\n", - " y=class_col,\n", - " hue=class_col,\n", - " order=unique_classes,\n", - " palette=[NVIDIA_GREEN, \"red\"],\n", - " size=6,\n", - " jitter=0.3,\n", - " alpha=0.6,\n", - " )\n", - "\n", - " # Add median indicators using boxplot\n", - " sns.boxplot(\n", - " showmeans=True,\n", - " meanline=True,\n", - " meanprops={\"visible\": False},\n", - " medianprops={\"color\": \"black\", \"ls\": \"-\", \"lw\": 2},\n", - " whiskerprops={\"visible\": False},\n", - " zorder=10,\n", - " x=x_col,\n", - " y=class_col,\n", - " data=df,\n", - " order=unique_classes,\n", - " showfliers=False,\n", - " showbox=False,\n", - " showcaps=False,\n", - " ax=p,\n", - " )\n", - "\n", - " # Customize plot appearance\n", - " plt.title(\n", - " \"Distribution of Delta Likelihoods Scores\\nComparing Evo 2 likelihood scores for different BRCA1 SNV classes\",\n", - " color=FONT_COLOR,\n", - " fontsize=12,\n", - " loc=\"left\",\n", - " )\n", - " plt.xlabel(\"Delta Likelihood Score, Evo 2\", color=FONT_COLOR)\n", - " plt.ylabel(\"BRCA1 SNV Class\", color=FONT_COLOR)\n", - "\n", - " # Customize grid and tick colors\n", - " plt.grid(color=GRID_COLOR, axis=\"x\", linestyle=\"--\", linewidth=0.5)\n", - " plt.tick_params(colors=FONT_COLOR)\n", - "\n", - " # Set background color\n", - " plt.gca().set_facecolor(BACKGROUND_COLOR)\n", - " plt.gcf().set_facecolor(BACKGROUND_COLOR)\n", - "\n", - " plt.tight_layout()\n", - "\n", - " # return plt.gcf()" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "metadata": {}, - "outputs": [ { - "data": { - "image/png": "", - "text/plain": [ - "
" + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Next, we'll write these to local `.fasta` files so we can use them for prediction below." ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "plot_strip_with_means(brca1_df, x_col=\"evo2_delta_score\", class_col=\"class\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We can also calculate the area under the receiver operating characteristic curve (AUROC) of this zero-shot prediction method. Note that the results are nearly random unless you are on one of the following configurations:\n", - "* `--fp8` on an fp8 enabled GPU with either the 1b or 7b models. The 40b likely works as well.\n", - "* the 7b model uniquely seems to work well without `--fp8` so if you are on an older device, the 7b model should produce\n", - " robust results. Change the `MODEL_SIZE` earlier in this tutorial and rerun for good results in that case.\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "metadata": {}, - "outputs": [ + }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "Zero-shot prediction AUROC: 0.77\n" - ] - } - ], - "source": [ - "# Calculate AUROC of zero-shot predictions\n", - "# class 1 is LOF which is the bad thing. That means we expect this to be more negative.\n", - "y_true = brca1_df[\"class\"] == \"LOF\"\n", - "auroc = roc_auc_score(y_true, -brca1_df[\"evo2_delta_score\"])\n", - "print(f\"Zero-shot prediction AUROC: {auroc:.2}\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "jupyter": { - "source_hidden": true + "cell_type": "code", + "execution_count": 7, + "metadata": { + "jupyter": { + "source_hidden": true + }, + "tags": [ + "hide-cell" + ] + }, + "outputs": [], + "source": [ + "def parse_sequences(pos, ref, alt, seq_chr17, window_size=8192):\n", + " \"\"\"Parse reference and variant sequences from the reference genome sequence.\n", + "\n", + " Parameters:\n", + " -----------\n", + " pos : int\n", + " Position (1-indexed)\n", + " ref : str\n", + " Reference base\n", + " alt : str\n", + " Alternate base\n", + " seq_chr17 : str\n", + " Full chromosome 17 sequence\n", + " window_size : int\n", + " Size of the sequence window to extract\n", + "\n", + " Returns:\n", + " --------\n", + " tuple\n", + " (reference_sequence, variant_sequence)\n", + " \"\"\"\n", + " p = pos - 1 # Convert to 0-indexed position\n", + " full_seq = seq_chr17\n", + "\n", + " ref_seq_start = max(0, p - window_size // 2)\n", + " ref_seq_end = min(len(full_seq), p + window_size // 2)\n", + " ref_seq = seq_chr17[ref_seq_start:ref_seq_end]\n", + " snv_pos_in_ref = min(window_size // 2, p)\n", + " var_seq = ref_seq[:snv_pos_in_ref] + alt + ref_seq[snv_pos_in_ref + 1 :]\n", + "\n", + " # Sanity checks\n", + " assert len(var_seq) == len(ref_seq)\n", + " assert ref_seq[snv_pos_in_ref] == ref\n", + " assert var_seq[snv_pos_in_ref] == alt\n", + "\n", + " return ref_seq, var_seq\n", + "\n", + "\n", + "def generate_fasta_files(df, seq_chr17, output_dir=\"brca1_fasta_files\", window_size=8192):\n", + " \"\"\"Generate FASTA files for reference and variant sequences.\n", + "\n", + " Parameters:\n", + " -----------\n", + " df : pandas.DataFrame\n", + " Dataframe with variant information\n", + " seq_chr17 : str\n", + " Chromosome 17 sequence\n", + " output_dir : str\n", + " Output directory for FASTA files\n", + " window_size : int\n", + " Size of sequence window\n", + "\n", + " Returns:\n", + " --------\n", + " pandas.DataFrame\n", + " Dataframe with added columns for FASTA names\n", + " \"\"\"\n", + " # Create output directory\n", + " output_dir = Path(output_dir)\n", + " output_dir.mkdir(parents=True, exist_ok=True)\n", + "\n", + " # Paths for output files\n", + " ref_fasta_path = output_dir / \"brca1_reference_sequences.fasta\"\n", + " var_fasta_path = output_dir / \"brca1_variant_sequences.fasta\"\n", + "\n", + " # Track unique sequences\n", + " ref_sequences = set()\n", + " var_sequences = set()\n", + " ref_seq_to_name = {}\n", + "\n", + " # Store unique sequences with metadata for writing\n", + " ref_entries = []\n", + " var_entries = []\n", + " ref_names = []\n", + " var_names = []\n", + "\n", + " # Collect unique reference and variant sequences\n", + " for idx, row in df.iterrows():\n", + " ref_seq, var_seq = parse_sequences(row[\"pos\"], row[\"ref\"], row[\"alt\"], seq_chr17, window_size)\n", + "\n", + " # Add to sets to ensure uniqueness\n", + " if ref_seq not in ref_sequences:\n", + " ref_sequences.add(ref_seq)\n", + " ref_name = f\"BRCA1_ref_pos_{row['pos']}_{row['ref']}_class_{row['class']}\"\n", + "\n", + " ref_entries.append(f\">{ref_name}\\n{ref_seq}\\n\")\n", + " ref_names.append(ref_name)\n", + " ref_seq_to_name[ref_seq] = ref_name\n", + " else:\n", + " ref_name = ref_seq_to_name[ref_seq]\n", + " ref_names.append(ref_name)\n", + "\n", + " if var_seq not in var_sequences:\n", + " var_sequences.add(var_seq)\n", + " var_name = f\"BRCA1_var_pos_{row['pos']}_{row['ref']}to{row['alt']}_class_{row['class']}\"\n", + "\n", + " var_entries.append(f\">{var_name}\\n{var_seq}\\n\")\n", + " var_names.append(var_name)\n", + " else:\n", + " assert False, \"Duplicate variant sequence\"\n", + "\n", + " # Write unique sequences to FASTA files\n", + " with open(ref_fasta_path, \"w\") as f:\n", + " f.writelines(ref_entries)\n", + "\n", + " with open(var_fasta_path, \"w\") as f:\n", + " f.writelines(var_entries)\n", + "\n", + " # Add FASTA names to dataframe\n", + " df_with_names = df.copy()\n", + " df_with_names[\"ref_fasta_name\"] = ref_names\n", + " df_with_names[\"var_fasta_name\"] = var_names\n", + "\n", + " print(f\"Total unique reference sequences: {len(ref_sequences)}\")\n", + " print(f\"Total unique variant sequences: {len(var_sequences)}\")\n", + "\n", + " return df_with_names" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Total unique reference sequences: 156\n", + "Total unique variant sequences: 166\n" + ] + } + ], + "source": [ + "brca1_df = generate_fasta_files(brca1_df, seq_chr17, output_dir=OUTPUT_DIR)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Load Evo 2 Checkpoints" + ] }, - "tags": [ - "hide-cell" - ] - }, - "outputs": [], - "source": [ - "def plot_roc_curve(df):\n", - " \"\"\"Plots an ROC curve using Seaborn with a light NVIDIA-themed design.\n", - "\n", - " The function assumes:\n", - " - `class` column as the true labels (binary, 'LOF' = 1, else 0).\n", - " - `evo2_delta_score` as the prediction score.\n", - "\n", - " Parameters:\n", - " - df (pd.DataFrame): DataFrame containing `class` and `evo2_delta_score`.\n", - "\n", - " Returns:\n", - " - matplotlib Figure: ROC Curve Visualization.\n", - " \"\"\"\n", - " # NVIDIA theme colors\n", - " NVIDIA_GREEN = \"#76B900\" # noqa: N806\n", - " BACKGROUND_COLOR = \"#F8F8F8\" # noqa: N806\n", - " GRID_COLOR = \"#DDDDDD\" # noqa: N806\n", - " FONT_COLOR = \"#333333\" # noqa: N806\n", - "\n", - " # Validate required columns\n", - " if \"class\" not in df.columns or \"evo2_delta_score\" not in df.columns:\n", - " raise ValueError(\"DataFrame must contain 'class' and 'evo2_delta_score' columns.\")\n", - "\n", - " # Convert 'class' to binary labels: Assume 'LOF' = 1, anything else = 0\n", - " y_true = (df[\"class\"] == \"LOF\").astype(int)\n", - "\n", - " # Compute ROC curve\n", - " fpr, tpr, _ = roc_curve(y_true, -df[\"evo2_delta_score\"]) # Negative to align with previous logic\n", - " roc_auc = auc(fpr, tpr)\n", - "\n", - " # Set up the plot with NVIDIA theme\n", - " plt.figure(figsize=(9, 5), facecolor=BACKGROUND_COLOR)\n", - " plt.style.use(\"default\") # Reset to default to avoid any pre-existing style\n", - "\n", - " # Plot ROC curve\n", - " plt.plot(fpr, tpr, color=NVIDIA_GREEN, lw=3, label=f\"ROC curve (AUROC = {roc_auc:.2f})\")\n", - "\n", - " # Plot diagonal reference line for random guessing\n", - " plt.plot([0, 1], [0, 1], color=\"gray\", lw=2, linestyle=\"--\")\n", - "\n", - " # Customize plot appearance\n", - " plt.xlim([0.0, 1.0])\n", - " plt.ylim([0.0, 1.05])\n", - " plt.xlabel(\"False Positive Rate\", color=FONT_COLOR, fontsize=12)\n", - " plt.ylabel(\"True Positive Rate\", color=FONT_COLOR, fontsize=12)\n", - " plt.title(\n", - " \"Zeroshot ROC Curve\\nEvaluating the discriminative performance of Evo 2 predictions\",\n", - " color=FONT_COLOR,\n", - " fontsize=16,\n", - " loc=\"left\",\n", - " )\n", - "\n", - " # Customize grid and tick colors\n", - " plt.grid(color=GRID_COLOR, linestyle=\"--\", linewidth=0.5)\n", - " plt.tick_params(colors=FONT_COLOR)\n", - "\n", - " # Set background color\n", - " plt.gca().set_facecolor(BACKGROUND_COLOR)\n", - "\n", - " # Add legend\n", - " plt.legend(loc=\"lower right\", frameon=True, facecolor=BACKGROUND_COLOR, edgecolor=GRID_COLOR)" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "metadata": {}, - "outputs": [ { - "data": { - "image/png": "", - "text/plain": [ - "
" + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "Then, we load Evo 2 1B model, loading the Evo 2 weights from hugging face.\n", + "\n", + "*Note - for better performance, load the 7b model by setting `MODEL_SIZE=\"7b\"` which also works well GPUs that do not support FP8.*\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "%%capture\n", + "MODEL_SIZE = \"1b\" # also try 7b if you have a GPU with more than 32GB of memory\n", + "\n", + "# Define checkpoint path\n", + "if MODEL_SIZE == \"1b\":\n", + " from pathlib import Path\n", + "\n", + " from bionemo.core.data.load import load\n", + "\n", + " # This line will download the checkpoint from NGC to your $HOME/.cache/bionemo directory and return the path.\n", + " # To do the same from the command line, use `CHECKPOINT_PATH=$(download_bionemo_data evo2/1b-8k-bf16:1.0)`\n", + "\n", + " # Download the 1b BF16 checkpoint from NGC\n", + " # Available checkpoints: evo2/1b-8k-bf16:1.0, evo2/1b-8k:1.0, evo2/7b-8k:1.0, evo2/7b-1m:1.0\n", + " nemo2_ckpt_path = load(\"evo2/1b-8k-bf16:1.0\")\n", + " checkpoint_path = Path(\"evo2_1b_bf16_mbridge\")\n", + " from bionemo.evo2.data.dataset_tokenizer import (\n", + " DEFAULT_HF_TOKENIZER_MODEL_PATH_512, # use the 512 size for historical reasons\n", + " )\n", + "\n", + " mixed_precision_recipe = \"bf16_mixed\" # also try bf16_with_fp8_current_scaling_mixed\n", + " convert_ckpt_cmd = f\"\"\"evo2_convert_nemo2_to_mbridge \\\n", + " --nemo2-ckpt-dir {nemo2_ckpt_path} \\\n", + " --mbridge-ckpt-dir {checkpoint_path} \\\n", + " --model-size 1b \\\n", + " --mixed-precision-recipe {mixed_precision_recipe} \\\n", + " --seq-length 8192 \\\n", + " --tokenizer-path {DEFAULT_HF_TOKENIZER_MODEL_PATH_512} \\\n", + " \"\"\".rstrip()\n", + " print(f\"Running command: {convert_ckpt_cmd}\")\n", + "\n", + " result = run_subprocess_safely(convert_ckpt_cmd)\n", + " print(f\"Downloaded checkpoint to: {nemo2_ckpt_path} and converted to mbridge format at {checkpoint_path}\")\n", + "else:\n", + " assert False, \"Implement conversion for other model sizes. Should be similar to the 1b case above.\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Score Sequences" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Next, we score the likelihoods of the reference and variant sequences of each SNV.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "jupyter": { + "source_hidden": true + }, + "tags": [ + "hide-cell" + ] + }, + "outputs": [], + "source": [ + "def check_fp8_support():\n", + " \"\"\"Check if FP8 is supported on the current GPU.\n", + "\n", + " FP8 requires compute capability 8.9+ (Ada Lovelace/Hopper architecture or newer).\n", + " \"\"\"\n", + " if not torch.cuda.is_available():\n", + " return False, \"CUDA not available\"\n", + "\n", + " device_props = torch.cuda.get_device_properties(0)\n", + " compute_capability = f\"{device_props.major}.{device_props.minor}\"\n", + " device_name = device_props.name\n", + "\n", + " # FP8 is supported on compute capability 8.9+ (Ada Lovelace/Hopper architecture)\n", + " is_supported = (device_props.major > 8) or (device_props.major == 8 and device_props.minor >= 9)\n", + "\n", + " return is_supported, f\"Device: {device_name}, Compute Capability: {compute_capability}\"" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "FP8 Support: True\n", + "Device: NVIDIA GeForce RTX 5090, Compute Capability: 12.0\n" + ] + } + ], + "source": [ + "# Define output directories for prediction results\n", + "output_dir = Path(\"brca1_fasta_files\")\n", + "output_dir.mkdir(parents=True, exist_ok=True)\n", + "\n", + "# Save reference and variant sequences to FASTA\n", + "ref_fasta_path = output_dir / \"brca1_reference_sequences.fasta\"\n", + "var_fasta_path = output_dir / \"brca1_variant_sequences.fasta\"\n", + "\n", + "predict_ref_dir = output_dir / \"reference_predictions\"\n", + "predict_var_dir = output_dir / \"variant_predictions\"\n", + "predict_ref_dir.mkdir(parents=True, exist_ok=True)\n", + "predict_var_dir.mkdir(parents=True, exist_ok=True)\n", + "\n", + "fp8_supported, gpu_info = check_fp8_support()\n", + "print(f\"FP8 Support: {fp8_supported}\")\n", + "print(gpu_info)\n", + "\n", + "# Note: If FP8 is not supported, you may want to disable it in the model config\n", + "# The Evo2 config has 'use_fp8_input_projections: True' by default\n", + "\n", + "if FAST_CI_MODE:\n", + " model_subset_option = \"--num-layers 4 --hybrid-override-pattern SDH*\"\n", + "else:\n", + " model_subset_option = \"\"\n", + "\n", + "# Disable FP8 for now, it is underperforming at the moment in prediction. 0.7 vs 0.76 and similar run times.\n", + "fp8_option = \"--mixed-precision-recipe bf16_with_fp8_current_scaling_mixed\" if fp8_supported else \"\"\n", + "\n", + "# Update predict commands to run on the full dataset\n", + "predict_ref_command = (\n", + " f\"predict_evo2 --fasta {ref_fasta_path} --ckpt-dir {checkpoint_path} \"\n", + " f\"--output-dir {predict_ref_dir} --tensor-parallel-size 1 {model_subset_option} \"\n", + " f\"--pipeline-model-parallel-size 1 --context-parallel-size 1 --output-log-prob-seqs {fp8_option}\"\n", + ")\n", + "\n", + "predict_var_command = (\n", + " f\"predict_evo2 --fasta {var_fasta_path} --ckpt-dir {checkpoint_path} \"\n", + " f\"--output-dir {predict_var_dir} --tensor-parallel-size 1 {model_subset_option} \"\n", + " f\"--pipeline-model-parallel-size 1 --context-parallel-size 1 --output-log-prob-seqs {fp8_option}\"\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Score reference sequences:" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "%%capture\n", + "print(f\"Running command: {predict_ref_command}\")\n", + "\n", + "result = run_subprocess_safely(predict_ref_command)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "assert result[\"returncode\"] == 0, result" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Score variant sequences:" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "%%capture\n", + "print(f\"Running command: {predict_var_command}\")\n", + "\n", + "result = run_subprocess_safely(predict_var_command)" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "assert result[\"returncode\"] == 0, result" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We calculate the change in likelihoods for each variant relative to the likelihood of their respective wild-type sequence.\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "First, we load the prediction files and sequence id maps:" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "# Find and load prediction files\n", + "# File naming convention (epoch mode): predictions__rank_{global_rank}__dp_rank_{dp_rank}.pt\n", + "ref_pred_files = glob.glob(os.path.join(predict_ref_dir, \"predictions__rank_*.pt\"))\n", + "var_pred_files = glob.glob(os.path.join(predict_var_dir, \"predictions__rank_*.pt\"))\n", + "\n", + "# Load sequence ID maps (maps sequence ID -> prediction index)\n", + "with open(os.path.join(predict_ref_dir, \"seq_idx_map.json\"), \"r\") as f:\n", + " ref_seq_idx_map = json.load(f)\n", + "with open(os.path.join(predict_var_dir, \"seq_idx_map.json\"), \"r\") as f:\n", + " var_seq_idx_map = json.load(f)\n", + "\n", + "# Load predictions\n", + "ref_preds = torch.load(ref_pred_files[0])\n", + "var_preds = torch.load(var_pred_files[0])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Then, calculate the delta score:" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
chromposrefaltscoreclassref_fasta_namevar_fasta_nameref_log_probsvar_log_probsevo2_delta_score
01741276097AG0.326953FUNC/INTBRCA1_ref_pos_41276097_A_class_FUNC/INTBRCA1_var_pos_41276097_AtoG_class_FUNC/INT-0.935537-0.935760-0.000223
11741201130AG0.056569FUNC/INTBRCA1_ref_pos_41201130_A_class_FUNC/INTBRCA1_var_pos_41201130_AtoG_class_FUNC/INT-0.929413-0.929956-0.000543
21741215938TA-2.017579LOFBRCA1_ref_pos_41215938_T_class_LOFBRCA1_var_pos_41215938_TtoA_class_LOF-0.864209-0.866731-0.002521
31741215932AC-1.706222LOFBRCA1_ref_pos_41215932_A_class_LOFBRCA1_var_pos_41215932_AtoC_class_LOF-0.864879-0.866319-0.001440
41741219685GT0.037593FUNC/INTBRCA1_ref_pos_41219685_G_class_FUNC/INTBRCA1_var_pos_41219685_GtoT_class_FUNC/INT-1.027526-1.027546-0.000020
\n", + "
" + ], + "text/plain": [ + " chrom pos ref alt score class \\\n", + "0 17 41276097 A G 0.326953 FUNC/INT \n", + "1 17 41201130 A G 0.056569 FUNC/INT \n", + "2 17 41215938 T A -2.017579 LOF \n", + "3 17 41215932 A C -1.706222 LOF \n", + "4 17 41219685 G T 0.037593 FUNC/INT \n", + "\n", + " ref_fasta_name \\\n", + "0 BRCA1_ref_pos_41276097_A_class_FUNC/INT \n", + "1 BRCA1_ref_pos_41201130_A_class_FUNC/INT \n", + "2 BRCA1_ref_pos_41215938_T_class_LOF \n", + "3 BRCA1_ref_pos_41215932_A_class_LOF \n", + "4 BRCA1_ref_pos_41219685_G_class_FUNC/INT \n", + "\n", + " var_fasta_name ref_log_probs var_log_probs \\\n", + "0 BRCA1_var_pos_41276097_AtoG_class_FUNC/INT -0.935537 -0.935760 \n", + "1 BRCA1_var_pos_41201130_AtoG_class_FUNC/INT -0.929413 -0.929956 \n", + "2 BRCA1_var_pos_41215938_TtoA_class_LOF -0.864209 -0.866731 \n", + "3 BRCA1_var_pos_41215932_AtoC_class_LOF -0.864879 -0.866319 \n", + "4 BRCA1_var_pos_41219685_GtoT_class_FUNC/INT -1.027526 -1.027546 \n", + "\n", + " evo2_delta_score \n", + "0 -0.000223 \n", + "1 -0.000543 \n", + "2 -0.002521 \n", + "3 -0.001440 \n", + "4 -0.000020 " + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# next, calculate change in likelihoods\n", + "ref_log_probs = []\n", + "var_log_probs = []\n", + "for _, row in brca1_df.iterrows():\n", + " ref_name = row[\"ref_fasta_name\"]\n", + " var_name = row[\"var_fasta_name\"]\n", + " ref_log_probs.append(ref_preds[\"log_probs_seqs\"][ref_seq_idx_map[ref_name]].item())\n", + " var_log_probs.append(var_preds[\"log_probs_seqs\"][var_seq_idx_map[var_name]].item())\n", + "brca1_df[\"ref_log_probs\"] = ref_log_probs\n", + "brca1_df[\"var_log_probs\"] = var_log_probs\n", + "# ideally probability of a broken variant is lower than a good one. So a bad var - good ref is negative.\n", + "brca1_df[\"evo2_delta_score\"] = brca1_df[\"var_log_probs\"] - brca1_df[\"ref_log_probs\"]\n", + "brca1_df.head()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This delta likelihood should be predictive of how disruptive the SNV is to the protein's function: the lower the delta, the more likely that the SNV is disruptive. We can show this by comparing the distributions of delta likelihoods for the two classes of SNVs (functional/intermediate vs loss-of-function)." + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": { + "jupyter": { + "source_hidden": true + }, + "tags": [ + "hide-cell" + ] + }, + "outputs": [], + "source": [ + "def plot_strip_with_means(df, x_col=\"evo2_delta_score\", class_col=\"class\"):\n", + " \"\"\"Creates a strip plot with jittered points and median indicators for each class using Seaborn.\n", + "\n", + " Parameters:\n", + " - df (pd.DataFrame): The input DataFrame containing data.\n", + " - x_col (str): The column name representing the x-axis values (e.g., evo2_delta_score).\n", + " - class_col (str): The column name representing the class labels.\n", + "\n", + " Returns:\n", + " - matplotlib Figure: Strip plot with median indicators.\n", + " \"\"\"\n", + " # NVIDIA theme colors\n", + " NVIDIA_GREEN = \"#76B900\" # noqa: N806\n", + " BACKGROUND_COLOR = \"#F8F8F8\" # noqa: N806\n", + " GRID_COLOR = \"#DDDDDD\" # noqa: N806\n", + " FONT_COLOR = \"#333333\" # noqa: N806\n", + "\n", + " # Determine order of classes (if not already specified)\n", + " unique_classes = sorted(df[class_col].unique())\n", + "\n", + " # Set up the plot with NVIDIA theme\n", + " plt.figure(figsize=(9, 5), facecolor=BACKGROUND_COLOR)\n", + " plt.style.use(\"default\") # Reset to default to avoid any pre-existing style\n", + "\n", + " # Create strip plot\n", + " p = sns.stripplot(\n", + " data=df,\n", + " x=x_col,\n", + " y=class_col,\n", + " hue=class_col,\n", + " order=unique_classes,\n", + " palette=[NVIDIA_GREEN, \"red\"],\n", + " size=6,\n", + " jitter=0.3,\n", + " alpha=0.6,\n", + " )\n", + "\n", + " # Add median indicators using boxplot\n", + " sns.boxplot(\n", + " showmeans=True,\n", + " meanline=True,\n", + " meanprops={\"visible\": False},\n", + " medianprops={\"color\": \"black\", \"ls\": \"-\", \"lw\": 2},\n", + " whiskerprops={\"visible\": False},\n", + " zorder=10,\n", + " x=x_col,\n", + " y=class_col,\n", + " data=df,\n", + " order=unique_classes,\n", + " showfliers=False,\n", + " showbox=False,\n", + " showcaps=False,\n", + " ax=p,\n", + " )\n", + "\n", + " # Customize plot appearance\n", + " plt.title(\n", + " \"Distribution of Delta Likelihoods Scores\\nComparing Evo 2 likelihood scores for different BRCA1 SNV classes\",\n", + " color=FONT_COLOR,\n", + " fontsize=12,\n", + " loc=\"left\",\n", + " )\n", + " plt.xlabel(\"Delta Likelihood Score, Evo 2\", color=FONT_COLOR)\n", + " plt.ylabel(\"BRCA1 SNV Class\", color=FONT_COLOR)\n", + "\n", + " # Customize grid and tick colors\n", + " plt.grid(color=GRID_COLOR, axis=\"x\", linestyle=\"--\", linewidth=0.5)\n", + " plt.tick_params(colors=FONT_COLOR)\n", + "\n", + " # Set background color\n", + " plt.gca().set_facecolor(BACKGROUND_COLOR)\n", + " plt.gcf().set_facecolor(BACKGROUND_COLOR)\n", + "\n", + " plt.tight_layout()\n", + "\n", + " # return plt.gcf()" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "plot_strip_with_means(brca1_df, x_col=\"evo2_delta_score\", class_col=\"class\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can also calculate the area under the receiver operating characteristic curve (AUROC) of this zero-shot prediction method. Note that the results are nearly random unless you are on one of the following configurations:\n", + "* `--fp8` on an fp8 enabled GPU with either the 1b or 7b models. The 40b likely works as well.\n", + "* the 7b model uniquely seems to work well without `--fp8` so if you are on an older device, the 7b model should produce\n", + " robust results. Change the `MODEL_SIZE` earlier in this tutorial and rerun for good results in that case.\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Zero-shot prediction AUROC: 0.73\n" + ] + } + ], + "source": [ + "# Calculate AUROC of zero-shot predictions\n", + "# class 1 is LOF which is the bad thing. That means we expect this to be more negative.\n", + "y_true = brca1_df[\"class\"] == \"LOF\"\n", + "auroc = roc_auc_score(y_true, -brca1_df[\"evo2_delta_score\"])\n", + "print(f\"Zero-shot prediction AUROC: {auroc:.2}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": { + "jupyter": { + "source_hidden": true + }, + "tags": [ + "hide-cell" + ] + }, + "outputs": [], + "source": [ + "def plot_roc_curve(df):\n", + " \"\"\"Plots an ROC curve using Seaborn with a light NVIDIA-themed design.\n", + "\n", + " The function assumes:\n", + " - `class` column as the true labels (binary, 'LOF' = 1, else 0).\n", + " - `evo2_delta_score` as the prediction score.\n", + "\n", + " Parameters:\n", + " - df (pd.DataFrame): DataFrame containing `class` and `evo2_delta_score`.\n", + "\n", + " Returns:\n", + " - matplotlib Figure: ROC Curve Visualization.\n", + " \"\"\"\n", + " # NVIDIA theme colors\n", + " NVIDIA_GREEN = \"#76B900\" # noqa: N806\n", + " BACKGROUND_COLOR = \"#F8F8F8\" # noqa: N806\n", + " GRID_COLOR = \"#DDDDDD\" # noqa: N806\n", + " FONT_COLOR = \"#333333\" # noqa: N806\n", + "\n", + " # Validate required columns\n", + " if \"class\" not in df.columns or \"evo2_delta_score\" not in df.columns:\n", + " raise ValueError(\"DataFrame must contain 'class' and 'evo2_delta_score' columns.\")\n", + "\n", + " # Convert 'class' to binary labels: Assume 'LOF' = 1, anything else = 0\n", + " y_true = (df[\"class\"] == \"LOF\").astype(int)\n", + "\n", + " # Compute ROC curve\n", + " fpr, tpr, _ = roc_curve(y_true, -df[\"evo2_delta_score\"]) # Negative to align with previous logic\n", + " roc_auc = auc(fpr, tpr)\n", + "\n", + " # Set up the plot with NVIDIA theme\n", + " plt.figure(figsize=(9, 5), facecolor=BACKGROUND_COLOR)\n", + " plt.style.use(\"default\") # Reset to default to avoid any pre-existing style\n", + "\n", + " # Plot ROC curve\n", + " plt.plot(fpr, tpr, color=NVIDIA_GREEN, lw=3, label=f\"ROC curve (AUROC = {roc_auc:.2f})\")\n", + "\n", + " # Plot diagonal reference line for random guessing\n", + " plt.plot([0, 1], [0, 1], color=\"gray\", lw=2, linestyle=\"--\")\n", + "\n", + " # Customize plot appearance\n", + " plt.xlim([0.0, 1.0])\n", + " plt.ylim([0.0, 1.05])\n", + " plt.xlabel(\"False Positive Rate\", color=FONT_COLOR, fontsize=12)\n", + " plt.ylabel(\"True Positive Rate\", color=FONT_COLOR, fontsize=12)\n", + " plt.title(\n", + " \"Zeroshot ROC Curve\\nEvaluating the discriminative performance of Evo 2 predictions\",\n", + " color=FONT_COLOR,\n", + " fontsize=16,\n", + " loc=\"left\",\n", + " )\n", + "\n", + " # Customize grid and tick colors\n", + " plt.grid(color=GRID_COLOR, linestyle=\"--\", linewidth=0.5)\n", + " plt.tick_params(colors=FONT_COLOR)\n", + "\n", + " # Set background color\n", + " plt.gca().set_facecolor(BACKGROUND_COLOR)\n", + "\n", + " # Add legend\n", + " plt.legend(loc=\"lower right\", frameon=True, facecolor=BACKGROUND_COLOR, edgecolor=GRID_COLOR)" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "plot_roc_curve(brca1_df)" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [], + "source": [ + "# Check if the AUC is a reasonable value for our CI suite when we run the full model\n", + "assert FAST_CI_MODE or auroc >= 0.7" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Full Sample Performance\n", + "\n", + "The above analysis may have been performed on a subset of the available data.\n", + "\n", + "For comparison, the table below presents the AUROC scores for different model sizes trained on the *full dataset* (100% sample fraction).\n", + "\n", + "| Model Size | Dataset Sample Fraction | AUROC |\n", + "|------------|------------------------|-------|\n", + "| Evo 2 1B | 100% | 0.74 |\n", + "| Evo 2 7B | 100% | 0.87 |\n" ] - }, - "metadata": {}, - "output_type": "display_data" } - ], - "source": [ - "plot_roc_curve(brca1_df)" - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "metadata": {}, - "outputs": [], - "source": [ - "# Check if the AUC is a reasonable value for our CI suite when we run the full model\n", - "assert FAST_CI_MODE or auroc >= 0.73" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Full Sample Performance\n", - "\n", - "The above analysis may have been performed on a subset of the available data.\n", - "\n", - "For comparison, the table below presents the AUROC scores for different model sizes trained on the *full dataset* (100% sample fraction).\n", - "\n", - "| Model Size | Dataset Sample Fraction | AUROC |\n", - "|------------|------------------------|-------|\n", - "| Evo 2 1B | 100% | 0.74 |\n", - "| Evo 2 7B | 100% | 0.87 |\n" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" + } }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.12.3" - } - }, - "nbformat": 4, - "nbformat_minor": 2 + "nbformat": 4, + "nbformat_minor": 2 } diff --git a/bionemo-recipes/recipes/evo2_megatron/pyproject.toml b/bionemo-recipes/recipes/evo2_megatron/pyproject.toml index 355d808460..ce454bf692 100644 --- a/bionemo-recipes/recipes/evo2_megatron/pyproject.toml +++ b/bionemo-recipes/recipes/evo2_megatron/pyproject.toml @@ -35,9 +35,9 @@ test = [] [project.scripts] torchrun = "torch.distributed.run:main" -#infer_evo2 = "bionemo.evo2.run.infer:main" +infer_evo2 = "bionemo.evo2.run.infer:main" train_evo2 = "bionemo.evo2.run.train:main" -#predict_evo2 = "bionemo.evo2.run.predict:main" +predict_evo2 = "bionemo.evo2.run.predict:main" preprocess_evo2 = "bionemo.evo2.data.preprocess:main" splice_evo2 = "bionemo.evo2.data.transcript_extraction:main" evo2_convert_nemo2_to_mbridge = "bionemo.evo2.utils.checkpoint.nemo2_to_mbridge:main" diff --git a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/evo2_provider.py b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/evo2_provider.py index dbd61f3a5c..d73366ee73 100644 --- a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/evo2_provider.py +++ b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/evo2_provider.py @@ -511,8 +511,20 @@ def provide(self, pre_process=None, post_process=None, vp_stage=None) -> MCoreHy rotary_percent=self.rotary_percent, rotary_base=self.rotary_base, seq_len_interpolation_factor=self.seq_len_interpolation_factor, - pre_process=pre_process or parallel_state.is_pipeline_first_stage(), - post_process=post_process or parallel_state.is_pipeline_last_stage(), + # Note: When self.pre_process/self.post_process is explicitly False (e.g., for embedding + # extraction), we must use that value regardless of what the caller passes. This is because + # _create_model in megatron.bridge always passes the pipeline stage values, but we want to + # disable post-processing when extracting embeddings. + pre_process=( + False + if self.pre_process is False + else (pre_process if pre_process is not None else parallel_state.is_pipeline_first_stage()) + ), + post_process=( + False + if self.post_process is False + else (post_process if post_process is not None else parallel_state.is_pipeline_last_stage()) + ), share_embeddings_and_output_weights=self.share_embeddings_and_output_weights, hyena_init_method=self.hyena_init_method, hyena_output_layer_init_method=self.hyena_output_layer_init_method, diff --git a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/hyena_block.py b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/hyena_block.py index 7ec67a4fba..93acf70e57 100644 --- a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/hyena_block.py +++ b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/hyena_block.py @@ -199,6 +199,9 @@ def __init__( hidden_size=self.transformer_config.hidden_size, eps=self.transformer_config.layernorm_epsilon, ) + else: + # Ensure final_norm is always defined to avoid AttributeError when post_process=False + self.final_norm = None # Required for activation recomputation self.num_layers_per_pipeline_rank = len(self.layers) diff --git a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/hyena_utils.py b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/hyena_utils.py index d0e452ef6d..955ea7fc5a 100644 --- a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/hyena_utils.py +++ b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/hyena_utils.py @@ -1005,9 +1005,11 @@ def get_filter_state(filter_name): ) # y = rearrange(y, "b d l -> b l d") else: - x1 = rearrange(x1, "1 d l -> l d") - x2 = rearrange(x2, "1 d l -> l d") - v = rearrange(v, "1 d l -> l d") + # Decode path: handle arbitrary batch size + # Input shapes: [b, d, l] where l=1 during decode + x1 = rearrange(x1, "b d l -> (b l) d") + x2 = rearrange(x2, "b d l -> (b l) d") + v = rearrange(v, "b d l -> (b l) d") x1, x2 = x2, x1 # TODO: figure why it is swapped y, iir_state = engine.step_iir( x2=x2, @@ -1301,6 +1303,8 @@ def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None): "conv_bias": 0, }, # parameters sharded across TP sharded_offsets=sharded_offsets, + tp_group=self.pg_collection.tp, + dp_cp_group=self.pg_collection.dp_cp, ) # Submodules for name, module in self.named_children(): diff --git a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/recipes/evo2.py b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/recipes/evo2.py index 947d00e15b..837286df94 100644 --- a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/recipes/evo2.py +++ b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/recipes/evo2.py @@ -29,7 +29,7 @@ TokenizerConfig, TrainingConfig, ) -from megatron.bridge.training.mixed_precision import MixedPrecisionConfig +from megatron.bridge.training.mixed_precision import MixedPrecisionConfig, get_mixed_precision_config from typing_extensions import TypedDict, Unpack from bionemo.evo2.data.evo2_dataset_provider import Evo2DatasetProvider @@ -166,6 +166,8 @@ def _evo2_common( checkpoint_dir = os.path.join(run_output_dir, "checkpoints") tensorboard_dir = os.path.join(run_output_dir, "tb_logs") wandb_save_dir = os.path.join(run_output_dir, "wandb") + if isinstance(precision_config, str): + precision_config = get_mixed_precision_config(precision_config) if mock: dataset_cfg_or_provider = MockEvo2DatasetProvider( random_seed=dataset_seed, diff --git a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/infer.py b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/infer.py index d68a102977..d6ebd79b13 100644 --- a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/infer.py +++ b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/infer.py @@ -16,222 +16,613 @@ # See the License for the specific language governing permissions and # limitations under the License. -# FIXME get this working with megatron bridge - -# import argparse -# import sys -# import time -# from typing import Literal, Optional - -# import nemo.lightning as nl -# import torch -# from megatron.core.inference.common_inference_params import CommonInferenceParams -# from megatron.core.inference.inference_request import InferenceRequest -# from nemo.collections.llm import inference -# from nemo.utils import logging - - -# CheckpointFormats = Literal["torch_dist", "zarr"] - - -# def parse_args(): -# """Parse arguments for Evo2 inference.""" -# ap = argparse.ArgumentParser() - -# # generation args: -# default_prompt = ( -# "|d__Bacteria;" -# + "p__Pseudomonadota;" -# + "c__Gammaproteobacteria;" -# + "o__Enterobacterales;" -# + "f__Enterobacteriaceae;" -# + "g__Escherichia;" -# + "s__Escherichia|" -# ) -# ap.add_argument( -# "--prompt", -# type=str, -# default=default_prompt, -# help="Prompt to generate text from Evo2. Defaults to a phylogenetic lineage tag for E coli.", -# ) -# ap.add_argument( -# "--ckpt-dir", type=str, required=True, help="Path to checkpoint directory containing pre-trained Evo2 model." -# ) -# ap.add_argument("--temperature", type=float, default=1.0, help="Temperature during sampling for generation.") -# ap.add_argument("--top-k", type=int, default=0, help="Top K during sampling for generation.") -# ap.add_argument("--top-p", type=float, default=0.0, help="Top P during sampling for generation.") -# ap.add_argument("--max-new-tokens", type=int, default=1024, help="Maximum number of tokens to generate.") -# ap.add_argument("--seed", type=int, default=None, help="Random seed for generation.") -# # compute args: -# ap.add_argument("--tensor-parallel-size", type=int, default=1, help="Order of tensor parallelism. Defaults to 1.") -# ap.add_argument( -# "--pipeline-model-parallel-size", type=int, default=1, help="Order of pipeline parallelism. Defaults to 1." -# ) -# ap.add_argument( -# "--context-parallel-size", type=int, default=1, help="Order of context parallelism. Defaults to 1." -# ) -# # output args: -# ap.add_argument( -# "--output-file", -# type=str, -# default=None, -# help="Output file containing the generated text produced by the Evo2 model. If not provided, the output will be logged.", -# ) -# # extra: -# ap.add_argument( -# "--ckpt-format", -# type=str, -# choices=["torch_dist", "zarr"], -# default="torch_dist", -# help="Specify checkpoint format to use. Defaults to 'torch_dist', as 'zarr' is deprecated.", -# ) -# ap.add_argument( -# "--fp8", -# action="store_true", -# default=False, -# help="Whether to use vortex style FP8. Defaults to False.", -# ) -# ap.add_argument( -# "--flash-decode", -# action="store_true", -# default=False, -# help="Whether to use flash decode. Defaults to True.", -# ) -# return ap.parse_args() - - -# def infer( -# prompt: str, -# ckpt_dir: str, -# temperature: float, -# top_k: int, -# top_p: float, -# max_new_tokens: int, -# tensor_parallel_size: int, -# pipeline_model_parallel_size: int, -# context_parallel_size: int, -# output_file: Optional[str] = None, -# ckpt_format: CheckpointFormats = "torch_dist", -# seed: Optional[int] = None, -# vortex_style_fp8: bool = False, -# flash_decode: bool = False, -# return_log_probs: bool = False, -# ) -> list[InferenceRequest]: -# """Inference workflow for Evo2. - -# Args: -# prompt (str): Prompt to generate text from Evo2. -# ckpt_dir (str): Path to checkpoint directory containing pre-trained Evo2 model. -# temperature (float): Temperature during sampling for generation. -# top_k (int): Top K during sampling for generation. -# top_p (float): Top P during sampling for generation. -# max_new_tokens (int): Maximum number of tokens to generate. -# tensor_parallel_size (int): Order of tensor parallelism. -# pipeline_model_parallel_size (int): Order of pipeline parallelism. -# context_parallel_size (int): Order of context parallelism. -# output_file (str): Output file containing the generated text produced by the Evo2 model. -# ckpt_format (CheckpointFormats): Checkpoint format to use. -# seed (int): Random seed for generation. -# vortex_style_fp8 (bool): Whether to use vortex style FP8. -# flash_decode (bool): Whether to use flash decode. -# return_log_probs (bool): Whether to return log probabilities. - -# Returns: -# None -# """ -# model_parallel_size = tensor_parallel_size * pipeline_model_parallel_size * context_parallel_size -# if model_parallel_size > torch.cuda.device_count(): -# raise ValueError( -# f"Requested model parallel size {model_parallel_size} is greater than the " -# f"number of available CUDA devices {torch.cuda.device_count()}" -# ) -# # Create PTL trainer. -# trainer = nl.Trainer( -# accelerator="gpu", -# devices=model_parallel_size, -# strategy=nl.MegatronStrategy( -# tensor_model_parallel_size=tensor_parallel_size, -# pipeline_model_parallel_size=pipeline_model_parallel_size, -# context_parallel_size=context_parallel_size, -# pipeline_dtype=torch.bfloat16, -# ckpt_load_optimizer=False, # Needs to be false for a normal model checkpoint. -# ckpt_save_optimizer=False, -# ckpt_async_save=False, -# save_ckpt_format=ckpt_format, -# ckpt_load_strictness="log_all", -# ), -# log_every_n_steps=1, -# limit_val_batches=10, -# num_sanity_val_steps=0, -# plugins=nl.MegatronMixedPrecision( -# precision="bf16-mixed", -# params_dtype=torch.bfloat16, -# ), -# ) -# inference_wrapped_model, mcore_tokenizer = inference.setup_model_and_tokenizer( -# path=ckpt_dir, -# trainer=trainer, -# params_dtype=torch.bfloat16, -# inference_batch_times_seqlen_threshold=8192, # TODO -# inference_max_seq_length=8192, # TODO -# recompute_granularity=None, -# recompute_num_layers=None, -# recompute_method=None, -# vortex_style_fp8=vortex_style_fp8, -# flash_decode=flash_decode, -# enable_flash_decode=flash_decode, -# ) -# t0 = time.perf_counter_ns() -# # TODO: fix return type in NeMo inference.generate (it is a list[InferenceRequest] not a dict) -# results: list[InferenceRequest] = inference.generate( -# model=inference_wrapped_model, -# max_batch_size=1, # vortex only supports batch size 1 -# tokenizer=mcore_tokenizer, -# prompts=[prompt], -# random_seed=seed, -# inference_params=CommonInferenceParams( -# temperature=temperature, -# top_k=top_k, -# top_p=top_p, -# return_log_probs=return_log_probs, -# num_tokens_to_generate=max_new_tokens, -# ), -# ) -# dt = (time.perf_counter_ns() - t0) / 1e9 # seconds -# tokens_per_sec = (len(results[0].generated_text) + 1) / dt # +1 for the prompt - -# print(f"Inference time: {dt} seconds, {tokens_per_sec} tokens/sec", file=sys.stderr) -# if torch.distributed.get_rank() == 0: -# if output_file is None: -# logging.info(results) -# else: -# with open(output_file, "w") as f: -# f.write(f"{results[0]}\n") - -# return results - - -# def main(): -# """Main function for Evo2 inference.""" -# # Parse args. -# args = parse_args() -# infer( -# prompt=args.prompt, -# ckpt_dir=args.ckpt_dir, -# temperature=args.temperature, -# top_k=args.top_k, -# top_p=args.top_p, -# max_new_tokens=args.max_new_tokens, -# tensor_parallel_size=args.tensor_parallel_size, -# pipeline_model_parallel_size=args.pipeline_model_parallel_size, -# context_parallel_size=args.context_parallel_size, -# output_file=args.output_file, -# ckpt_format=args.ckpt_format, -# seed=args.seed, -# vortex_style_fp8=args.fp8, # Vortex only applied FP8 to some layers. -# flash_decode=args.flash_decode, -# ) - - -# if __name__ == "__main__": -# main() +r"""Text generation (inference) workflow for Evo2 using Megatron Core. + +This module provides autoregressive text generation for Evo2 models using the +MCore inference infrastructure (StaticInferenceEngine, TextGenerationController). + +Based on: https://github.com/NVIDIA/Megatron-LM/blob/main/examples/inference/gpt/gpt_static_inference.py + +Usage (CLI): + torchrun --nproc_per_node 1 -m bionemo.evo2.run.infer \ + --ckpt-dir /path/to/mbridge/checkpoint \ + --prompt "|d__Bacteria;p__Pseudomonadota|" \ + --max-new-tokens 100 + +Usage (Python API): + from bionemo.evo2.run.infer import setup_inference_engine, generate + + # Setup engine (loads model, creates inference components) + engine, tokenizer = setup_inference_engine(ckpt_dir) + + # Generate text + results = generate(engine, prompts=["ATCGATCG"], max_new_tokens=100) +""" + +import argparse +import logging +import sys +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, List, Optional + +import torch +import torch.distributed as dist +from megatron.bridge.training.checkpointing import _load_model_weights_from_checkpoint +from megatron.bridge.training.config import DistributedInitConfig, RNGConfig +from megatron.bridge.training.mixed_precision import get_mixed_precision_config +from megatron.bridge.training.tokenizers.tokenizer import _HuggingFaceTokenizer +from megatron.bridge.training.utils.checkpoint_utils import ( + file_exists, + get_checkpoint_run_config_filename, + read_run_config, +) +from megatron.bridge.utils.instantiate_utils import instantiate +from megatron.core import parallel_state +from megatron.core.inference.contexts import StaticInferenceContext +from megatron.core.inference.engines.static_engine import StaticInferenceEngine +from megatron.core.inference.inference_request import InferenceRequest +from megatron.core.inference.model_inference_wrappers.abstract_model_inference_wrapper import ( + AbstractModelInferenceWrapper, +) +from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import ( + InferenceWrapperConfig, +) +from megatron.core.inference.sampling_params import SamplingParams +from megatron.core.inference.text_generation_controllers.text_generation_controller import ( + TextGenerationController, +) +from megatron.core.transformer.module import Float16Module +from megatron.core.utils import get_model_config + +from bionemo.evo2.data.dataset_tokenizer import DEFAULT_HF_TOKENIZER_MODEL_PATH +from bionemo.evo2.models.evo2_provider import HyenaInferenceContext +from bionemo.evo2.run.predict import initialize_inference_distributed, resolve_checkpoint_path + + +logger: logging.Logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + +# ============================================================================= +# Evo2 Model Inference Wrapper +# ============================================================================= + + +class Evo2ModelInferenceWrapper(AbstractModelInferenceWrapper): + """Inference wrapper for Evo2 models. + + Extends the abstract wrapper to provide Evo2-specific input preparation + and forward pass handling for autoregressive text generation. + """ + + def __init__( + self, + model: torch.nn.Module, + inference_wrapper_config: InferenceWrapperConfig, + inference_context: Optional[StaticInferenceContext] = None, + ): + """Initialize the Evo2 inference wrapper. + + Args: + model: The Evo2 model to wrap for inference. + inference_wrapper_config: Configuration with hidden size, vocab size, etc. + inference_context: Context for managing state and sequence offsets. + """ + super().__init__(model, inference_wrapper_config, inference_context) + + def prep_inference_input(self, prompts_tokens: torch.Tensor) -> Dict[str, Any]: + """Prepare the inference input data. + + Args: + prompts_tokens: A tensor of shape [batch_size, max_seq_len] + + Returns: + Dict with tokens, attention_mask, and position_ids + """ + batch_size, seq_len = prompts_tokens.shape + device = prompts_tokens.device + + # For Evo2/Hyena models, position_ids are sequential + position_ids = torch.arange(seq_len, dtype=torch.long, device=device).unsqueeze(0).expand(batch_size, -1) + + # Evo2 uses causal attention - for flash attention backend, mask is None + attention_mask = None + + return { + "tokens": prompts_tokens, + "attention_mask": attention_mask, + "position_ids": position_ids, + } + + def get_batch_for_context_window( + self, + inference_input: Dict[str, Any], + context_start_position: int, + context_end_position: int, + ) -> Dict[str, Any]: + """Extract batch for a specific context window. + + Called iteratively during autoregressive generation. + + Args: + inference_input: Full inference input dict + context_start_position: Start of context window + context_end_position: End of context window + + Returns: + Dict with sliced tokens, positions, and attention mask + """ + tokens = inference_input["tokens"] + position_ids = inference_input["position_ids"] + attention_mask = inference_input["attention_mask"] + + tokens2use = tokens[:, context_start_position:context_end_position] + positions2use = position_ids[:, context_start_position:context_end_position] + + if attention_mask is not None: + attention_mask2use = attention_mask[ + ..., context_start_position:context_end_position, :context_end_position + ] + else: + attention_mask2use = None + + return { + "tokens": tokens2use, + "position_ids": positions2use, + "attention_mask": attention_mask2use, + } + + def _forward(self, inference_input: Dict[str, Any]) -> torch.Tensor: + """Run a forward pass of the model. + + Override to pass HyenaInferenceContext properly. + + Args: + inference_input: The input data dict. + + Returns: + The model output logits. + """ + tokens = inference_input["tokens"] + position_ids = inference_input["position_ids"] + attention_mask = inference_input["attention_mask"] + + return self.model( + tokens, + position_ids, + attention_mask, + inference_context=self.inference_context, + runtime_gather_output=True, + ) + + +# ============================================================================= +# Inference Components Container +# ============================================================================= + + +@dataclass +class Evo2InferenceComponents: + """Container for Evo2 inference components. + + This dataclass holds all the components needed for text generation, + making it easy to pass around and reuse. + """ + + inference_engine: StaticInferenceEngine + tokenizer: _HuggingFaceTokenizer + inference_wrapper: Evo2ModelInferenceWrapper + inference_context: HyenaInferenceContext + model: torch.nn.Module + + +# ============================================================================= +# Public API: Setup and Generate Functions +# ============================================================================= + + +def setup_inference_engine( + ckpt_dir: Path, + *, + max_seq_length: int = 8192, + max_batch_size: int = 1, + tensor_parallel_size: int = 1, + pipeline_model_parallel_size: int = 1, + context_parallel_size: int = 1, + mixed_precision_recipe: Optional[str] = None, + random_seed: int = 1234, +) -> Evo2InferenceComponents: + """Setup the Evo2 inference engine and related components. + + This function loads the model, creates the inference wrapper, and sets up + all necessary components for text generation. + + Args: + ckpt_dir: Path to MBridge checkpoint directory. + max_seq_length: Maximum sequence length for generation. + max_batch_size: Maximum batch size for inference. + tensor_parallel_size: Tensor parallelism degree. + pipeline_model_parallel_size: Pipeline parallelism degree (must be 1). + context_parallel_size: Context parallelism degree. + mixed_precision_recipe: Override mixed precision recipe. + random_seed: Random seed for reproducibility. + + Returns: + Evo2InferenceComponents containing all inference components. + + Example: + >>> components = setup_inference_engine(Path("/path/to/checkpoint"), max_batch_size=4) + >>> results = generate(components, prompts=["ATCG", "GCTA"], max_new_tokens=100) + """ + if pipeline_model_parallel_size != 1: + raise ValueError("Pipeline parallelism > 1 is not supported for inference.") + + # ------------------------------------------------------------------------- + # Step 1: Load configuration from checkpoint + # ------------------------------------------------------------------------- + resolved_ckpt_dir = resolve_checkpoint_path(ckpt_dir) + logger.info(f"Loading configuration from checkpoint: {resolved_ckpt_dir}") + + run_config_filename = get_checkpoint_run_config_filename(str(resolved_ckpt_dir)) + if not file_exists(run_config_filename): + raise FileNotFoundError(f"run_config.yaml not found at {run_config_filename}") + + run_config = read_run_config(run_config_filename) + model_provider = instantiate(run_config["model"]) + logger.info(f"Instantiated model provider: {type(model_provider).__name__}") + + # ------------------------------------------------------------------------- + # Step 2: Configure parallelism and precision + # ------------------------------------------------------------------------- + model_provider.tensor_model_parallel_size = tensor_parallel_size + model_provider.pipeline_model_parallel_size = pipeline_model_parallel_size + model_provider.context_parallel_size = context_parallel_size + model_provider.sequence_parallel = tensor_parallel_size > 1 + + # Enable flash decode for inference + model_provider.flash_decode = True + + # Use bf16_mixed for inference to avoid FP8 issues + if mixed_precision_recipe is not None: + mp_config = get_mixed_precision_config(mixed_precision_recipe) + else: + mp_config = get_mixed_precision_config("bf16_mixed") + + mp_config.finalize() + mp_config.setup(model_provider) + + # ------------------------------------------------------------------------- + # Step 3: Load tokenizer + # ------------------------------------------------------------------------- + tokenizer_dir = resolved_ckpt_dir / "tokenizer" + if tokenizer_dir.exists(): + tokenizer = _HuggingFaceTokenizer(tokenizer_dir) + else: + tokenizer = _HuggingFaceTokenizer(DEFAULT_HF_TOKENIZER_MODEL_PATH) + + model_provider.vocab_size = tokenizer.vocab_size + model_provider.should_pad_vocab = True + + # ------------------------------------------------------------------------- + # Step 4: Initialize distributed environment + # ------------------------------------------------------------------------- + rng_config = instantiate(run_config.get("rng")) if run_config.get("rng") else RNGConfig(seed=random_seed) + dist_config = instantiate(run_config.get("dist")) if run_config.get("dist") else DistributedInitConfig() + + from megatron.bridge.utils.common_utils import get_world_size_safe + + model_parallel_size = tensor_parallel_size * pipeline_model_parallel_size * context_parallel_size + world_size = get_world_size_safe() + data_parallel_size = world_size // model_parallel_size + + initialize_inference_distributed( + tensor_model_parallel_size=tensor_parallel_size, + pipeline_model_parallel_size=pipeline_model_parallel_size, + context_parallel_size=context_parallel_size, + micro_batch_size=max_batch_size, + global_batch_size=max_batch_size * data_parallel_size, + rng_config=rng_config, + dist_config=dist_config, + ) + logger.info("Initialized distributed environment") + + # ------------------------------------------------------------------------- + # Step 5: Create model and load weights + # ------------------------------------------------------------------------- + logger.info("Creating model...") + model_provider.finalize() + + raw_model = model_provider.provide(pre_process=True, post_process=True).eval().cuda() + + logger.info(f"Loading weights from: {resolved_ckpt_dir}") + _load_model_weights_from_checkpoint( + checkpoint_path=str(resolved_ckpt_dir), + model=[raw_model], + dist_ckpt_strictness="ignore_all", + ) + logger.info("Weights loaded successfully") + + # Wrap with Float16Module + model = Float16Module(model_provider, raw_model) + + # ------------------------------------------------------------------------- + # Step 6: Setup MCore inference infrastructure + # ------------------------------------------------------------------------- + # Create inference wrapper config + model_config = get_model_config(raw_model) + inference_wrapper_config = InferenceWrapperConfig( + hidden_size=model_config.hidden_size, + inference_max_requests=max_batch_size, + inference_max_seq_length=max_seq_length, + inference_batch_times_seqlen_threshold=max_seq_length * max_batch_size, + params_dtype=torch.bfloat16, + padded_vocab_size=tokenizer.vocab_size, + ) + + # Create Hyena-specific inference context + inference_context = HyenaInferenceContext( + max_batch_size=max_batch_size, + max_sequence_length=max_seq_length, + ) + # Don't materialize only last token - we need full logits for sampling + inference_context.materialize_only_last_token_logits = False + + # Create the inference wrapper + inference_wrapper = Evo2ModelInferenceWrapper( + model=model, + inference_wrapper_config=inference_wrapper_config, + inference_context=inference_context, + ) + + # Create the text generation controller + text_generation_controller = TextGenerationController( + inference_wrapped_model=inference_wrapper, + tokenizer=tokenizer, + ) + + # Create the static inference engine (using legacy mode for simplicity) + inference_engine = StaticInferenceEngine( + text_generation_controller=text_generation_controller, + max_batch_size=max_batch_size, + random_seed=random_seed, + legacy=True, # Use legacy static engine + ) + + return Evo2InferenceComponents( + inference_engine=inference_engine, + tokenizer=tokenizer, + inference_wrapper=inference_wrapper, + inference_context=inference_context, + model=model, + ) + + +def generate( + components: Evo2InferenceComponents, + prompts: List[str], + *, + max_new_tokens: int = 100, + temperature: float = 1.0, + top_k: int = 0, + top_p: float = 0.0, + return_log_probs: bool = False, +) -> List[InferenceRequest]: + """Generate text using the Evo2 inference engine. + + Args: + components: Inference components from setup_inference_engine. + prompts: List of prompt strings to generate from. + max_new_tokens: Maximum number of tokens to generate. + temperature: Sampling temperature (higher = more random). + top_k: Top-k sampling parameter (0 = disabled, 1 = greedy). + top_p: Nucleus sampling parameter (0 = disabled). + return_log_probs: Whether to return log probabilities. + + Returns: + List of InferenceRequest objects containing generated text and metadata. + + Example: + >>> components = setup_inference_engine(ckpt_dir) + >>> results = generate(components, ["ATCGATCG"], max_new_tokens=50, top_k=1) + >>> print(results[0].generated_text) + """ + # Reset inference context before generation + components.inference_context.reset() + + sampling_params = SamplingParams( + temperature=temperature, + top_k=max(0, top_k), + top_p=top_p if top_p > 0 else 0.0, + num_tokens_to_generate=max_new_tokens, + return_log_probs=return_log_probs, + ) + + results = components.inference_engine.generate( + prompts=prompts, + sampling_params=sampling_params, + ) + + # Reset context after generation + components.inference_context.reset() + + return results + + +# ============================================================================= +# CLI: Full Inference Workflow +# ============================================================================= + + +def parse_args() -> argparse.Namespace: + """Parse command-line arguments for Evo2 inference. + + Returns: + Parsed arguments namespace + """ + default_prompt = ( + "|d__Bacteria;" + + "p__Pseudomonadota;" + + "c__Gammaproteobacteria;" + + "o__Enterobacterales;" + + "f__Enterobacteriaceae;" + + "g__Escherichia;" + + "s__Escherichia|" + ) + + ap = argparse.ArgumentParser( + description="Generate text with Evo2 models using MCore inference", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + # Required arguments + ap.add_argument( + "--ckpt-dir", + type=Path, + required=True, + help="Path to MBridge checkpoint directory", + ) + + # Generation arguments + ap.add_argument( + "--prompt", + type=str, + default=default_prompt, + help="Prompt text for generation", + ) + ap.add_argument("--max-new-tokens", type=int, default=100, help="Maximum tokens to generate") + ap.add_argument("--temperature", type=float, default=1.0, help="Sampling temperature") + ap.add_argument("--top-k", type=int, default=0, help="Top-k sampling (0 = disabled)") + ap.add_argument("--top-p", type=float, default=0.0, help="Top-p nucleus sampling (0 = disabled)") + ap.add_argument("--seed", type=int, default=None, help="Random seed") + + # Parallelism arguments + ap.add_argument("--tensor-parallel-size", type=int, default=1, help="Tensor parallelism") + ap.add_argument("--pipeline-model-parallel-size", type=int, choices=[1], default=1, help="Pipeline parallelism") + ap.add_argument("--context-parallel-size", type=int, default=1, help="Context parallelism") + + # Output arguments + ap.add_argument("--output-file", type=Path, default=None, help="Save generated text to file") + + # Precision arguments + ap.add_argument("--mixed-precision-recipe", type=str, default=None, help="Override precision recipe") + + # Model arguments + ap.add_argument("--max-seq-length", type=int, default=8192, help="Max sequence length") + + return ap.parse_args() + + +def infer( + prompt: str, + ckpt_dir: Path, + *, + max_new_tokens: int = 100, + temperature: float = 1.0, + top_k: int = 0, + top_p: float = 0.0, + seed: Optional[int] = None, + tensor_parallel_size: int = 1, + pipeline_model_parallel_size: int = 1, + context_parallel_size: int = 1, + output_file: Optional[Path] = None, + mixed_precision_recipe: Optional[str] = None, + max_seq_length: int = 8192, +) -> str: + """Run autoregressive text generation with Evo2 using MCore inference. + + This is the main CLI entry point that sets up everything and runs inference. + For programmatic usage, prefer setup_inference_engine + generate. + + Args: + prompt: Input text prompt for generation. + ckpt_dir: Path to MBridge checkpoint directory. + max_new_tokens: Maximum number of tokens to generate. + temperature: Sampling temperature (higher = more random). + top_k: Top-k sampling parameter (0 = disabled). + top_p: Nucleus sampling parameter (0 = disabled). + seed: Random seed for reproducibility. + tensor_parallel_size: Tensor parallelism degree. + pipeline_model_parallel_size: Pipeline parallelism degree (must be 1). + context_parallel_size: Context parallelism degree. + output_file: Optional path to save generated text. + mixed_precision_recipe: Override mixed precision recipe. + max_seq_length: Maximum sequence length. + + Returns: + The generated text string. + """ + random_seed = seed or 1234 + + # Setup inference components + components = setup_inference_engine( + ckpt_dir=ckpt_dir, + max_seq_length=max_seq_length, + tensor_parallel_size=tensor_parallel_size, + pipeline_model_parallel_size=pipeline_model_parallel_size, + context_parallel_size=context_parallel_size, + mixed_precision_recipe=mixed_precision_recipe, + random_seed=random_seed, + ) + + logger.info(f"Generating from prompt: {prompt[:50]}...") + + # Generate + results = generate( + components, + prompts=[prompt], + max_new_tokens=max_new_tokens, + temperature=temperature, + top_k=top_k, + top_p=top_p, + ) + + # Extract generated text + generated_text = results[0].generated_text if results else "" + + # Output results + if parallel_state.get_data_parallel_rank() == 0: + print(f"\n=== Generated Text ===\n{generated_text}\n", file=sys.stdout) + + if output_file is not None: + output_file.parent.mkdir(parents=True, exist_ok=True) + with open(output_file, "w") as f: + f.write(generated_text) + logger.info(f"Saved generated text to: {output_file}") + + logger.info("Inference complete!") + + # Cleanup + if dist.is_initialized(): + dist.barrier() + dist.destroy_process_group() + + return generated_text + + +# ============================================================================= +# Entry Point +# ============================================================================= + + +def main() -> None: + """CLI entry point for Evo2 text generation.""" + args = parse_args() + infer( + prompt=args.prompt, + ckpt_dir=args.ckpt_dir, + max_new_tokens=args.max_new_tokens, + temperature=args.temperature, + top_k=args.top_k, + top_p=args.top_p, + seed=args.seed, + tensor_parallel_size=args.tensor_parallel_size, + pipeline_model_parallel_size=args.pipeline_model_parallel_size, + context_parallel_size=args.context_parallel_size, + output_file=args.output_file, + mixed_precision_recipe=args.mixed_precision_recipe, + max_seq_length=args.max_seq_length, + ) + + +if __name__ == "__main__": + main() diff --git a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/infer_example_simple.py b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/infer_example_simple.py new file mode 100644 index 0000000000..187daf1306 --- /dev/null +++ b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/infer_example_simple.py @@ -0,0 +1,166 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2024 Arc Institute. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2024 Michael Poli. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2024 Stanford University. All rights reserved +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Simple autoregressive generation example for Evo2 models. + +This module provides a straightforward implementation of autoregressive text generation +that directly calls the model forward pass without using the full MCore inference +infrastructure. This is useful for: + +1. Understanding how autoregressive generation works at a low level +2. Debugging and testing model behavior +3. Custom generation workflows that don't fit the MCore API + +For production use, prefer the MCore-based inference in `bionemo.evo2.run.infer`. + +Example: + >>> from bionemo.evo2.run.infer_example_simple import generate_tokens_simple + >>> from bionemo.evo2.models.evo2_provider import HyenaInferenceContext + >>> + >>> # Assuming model and tokenizer are already loaded + >>> ctx = HyenaInferenceContext(max_batch_size=1, max_sequence_length=8192) + >>> prompt_tokens = tokenizer.text_to_ids("ATCGATCG") + >>> tokens = generate_tokens_simple(model, prompt_tokens, max_new_tokens=100, inference_context=ctx) +""" + +from typing import List, Optional + +import torch + +from bionemo.evo2.models.evo2_provider import HyenaInferenceContext + + +@torch.inference_mode() +def generate_tokens_simple( + model: torch.nn.Module, + prompt_tokens: torch.Tensor, + max_new_tokens: int, + temperature: float = 1.0, + top_k: int = 0, + inference_context: Optional[HyenaInferenceContext] = None, +) -> List[int]: + """Generate tokens autoregressively using direct model forward passes. + + This function implements autoregressive generation by repeatedly calling + the model's forward pass with the previously generated token. It properly + manages the HyenaInferenceContext to cache SSM state between steps. + + Unlike the MCore-based inference in `bionemo.evo2.run.infer`, this function: + - Directly calls model.forward() instead of using inference wrappers + - Manually manages sequence_len_offset and decode_mode + - Does not use TextGenerationController or StaticInferenceEngine + + Args: + model: The Evo2 model (typically Float16Module wrapped). + prompt_tokens: Input prompt token IDs as a tensor of shape [1, seq_len]. + max_new_tokens: Maximum number of tokens to generate. + temperature: Sampling temperature. Higher values (e.g., 1.0) make output + more random, lower values make it more deterministic. Default is 1.0. + top_k: Top-k sampling parameter. If > 0, only the top k tokens are + considered for sampling. Use top_k=1 for greedy decoding. Default is 0. + inference_context: Hyena-specific context for SSM state caching. + If provided, enables efficient autoregressive generation by caching + filter states between decode steps. + + Returns: + List of generated token IDs (excluding the prompt). + + Example: + >>> # Setup + >>> ctx = HyenaInferenceContext(max_batch_size=1, max_sequence_length=8192) + >>> prompt = torch.tensor([[65, 84, 67, 71]], device="cuda") # "ATCG" + >>> + >>> # Generate with greedy decoding + >>> tokens = generate_tokens_simple( + ... model, prompt, max_new_tokens=10, top_k=1, inference_context=ctx + ... ) + >>> print(tokens) # [65, 84, 67, 71, ...] (continues the pattern) + + Note: + For production use, prefer `bionemo.evo2.run.infer.generate()` which uses + the MCore inference infrastructure with proper batching, sampling, and + distributed support. + """ + device = prompt_tokens.device + generated_tokens: List[int] = [] + prompt_len = prompt_tokens.shape[1] + + # Ensure context starts in prefill mode + if inference_context is not None: + inference_context.enable_prefill_mode() + + # Process the full prompt first (prefill phase) + # This computes and caches the SSM states for all prompt tokens + logits = model( + input_ids=prompt_tokens, + position_ids=None, + attention_mask=None, + labels=None, + runtime_gather_output=True, + inference_context=inference_context, + ) + + # Update sequence_len_offset after prefill (MCore wrapper does this automatically) + if inference_context is not None: + inference_context.increment_sequence_len_offset(prompt_len) + # Switch to decode mode after prefill is complete + inference_context.enable_decode_mode() + + # Get next token from last position logits + next_token_logits = logits[0, -1, :].clone() + + # Generate tokens autoregressively (decode phase) + for _ in range(max_new_tokens): + # Apply temperature scaling + if temperature > 0: + next_token_logits = next_token_logits / temperature + + # Apply top-k filtering + if top_k > 0: + indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None] + next_token_logits[indices_to_remove] = float("-inf") + + # Sample or argmax + if temperature > 0 and top_k != 1: + probs = torch.softmax(next_token_logits, dim=-1) + next_token = torch.multinomial(probs, num_samples=1).item() + else: + next_token = torch.argmax(next_token_logits).item() + + generated_tokens.append(next_token) + + # Prepare next input (single token) + next_input = torch.tensor([[next_token]], dtype=torch.long, device=device) + + # Forward pass with cached state + logits = model( + input_ids=next_input, + position_ids=None, + attention_mask=None, + labels=None, + runtime_gather_output=True, + inference_context=inference_context, + ) + + # Update sequence_len_offset after each decode step + if inference_context is not None: + inference_context.increment_sequence_len_offset(1) + + next_token_logits = logits[0, -1, :].clone() + + return generated_tokens diff --git a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/predict.py b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/predict.py index 968a9f52f8..e65d4cdf40 100644 --- a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/predict.py +++ b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/predict.py @@ -16,697 +16,1407 @@ # See the License for the specific language governing permissions and # limitations under the License. -# FIXME get this working with megatron bridge - -# import argparse -# import functools -# import tempfile -# from pathlib import Path -# from typing import Any, Literal - -# import nemo.lightning as nl -# import torch -# from bionemo.evo2.data.fasta_dataset import SimpleFastaDataset -# from bionemo.evo2.models.llama import LLAMA_MODEL_OPTIONS - -# # Add import for Mamba models -# from bionemo.evo2.models.mamba import MAMBA_MODEL_OPTIONS, MambaModel -# from bionemo.evo2.models.peft import Evo2LoRA -# from bionemo.evo2.run.utils import infer_model_type, patch_eden_tokenizer -# from bionemo.llm.data import collate -# from bionemo.llm.lightning import LightningPassthroughPredictionMixin -# from bionemo.llm.utils.callbacks import PredictionWriter -# from lightning.pytorch import LightningDataModule -# from megatron.core import parallel_state -# from megatron.core.enums import Fp8Recipe -# from megatron.core.tensor_parallel.mappings import _gather_along_last_dim -# from megatron.core.utils import get_batch_on_this_cp_rank -# from nemo.collections.llm.gpt.data.megatron.hyena.evo2_dataset import Evo2Dataset -# from nemo.collections.llm.gpt.model.base import GPTModel, get_packed_seq_params -# from nemo.collections.llm.gpt.model.hyena import HYENA_MODEL_OPTIONS, HyenaModel -# from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer -# from nemo.lightning import NeMoLogger -# from nemo.lightning.data import WrappedDataLoader -# from nemo.utils import logging as logger -# from torch import Tensor - - -# CheckpointFormats = Literal["torch_dist", "zarr"] - -# SHUFFLE_MESSAGE = ( -# "Per token log probabilities are not supported when using context parallelism. The results will be " -# "zigzag shuffled along the sequence dimension. Raise a feature request if you need this and do " -# "not want to manually do the unshuffling yourself. You need to undo the shuffling that happened in " -# "`megatron.core.utils.get_batch_on_this_cp_rank`." -# ) - - -# def parse_args(): -# """Parse arguments for Evo2 inference.""" -# ap = argparse.ArgumentParser() -# ap.add_argument("--num-nodes", type=int, default=1, help="Number of nodes to use for prediction, defaults to 1.") -# ap.add_argument( -# "--devices", -# type=int, -# help="Number of devices to use for prediction, defaults to tensor_model_parallel_size * pipeline_model_parallel_size * context_parallel_size.", -# ) -# ap.add_argument( -# "--eden-tokenizer", -# action="store_true", -# help="Patch the tokenizer to work with the one used in training the Eden model.", -# ) -# ap.add_argument("--fasta", type=Path, required=True, help="Fasta path from which to generate logit predictions.") -# ap.add_argument("--ckpt-dir", type=Path, required=True, help="NeMo2 checkpoint directory for inference.") -# ap.add_argument("--min-length", type=int, required=False, help="Minimum sequence length for padding.") -# ap.add_argument("--prepend-bos", action="store_true", help="Prepend BOS token to sequences. Defaults to False.") -# ap.add_argument( -# "--mask-phylogenetic-tags", -# action="store_true", -# help="Mask phylogenetic tags in loss computation. Defaults to False.", -# ) -# ap.add_argument("--tensor-parallel-size", type=int, default=1, help="Order of tensor parallelism. Defaults to 1.") -# ap.add_argument( -# "--pipeline-model-parallel-size", -# type=int, -# choices=[1], -# default=1, -# help="Order of pipeline parallelism. Defaults to 1 and currently only 1 is supported.", -# ) -# ap.add_argument( -# "--context-parallel-size", type=int, default=1, help="Order of context parallelism. Defaults to 1." -# ) -# ap.add_argument( -# "--fp8-recipe", -# type=str, -# default="delayed", -# choices=list(Fp8Recipe.__members__.keys()), -# help="FP8 recipe to use for FP8 tensors in the forward and backward pass. Note that some recipes are only " -# "supported by certain architectures. For example 'mxfp8' requires at least blackwell, and 'blockwise' is only " -# "implemented for hopper (but not blackwell). 'tensorwise' and 'delayed' are currently supported by all " -# "architectures, but 'tensorwise' is preferred over 'delayed' which is the default for historical reasons.", -# ) -# ap.add_argument( -# "--no-sequence-parallel", -# action="store_true", -# help="When using TP, skip sequence parallelism. Otherwise sequence parallelism is used whenever tensor " -# "parallelism is used. sequence parallelism should save a small amount of GPU memory so it's on" -# " by default.", -# ) -# ap.add_argument("--micro-batch-size", type=int, default=1, help="Batch size for prediction. Defaults to 1.") -# ap.add_argument( -# "--write-interval", -# type=str, -# default="epoch", -# choices=["epoch", "batch"], -# help="Interval to write predictions to disk. If doing very large predictions, you may want to set this to 'batch'.", -# ) -# ap.add_argument( -# "--model-size", -# type=str, -# default="7b_arc_longcontext", -# choices=sorted( -# list(HYENA_MODEL_OPTIONS.keys()) + list(MAMBA_MODEL_OPTIONS.keys()) + list(LLAMA_MODEL_OPTIONS.keys()) -# ), -# help="Model size to use. Defaults to '7b_arc_longcontext'.", -# ) -# # output args: -# ap.add_argument( -# "--output-dir", -# type=Path, -# default=None, -# help="Output dir that will contain the generated text produced by the Evo2 model. If not provided, the output will be logged.", -# ) -# ap.add_argument( -# "--files-per-subdir", -# type=int, -# help="Number of files to write to each subdirectory. If provided, subdirectories with N files each will be created. Ignored unless --write-interval is 'batch'.", -# ) -# ap.add_argument( -# "--full-fp8", -# action="store_true", -# help="Use full FP8 precision (faster but less accurate) rather than vortex style which " -# "only applies FP8 to the projection layer of the hyena mixer, when using FP8.", -# ) -# ap.add_argument("--fp8", action="store_true", help="Use FP8 precision. Defaults to BF16.") -# # extra: -# ap.add_argument( -# "--ckpt-format", -# type=str, -# choices=["torch_dist", "zarr"], -# default="torch_dist", -# help="Specify checkpoint format to use. Defaults to 'torch_dist', as 'zarr' is deprecated.", -# ) -# ap.add_argument( -# "--output-log-prob-seqs", action="store_true", help="Output log probability of sequences. Defaults to False." -# ) -# ap.add_argument( -# "--log-prob-collapse-option", -# choices=["sum", "mean", "per_token"], -# default="mean", -# help="How to collapse the log probabilities across the sequence dimension.", -# ) -# ap.add_argument( -# "--hybrid-override-pattern", -# type=str, -# help="Override the hybrid override pattern in the config (specifies hyena layer ordering and type).", -# ) -# ap.add_argument( -# "--num-layers", type=int, help="If set, override the number of layers specified in the requested config." -# ) -# ap.add_argument( -# "--seq-len-interpolation-factor", -# type=int, -# help="If set, override the sequence length interpolation factor specified in the requested config. If you " -# "know a model was trained with a specific interpolation factor for ROPE, provide it here, it can make a big " -# "difference in accuracy.", -# ) -# ap.add_argument( -# "--lora-checkpoint-path", -# type=Path, -# required=False, -# default=None, -# help="Path to the lora states to restore from.", -# ) -# return ap.parse_args() - - -# def _gather_along_cp_dim(input_, seq_dim: int = 1): -# """Gather tensors and concatenate along the last dimension.""" -# world_size = parallel_state.get_context_parallel_world_size() -# # Bypass the function if we are using only 1 GPU. -# if world_size == 1: -# return input_ - -# dim_size = list(input_.size()) -# dim_size[0] = dim_size[0] * world_size - -# output = torch.empty(dim_size, dtype=input_.dtype, device=torch.cuda.current_device()) -# # TODO: handle zigzag packing here. Currently this just gathers along ranks, but if you want to see the sequence in -# # the original order you need to undo the zigzag packing that happens in -# # `megatron.core.utils.get_batch_on_this_cp_rank`. -# torch.distributed.all_gather_into_tensor( -# output, input_.contiguous(), group=parallel_state.get_context_parallel_group() -# ) -# tensor_list = output.chunk(world_size, dim=0) -# output = torch.cat(tensor_list, dim=seq_dim).contiguous() - -# return output - - -# def _to_cpu(inputs: dict[str, Tensor]) -> dict[str, Tensor]: -# return {k: v.cpu() for k, v in inputs.items()} - - -# def _identity(inputs: dict[str, Tensor]) -> dict[str, Tensor]: -# return inputs - - -# class BasePredictor(LightningPassthroughPredictionMixin): -# """Base predictor for GPT-style models.""" - -# def __init__( -# self, -# *args, -# output_log_prob_seqs: bool = False, -# include_tokens_with_logprob_seqs: bool = False, -# log_prob_collapse_option: Literal["sum", "mean", "per_token"] = "mean", -# **kwargs, -# ): -# """Initialize the base predictor with arguments needed for writing predictions.""" -# super().__init__(*args, **kwargs) -# self.output_log_prob_seqs = output_log_prob_seqs -# self.log_prob_collapse_option = log_prob_collapse_option -# self.include_tokens_with_logprob_seqs = include_tokens_with_logprob_seqs -# self.shuffle_warning_raised = False - -# def predict_step( -# self, batch, batch_idx: int | None = None, to_cpu: bool = True -# ) -> Tensor | dict[str, Tensor] | None: -# """Alias for forward_step, also log the pad mask since sequences may not all have the same length.""" -# if len(batch) == 0: -# return -# assert self.training is False, "predict_step should be called in eval mode" -# with torch.no_grad(): -# forward_out = self.forward_step(batch) -# if not parallel_state.is_pipeline_last_stage(): -# return None -# # Reminder: the model's predictions for input i land at output i+1. To get everything to align, we prepend the -# # EOS token to the input sequences and take the outputs for all but the first token. -# forward_out_tp_gathered = _gather_along_last_dim( -# forward_out, group=parallel_state.get_tensor_model_parallel_group() -# ) - -# forward_out_gathered = _gather_along_cp_dim(forward_out_tp_gathered) -# loss_mask_gathered = _gather_along_cp_dim(batch["loss_mask"]) -# tokens_gathered = _gather_along_cp_dim(batch["tokens"]) -# cp_group_size = max(parallel_state.get_context_parallel_world_size(), 1) -# assert self.tokenizer.vocab_size == forward_out_gathered.shape[-1] -# to_cpu_fn = _to_cpu if to_cpu else _identity -# if self.output_log_prob_seqs: -# if self.log_prob_collapse_option == "per_token" and cp_group_size > 1 and not self.shuffle_warning_raised: -# logger.warning(SHUFFLE_MESSAGE) -# self.shuffle_warning_raised = True -# softmax_logprobs = torch.log_softmax(forward_out_gathered, dim=-1) -# softmax_logprobs = softmax_logprobs[:, :-1] -# input_ids = tokens_gathered[:, 1:] -# if softmax_logprobs.shape[1] != input_ids.shape[1]: -# raise RuntimeError( -# f"Softmax logprobs shape {softmax_logprobs.shape} does not match input ids shape {input_ids.shape}" -# ) - -# logprobs = torch.gather( -# softmax_logprobs, # Gather likelihoods... -# 2, # along the vocab dimension... -# input_ids.unsqueeze(-1), # using the token ids to index. -# ).squeeze(-1) -# log_prob_per_token = logprobs * loss_mask_gathered[:, 1:].float() -# if self.log_prob_collapse_option == "per_token": -# return to_cpu_fn( -# { -# "log_probs_seqs": log_prob_per_token, -# "seq_idx": batch["seq_idx"], -# "loss_mask": loss_mask_gathered[:, 1:], -# } -# ) -# else: -# log_prob_seqs = torch.sum(log_prob_per_token, dim=1) -# if self.log_prob_collapse_option == "mean": -# log_prob_seqs = log_prob_seqs / torch.clamp(loss_mask_gathered[:, 1:].float().sum(dim=-1), min=1.0) -# return to_cpu_fn({"log_probs_seqs": log_prob_seqs, "seq_idx": batch["seq_idx"]}) -# else: -# # If the user wants to match back to logits, then they will need to do the offsetting logic themselves. -# if cp_group_size > 1 and not self.shuffle_warning_raised: -# logger.warning(SHUFFLE_MESSAGE) -# self.shuffle_warning_raised = True -# logprob_seqs_result = { -# "token_logits": forward_out_gathered, -# "pad_mask": loss_mask_gathered, -# "seq_idx": batch["seq_idx"], -# } -# if self.include_tokens_with_logprob_seqs: -# logprob_seqs_result["tokens"] = tokens_gathered -# # Note, to match up tokens with logprobs, you need to offset by 1. Eg something like this: -# # shifted_token_logits = token_logits[:, :-1] -# # shifted_pad_mask = pad_mask[:, 1:] -# # shifted_tokens = tokens[:, 1:] -# return to_cpu_fn(logprob_seqs_result) - - -# class HyenaPredictor(BasePredictor, HyenaModel): -# """A predictor for the Hyena model. This adds in the predict step and the passthrough method.""" - -# def configure_model(self, *args, **kwargs) -> None: -# """Configure the model.""" -# super().configure_model(*args, **kwargs) -# self.trainer.strategy._init_model_parallel = True - - -# class MambaPredictor(BasePredictor, MambaModel): -# """Mamba model for prediction with additional metrics.""" - - -# class LlamaPredictor(BasePredictor, GPTModel): -# """Llama model for prediction with additional metrics.""" - - -# def hyena_predict_forward_step(model, batch) -> torch.Tensor: -# """Performs a forward step for the Hyena model. - -# Args: -# model: The Hyena model -# batch: Dictionary containing input batch data with keys: -# - tokens: Input token IDs -# - position_ids: Position IDs -# - labels: Labels for loss computation -# - loss_mask: Mask for loss computation - -# Returns: -# torch.Tensor: Output from the model forward pass -# """ -# forward_args = { -# "input_ids": batch["tokens"], -# "position_ids": batch["position_ids"], -# # "labels": batch["labels"], -# # "loss_mask": batch["loss_mask"], -# } - -# forward_args["attention_mask"] = None -# if "cu_seqlens" in batch: -# forward_args["packed_seq_params"] = get_packed_seq_params(batch) -# return model(**forward_args) - - -# def hyena_predict_data_step(dataloader_iter) -> dict[str, torch.Tensor]: -# """Data step for the Hyena model prediction. Modified from the original gpt data step to include the seq_idx.""" -# from megatron.core import parallel_state - -# # Based on: https://github.com/NVIDIA/Megatron-LM/blob/main/pretrain_gpt.py#L87 -# # https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py#L828-L842 - -# batch = next(dataloader_iter) - -# _batch: dict -# if isinstance(batch, tuple) and len(batch) == 3: -# _batch = batch[0] -# else: -# _batch = batch - -# required_device_keys = set() -# required_host_keys = set() - -# required_device_keys.add("attention_mask") -# if "cu_seqlens" in _batch: -# required_device_keys.add("cu_seqlens") -# required_host_keys.add("cu_seqlens_argmin") -# required_host_keys.add("max_seqlen") - -# if parallel_state.is_pipeline_first_stage(): -# required_device_keys.update(("tokens", "position_ids")) -# include_seq_idx = False -# if parallel_state.is_pipeline_last_stage(): -# include_seq_idx = True -# required_device_keys.update(("labels", "tokens", "loss_mask")) - -# _batch_required_keys = {} -# for key, val in _batch.items(): -# if key in required_device_keys: -# _batch_required_keys[key] = val.cuda(non_blocking=True) -# elif key in required_host_keys: -# _batch_required_keys[key] = val.cpu() -# else: -# _batch_required_keys[key] = None - -# # slice batch along sequence dimension for context parallelism -# output = get_batch_on_this_cp_rank(_batch_required_keys) -# if include_seq_idx: -# output["seq_idx"] = _batch["seq_idx"].cuda(non_blocking=True) -# return output - - -# class PredictDataModule(LightningDataModule): -# """Create a dataloader for prediction.""" - -# def __init__( -# self, -# dataset: torch.utils.data.Dataset, -# batch_size: int = 1, -# tokenizer=None, -# min_length: int | None = None, -# ): -# """Create a dataloader for prediction.""" -# super().__init__() -# self.dataset = dataset -# self.batch_size = batch_size -# self.tokenizer = tokenizer -# self.min_length = min_length -# default_pad_id = 0 -# self.pad_token_id = getattr(tokenizer, "pad_id", default_pad_id) if tokenizer is not None else default_pad_id - -# def setup(self, stage: str | None = None) -> None: -# """Set up the dataloader.""" -# pass - -# def predict_dataloader(self): -# """Create a dataloader for prediction.""" -# # need to use this to communicate that we are in predict mode and safe to not drop last batch -# return WrappedDataLoader( -# mode="predict", -# dataset=self.dataset, -# batch_size=self.batch_size, -# num_workers=8, -# shuffle=False, -# drop_last=False, -# collate_fn=functools.partial( -# collate.padding_collate_fn, -# padding_values={"tokens": self.pad_token_id, "position_ids": self.pad_token_id, "loss_mask": False}, -# min_length=self.min_length, -# max_length=None, -# ), -# ) - - -# def predict( -# fasta_path: Path, -# ckpt_dir: str, -# output_dir: Path, -# tensor_parallel_size: int, -# pipeline_model_parallel_size: int, -# context_parallel_size: int, -# num_nodes: int = 1, -# devices: int | None = None, -# eden_tokenizer: bool = False, -# model_size: str = "7b", -# ckpt_format: CheckpointFormats = "torch_dist", -# fp8: bool = False, -# full_fp8: bool = False, -# fp8_recipe: str = "delayed", -# work_dir: Path | None = None, -# micro_batch_size: int = 1, -# output_log_prob_seqs: bool = False, -# log_prob_collapse_option: Literal["sum", "mean", "per_token"] = "mean", -# write_interval: Literal["epoch", "batch"] = "epoch", -# prepend_bos: bool = False, -# no_sequence_parallel: bool = False, -# hybrid_override_pattern: str | None = None, -# num_layers: int | None = None, -# seq_len_interpolation_factor: int | None = None, -# files_per_subdir: int | None = None, -# lora_checkpoint_path: Path | None = None, -# mask_phylogenetic_tags: bool = False, -# min_length: int | None = None, -# extra_callbacks: list | None = None, # use this for making testing the predict loop easier. -# ): -# """Inference workflow for Evo2. - -# Returns: -# None -# """ -# if fp8 and not full_fp8 and fp8_recipe != "delayed": -# logger.warning( -# "fp8_recipe is ignored when using fp8 and not full_fp8 since it is set inside of the layer " -# "config to match vortex style FP8." -# ) -# if work_dir is None: -# work_dir = Path(tempfile.mkdtemp()) -# if files_per_subdir is None and write_interval == "batch": -# logger.warning( -# "--files-per-subdir is not set with --write-interval batch, will write all predictions to a " -# "single directory. This may cause problems if you are predicting on a very large dataset." -# ) -# sequence_parallel = tensor_parallel_size > 1 and not no_sequence_parallel -# output_dir.mkdir(parents=True, exist_ok=True) # Make sure the output directory exists, files will be written here. -# model_parallel_size = tensor_parallel_size * pipeline_model_parallel_size * context_parallel_size -# if devices is None: -# devices = model_parallel_size -# world_size = num_nodes * devices -# if world_size % model_parallel_size != 0: -# raise ValueError( -# f"world_size must be divisible by model_parallel_size, got {world_size} and" -# f" {model_parallel_size}. Please set --num-nodes and --devices such that num_nodes * devices is divisible " -# "by model_parallel_size, which is TP * CP * PP." -# ) -# global_batch_size = micro_batch_size * world_size // model_parallel_size - -# callbacks = [ -# PredictionWriter( -# output_dir=output_dir, -# write_interval=write_interval, -# batch_dim_key_defaults={"token_logits": 0}, -# seq_dim_key_defaults={"token_logits": 1}, -# files_per_subdir=files_per_subdir, -# save_all_model_parallel_ranks=False, # only write one copy of predictions. -# ) -# ] -# if extra_callbacks is not None: -# callbacks.extend(extra_callbacks) - -# # The following two config options are really only used for testing, but may also be useful for getting output from -# # specific layers of the model. -# config_modifiers_init: dict[str, Any] = { -# "distribute_saved_activations": False if sequence_parallel and tensor_parallel_size > 1 else True, -# } -# if hybrid_override_pattern is not None: -# config_modifiers_init["hybrid_override_pattern"] = hybrid_override_pattern -# if num_layers is not None: -# config_modifiers_init["num_layers"] = num_layers -# if seq_len_interpolation_factor is not None: -# config_modifiers_init["seq_len_interpolation_factor"] = seq_len_interpolation_factor - -# tokenizer = get_nmt_tokenizer("byte-level") -# if eden_tokenizer: -# patch_eden_tokenizer(tokenizer) - -# model_type = infer_model_type(model_size) - -# # Select model config based on model type -# if model_type == "hyena": -# if model_size not in HYENA_MODEL_OPTIONS: -# raise ValueError(f"Invalid model size for Hyena: {model_size}") -# config = HYENA_MODEL_OPTIONS[model_size]( -# forward_step_fn=hyena_predict_forward_step, -# data_step_fn=hyena_predict_data_step, -# # Only use vortex style FP8 in the model config if using FP8 and not full FP8. This will only apply FP8 to -# # the projection layer of the hyena mixer. -# vortex_style_fp8=fp8 and not full_fp8, -# **config_modifiers_init, -# ) - -# if lora_checkpoint_path: -# model_transform = Evo2LoRA(peft_ckpt_path=str(lora_checkpoint_path)) -# callbacks.append(model_transform) -# else: -# model_transform = None - -# model = HyenaPredictor( -# config, -# tokenizer=tokenizer, -# output_log_prob_seqs=output_log_prob_seqs, -# log_prob_collapse_option=log_prob_collapse_option, -# model_transform=model_transform, -# ) -# elif model_type == "mamba": # mamba -# if model_size not in MAMBA_MODEL_OPTIONS: -# raise ValueError(f"Invalid model size for Mamba: {model_size}") -# config = MAMBA_MODEL_OPTIONS[model_size]( -# forward_step_fn=hyena_predict_forward_step, # Can reuse the same forward steps -# data_step_fn=hyena_predict_data_step, -# **config_modifiers_init, -# ) - -# model = MambaPredictor( -# config, -# tokenizer=tokenizer, -# output_log_prob_seqs=output_log_prob_seqs, -# log_prob_collapse_option=log_prob_collapse_option, -# ) -# elif model_type == "llama": -# if model_size not in LLAMA_MODEL_OPTIONS: -# raise ValueError(f"Invalid model size for Llama: {model_size}") -# config = LLAMA_MODEL_OPTIONS[model_size]( -# forward_step_fn=hyena_predict_forward_step, -# data_step_fn=hyena_predict_data_step, -# **config_modifiers_init, -# ) -# model = LlamaPredictor( -# config, -# tokenizer=tokenizer, -# output_log_prob_seqs=output_log_prob_seqs, -# log_prob_collapse_option=log_prob_collapse_option, -# ) -# else: -# # This shouldn't be possible to reach. -# raise ValueError(f"Invalid model type: {model_type}.") - -# # Create PTL trainer. -# trainer = nl.Trainer( -# accelerator="gpu", -# num_nodes=num_nodes, -# devices=devices, -# strategy=nl.MegatronStrategy( -# drop_last_batch=False, -# tensor_model_parallel_size=tensor_parallel_size, -# pipeline_model_parallel_size=pipeline_model_parallel_size, -# context_parallel_size=context_parallel_size, -# pipeline_dtype=torch.bfloat16, -# ckpt_load_optimizer=False, # Needs to be false for a normal model checkpoint. -# ckpt_save_optimizer=False, -# ckpt_async_save=False, -# sequence_parallel=sequence_parallel, -# save_ckpt_format=ckpt_format, -# ckpt_load_strictness="log_all", -# setup_optimizers=False, -# store_optimizer_states=False, -# configure_optimizers=False, -# data_sampler=nl.MegatronDataSampler( -# micro_batch_size=micro_batch_size, -# global_batch_size=global_batch_size, -# seq_len=8192, -# output_log=False, # this is needed for predict step to work -# ), -# ), -# log_every_n_steps=1, -# limit_val_batches=10, -# num_sanity_val_steps=0, -# callbacks=callbacks, -# plugins=nl.MegatronMixedPrecision( -# precision="bf16-mixed", -# params_dtype=torch.bfloat16, -# # Only use FP8 in this plugin when using full FP8 precision and FP8. -# # Otherwise use vortex_style_fp8 in the model config. -# fp8_recipe=fp8_recipe, -# fp8="hybrid" if fp8 and full_fp8 else None, -# fp8_amax_history_len=16 if fp8 and full_fp8 else 1, -# fp8_amax_compute_algo="max" if fp8 and full_fp8 else "most_recent", -# ), -# ) - -# nemo_logger = NeMoLogger(log_dir=str(work_dir)) -# nemo_logger.setup(trainer, resume_if_exists=True) -# resume = nl.AutoResume( -# resume_if_exists=True, -# resume_ignore_no_checkpoint=False, -# resume_past_end=False, -# resume_from_path=str(ckpt_dir), -# restore_config=None, -# ) - -# resume.setup(trainer, model) # this pulls weights from the starting checkpoint. - -# if mask_phylogenetic_tags: - -# def custom_loss_masker(tokens): -# # Run the evo2 dataset mask_phylogenetic_tags function -# return Evo2Dataset.mask_phylogenetic_tags( -# tokens, -# Evo2Dataset.TAG_BOUNDS, -# Evo2Dataset.TAG_CHARS, -# tokenizer.eod if tokenizer is not None else Evo2Dataset.DEFAULT_EOD, -# Evo2Dataset.MAX_TAG_LEN, -# ) -# else: -# custom_loss_masker = None - -# dataset = SimpleFastaDataset(fasta_path, tokenizer, prepend_bos=prepend_bos, custom_loss_masker=custom_loss_masker) -# datamodule = PredictDataModule(dataset, batch_size=micro_batch_size, tokenizer=tokenizer, min_length=min_length) -# trainer.predict(model, datamodule=datamodule) # TODO return_predictions=False -# dataset.write_idx_map( -# output_dir -# ) # Finally write out the index map so we can match the predictions to the original sequences. - - -# def main(): -# """Entrypoint for Evo2 prediction (single inference step, no new tokens).""" -# args = parse_args() -# predict( -# num_nodes=args.num_nodes, -# devices=args.devices, -# fasta_path=args.fasta, -# ckpt_dir=args.ckpt_dir, -# tensor_parallel_size=args.tensor_parallel_size, -# pipeline_model_parallel_size=args.pipeline_model_parallel_size, -# context_parallel_size=args.context_parallel_size, -# output_dir=args.output_dir, -# model_size=args.model_size, -# ckpt_format=args.ckpt_format, -# fp8=args.fp8, -# full_fp8=args.full_fp8, -# fp8_recipe=args.fp8_recipe, -# micro_batch_size=args.micro_batch_size, -# output_log_prob_seqs=args.output_log_prob_seqs, -# log_prob_collapse_option=args.log_prob_collapse_option, -# prepend_bos=args.prepend_bos, -# no_sequence_parallel=args.no_sequence_parallel, -# hybrid_override_pattern=args.hybrid_override_pattern, -# seq_len_interpolation_factor=args.seq_len_interpolation_factor, -# num_layers=args.num_layers, -# files_per_subdir=args.files_per_subdir, -# write_interval=args.write_interval, -# lora_checkpoint_path=args.lora_checkpoint_path, -# mask_phylogenetic_tags=args.mask_phylogenetic_tags, -# min_length=args.min_length, -# eden_tokenizer=args.eden_tokenizer, -# ) - - -# if __name__ == "__main__": -# main() +r"""Prediction (inference) workflow for Evo2 using Megatron Bridge. + +This module provides functionality to run inference on Evo2 models using MBridge checkpoints. +It supports various parallelism strategies (TP, CP, DP) and can output either full logits +or collapsed log probabilities. + +Usage (CLI): + # Single GPU inference + torchrun --nproc_per_node 1 -m bionemo.evo2.run.predict \ + --fasta input.fasta --ckpt-dir /path/to/mbridge/checkpoint \ + --output-dir /path/to/output + + # Multi-GPU with tensor parallelism + torchrun --nproc_per_node 2 -m bionemo.evo2.run.predict \ + --fasta input.fasta --ckpt-dir /path/to/mbridge/checkpoint \ + --output-dir /path/to/output --tensor-parallel-size 2 + + # With context parallelism for long sequences + torchrun --nproc_per_node 2 -m bionemo.evo2.run.predict \ + --fasta input.fasta --ckpt-dir /path/to/mbridge/checkpoint \ + --output-dir /path/to/output --context-parallel-size 2 + +Output Format: + Batch mode (--write-interval batch): + - predictions__rank_{global_rank}__dp_rank_{dp_rank}__batch_{batch_idx}.pt + - With --files-per-subdir: subdir_{N}/predictions__rank_... + - Each file includes batch_idx tensor for reconstruction + + Epoch mode (--write-interval epoch, default): + - predictions__rank_{global_rank}__dp_rank_{dp_rank}.pt + - All batches collated into single file + + Both modes: + - seq_idx_map.json: Mapping from sequence names to indices in predictions + +Key Functions: + - predict(): Main prediction workflow + - batch_collator(): Collate predictions from multiple batches/ranks + - initialize_inference_distributed(): Set up distributed environment for inference +""" + +import argparse +import datetime +import logging +import os +from pathlib import Path +from typing import List, Literal, Optional, Tuple, TypeVar, Union + +import torch +import torch.distributed as dist +from megatron.bridge.data.samplers import build_pretraining_data_loader +from megatron.bridge.training.checkpointing import _load_model_weights_from_checkpoint +from megatron.bridge.training.config import DistributedInitConfig, RNGConfig +from megatron.bridge.training.mixed_precision import MIXED_PRECISION_RECIPES, get_mixed_precision_config +from megatron.bridge.training.tokenizers.tokenizer import _HuggingFaceTokenizer +from megatron.bridge.training.utils.checkpoint_utils import ( + file_exists, + get_checkpoint_run_config_filename, + read_run_config, +) +from megatron.bridge.utils.common_utils import ( + get_local_rank_preinit, + get_master_addr_safe, + get_master_port_safe, + get_rank_safe, + get_world_size_safe, +) +from megatron.bridge.utils.instantiate_utils import instantiate +from megatron.core import parallel_state, tensor_parallel +from megatron.core.num_microbatches_calculator import init_num_microbatches_calculator +from megatron.core.tensor_parallel.mappings import _gather_along_last_dim +from megatron.core.transformer.module import Float16Module +from megatron.core.utils import get_batch_on_this_cp_rank +from torch import Tensor + +from bionemo.evo2.data.dataset_tokenizer import DEFAULT_HF_TOKENIZER_MODEL_PATH +from bionemo.evo2.data.fasta_dataset import SimpleFastaDataset + + +logger: logging.Logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + +# Type alias for recursive batch structures (dicts, lists, tuples of Tensors) +ReductionT = TypeVar("ReductionT", bound=Union[Tensor, dict, list, tuple]) + + +# ============================================================================= +# Checkpoint Path Resolution +# ============================================================================= + + +def resolve_checkpoint_path(checkpoint_path: Path) -> Path: + """Resolve a checkpoint path to the actual checkpoint directory. + + MBridge checkpoints can be organized in two ways: + 1. Direct checkpoint: A directory containing run_config.yaml directly + (e.g., after conversion or for single checkpoints) + 2. Training output: A parent directory containing iter_XXXXXXX subdirectories + + This function handles both cases: + - If run_config.yaml exists in the given path, return it as-is + - Otherwise, find the latest iter_XXXXXXX subdirectory and return that + + Args: + checkpoint_path: Path to either a direct checkpoint or a training output directory. + + Returns: + Path to the checkpoint directory containing run_config.yaml. + + Raises: + FileNotFoundError: If the path doesn't exist or no valid checkpoint is found. + NotADirectoryError: If the path is not a directory. + + Examples: + >>> # Direct checkpoint path + >>> resolve_checkpoint_path(Path("/checkpoints/evo2_1b_mbridge")) + PosixPath('/checkpoints/evo2_1b_mbridge') + + >>> # Training output with iter_* subdirectories + >>> resolve_checkpoint_path(Path("/training/output")) + PosixPath('/training/output/iter_0007000') # Returns latest + """ + if not checkpoint_path.exists(): + raise FileNotFoundError(f"Checkpoint path '{checkpoint_path}' does not exist.") + if not checkpoint_path.is_dir(): + raise NotADirectoryError(f"Checkpoint path '{checkpoint_path}' must be a directory.") + + # Check if run_config.yaml exists directly in this path + run_config_path = get_checkpoint_run_config_filename(str(checkpoint_path)) + if file_exists(run_config_path): + return checkpoint_path + + # Look for iter_* subdirectories + iter_dirs = [ + (child.name, child) for child in checkpoint_path.iterdir() if child.is_dir() and child.name.startswith("iter_") + ] + + if not iter_dirs: + raise FileNotFoundError( + f"No valid checkpoint found at '{checkpoint_path}'. " + "Expected either run_config.yaml in the directory or iter_* subdirectories." + ) + + # Find the latest iteration by parsing the iteration number + def _parse_iter_num(item: tuple[str, Path]) -> int: + try: + return int(item[0].replace("iter_", "")) + except ValueError: + return -1 + + _, latest_iter_path = max(iter_dirs, key=_parse_iter_num) + + # Verify the selected iter directory has run_config.yaml + run_config_path = get_checkpoint_run_config_filename(str(latest_iter_path)) + if not file_exists(run_config_path): + raise FileNotFoundError(f"Latest checkpoint directory '{latest_iter_path}' does not contain run_config.yaml.") + + logger.info(f"Resolved checkpoint path to: {latest_iter_path}") + return latest_iter_path + + +# ============================================================================= +# Batch Collation Utilities +# ============================================================================= + + +def batch_collator( + batches: Optional[Union[Tuple[ReductionT, ...], List[ReductionT]]], + batch_dim: int = 0, + seq_dim: int = 1, + batch_dim_key_defaults: Optional[dict[str, int]] = None, + seq_dim_key_defaults: Optional[dict[str, int]] = None, + preferred_gpu: int = 0, +) -> Optional[ReductionT]: + """Collate multiple batches into a single batch by concatenating along the batch dimension. + + This function handles nested structures (dicts, lists, tuples) containing tensors. + Unlike PyTorch's default_collate, this assumes the batch dimension already exists + (as when parallelizing across microbatches or DP ranks). + + Args: + batches: Sequence of batches to collate. Each batch can be a tensor, dict, list, or tuple. + The structure must be consistent across all batches. + batch_dim: Dimension along which to concatenate tensors. Default 0. + seq_dim: Sequence dimension, used for padding to max length. Default 1. + batch_dim_key_defaults: For dict batches, override batch_dim for specific keys. + Default: {"token_logits": 1} (legacy compatibility, recommend passing {}). + seq_dim_key_defaults: For dict batches, override seq_dim for specific keys. + Default: {"token_logits": 0} (legacy compatibility, recommend passing {}). + preferred_gpu: If any tensor is on GPU, move all to this device. Default 0. + + Returns: + Collated batch with same structure as input batches, or None if input contains None. + + Raises: + ValueError: If batches is empty or contains unsupported types. + + Examples: + >>> # Collate dict batches + >>> batch1 = {"logits": torch.randn(2, 10, 512), "mask": torch.ones(2, 10)} + >>> batch2 = {"logits": torch.randn(3, 10, 512), "mask": torch.ones(3, 10)} + >>> result = batch_collator([batch1, batch2], batch_dim=0, seq_dim=1, + ... batch_dim_key_defaults={}, seq_dim_key_defaults={}) + >>> result["logits"].shape # torch.Size([5, 10, 512]) + + >>> # Collate with padding (different sequence lengths) + >>> batch1 = {"tokens": torch.randn(2, 100)} + >>> batch2 = {"tokens": torch.randn(2, 150)} + >>> result = batch_collator([batch1, batch2], batch_dim=0, seq_dim=1, + ... batch_dim_key_defaults={}, seq_dim_key_defaults={}) + >>> result["tokens"].shape # torch.Size([4, 150]) - padded to max length + """ + # Apply defaults for backward compatibility + if batch_dim_key_defaults is None: + batch_dim_key_defaults = {"token_logits": 1} + if seq_dim_key_defaults is None: + seq_dim_key_defaults = {"token_logits": 0} + + match batches: + # Base case: list starting with None + case [None, *_]: + return None + + # Base case: list of tensors + case [Tensor(), *_]: + return _collate_tensors(batches, batch_dim=batch_dim, seq_dim=seq_dim, preferred_gpu=preferred_gpu) + + # Recursive case: list of dicts + case [dict(), *_]: + return { + key: batch_collator( + [batch[key] for batch in batches], + batch_dim=batch_dim_key_defaults.get(key, batch_dim), + seq_dim=seq_dim_key_defaults.get(key, seq_dim), + batch_dim_key_defaults=batch_dim_key_defaults, + seq_dim_key_defaults=seq_dim_key_defaults, + preferred_gpu=preferred_gpu, + ) + for key in batches[0] + } + + # Recursive case: list of tuples + case [tuple(), *_]: + return tuple( + batch_collator( + [batch[i] for batch in batches], + batch_dim=batch_dim, + seq_dim=seq_dim, + batch_dim_key_defaults=batch_dim_key_defaults, + seq_dim_key_defaults=seq_dim_key_defaults, + preferred_gpu=preferred_gpu, + ) + for i in range(len(batches[0])) + ) + + # Recursive case: list of lists + case [list(), *_]: + return [ + batch_collator( + [batch[i] for batch in batches], + batch_dim=batch_dim, + seq_dim=seq_dim, + batch_dim_key_defaults=batch_dim_key_defaults, + seq_dim_key_defaults=seq_dim_key_defaults, + preferred_gpu=preferred_gpu, + ) + for i in range(len(batches[0])) + ] + + # Error cases + case []: + raise ValueError("Cannot collate an empty sequence of batches") + case _: + raise ValueError(f"Unsupported batch type: {type(batches[0]) if batches else 'empty'}") + + +def _collate_tensors( + tensors: List[Tensor], + batch_dim: int, + seq_dim: int, + preferred_gpu: int, +) -> Tensor: + """Concatenate tensors along batch dimension, padding sequence dimension if needed. + + Args: + tensors: List of tensors to concatenate + batch_dim: Dimension to concatenate along + seq_dim: Dimension to pad to max length + preferred_gpu: GPU device to use if any tensor is on GPU + + Returns: + Concatenated tensor + """ + # Move all to same device if any is on GPU + if any(t.is_cuda for t in tensors): + device = torch.device(f"cuda:{preferred_gpu}") + tensors = [t.to(device) for t in tensors] + + # For 1D tensors, just concatenate (no sequence dimension) + if tensors[0].ndim == 1: + return torch.cat(tensors, dim=0) + + # Pad to max sequence length + max_seq_len = max(t.size(seq_dim) for t in tensors) + padded_tensors = [] + + for tensor in tensors: + pad_amount = max_seq_len - tensor.size(seq_dim) + if pad_amount > 0: + # Build padding tuple: [left_last, right_last, left_second_last, right_second_last, ...] + pad_spec = [0] * (2 * tensor.ndim) + # Pad on the right of the sequence dimension + pad_spec[2 * (tensor.ndim - 1 - seq_dim) + 1] = pad_amount + padded_tensor = torch.nn.functional.pad(tensor, tuple(pad_spec)) + else: + padded_tensor = tensor + padded_tensors.append(padded_tensor) + + return torch.cat(padded_tensors, dim=batch_dim) + + +# ============================================================================= +# Distributed Initialization +# ============================================================================= + + +def initialize_inference_distributed( + tensor_model_parallel_size: int = 1, + pipeline_model_parallel_size: int = 1, + context_parallel_size: int = 1, + micro_batch_size: int = 1, + global_batch_size: int = 1, + rng_config: Optional[RNGConfig] = None, + dist_config: Optional[DistributedInitConfig] = None, +) -> None: + """Initialize distributed environment for inference. + + Sets up the minimal distributed infrastructure needed for model-parallel inference: + 1. torch.distributed process group + 2. Model parallel groups (TP, PP, CP, DP) + 3. Microbatch calculator (for batch scheduling) + 4. Random seeds for reproducibility + + This is a lightweight alternative to full Megatron initialization, skipping + training-specific components like the rerun state machine. + + Args: + tensor_model_parallel_size: Tensor parallelism degree (splits model across GPUs) + pipeline_model_parallel_size: Pipeline parallelism degree (must be 1 for inference) + context_parallel_size: Context parallelism degree (splits sequence across GPUs) + micro_batch_size: Batch size per forward pass + global_batch_size: Total batch size across all DP ranks + rng_config: Random number generator configuration. Defaults to seed=1234. + dist_config: Distributed backend configuration. Defaults to NCCL backend. + + Note: + This function must be called before creating the model. It initializes + parallel_state which is used throughout the codebase. + """ + import random + + import numpy as np + + # Apply defaults + if rng_config is None: + rng_config = RNGConfig(seed=1234) + if dist_config is None: + dist_config = DistributedInitConfig() + + assert torch.cuda.is_available(), "Inference requires CUDA." + + device_count = torch.cuda.device_count() + world_size = get_world_size_safe() + model_parallel_size = tensor_model_parallel_size * pipeline_model_parallel_size * context_parallel_size + data_parallel_size = world_size // model_parallel_size + + # Initialize microbatch calculator + init_num_microbatches_calculator( + rank=get_rank_safe(), + rampup_batch_size=None, + global_batch_size=global_batch_size, + micro_batch_size=micro_batch_size, + data_parallel_size=data_parallel_size, + decrease_batch_size_if_needed=False, + ) + + # Initialize torch.distributed + if not torch.distributed.is_initialized(): + if get_rank_safe() == 0: + print("> initializing torch distributed for inference ...", flush=True) + + if device_count > 0: + torch.cuda.set_device(get_local_rank_preinit()) + + # Ensure environment variables are set + if "MASTER_ADDR" not in os.environ: + os.environ["MASTER_ADDR"] = get_master_addr_safe() + if "MASTER_PORT" not in os.environ: + os.environ["MASTER_PORT"] = str(get_master_port_safe()) + + torch.distributed.init_process_group( + backend=dist_config.distributed_backend, + world_size=world_size, + rank=get_rank_safe(), + timeout=datetime.timedelta(minutes=dist_config.distributed_timeout_minutes), + ) + torch.distributed.barrier(device_ids=[get_local_rank_preinit()]) + else: + if get_rank_safe() == 0: + print("torch distributed is already initialized, skipping ...", flush=True) + + # Initialize model parallel groups + if device_count > 0 and not parallel_state.model_parallel_is_initialized(): + parallel_state.initialize_model_parallel( + tensor_model_parallel_size=tensor_model_parallel_size, + pipeline_model_parallel_size=pipeline_model_parallel_size, + context_parallel_size=context_parallel_size, + distributed_timeout_minutes=dist_config.distributed_timeout_minutes, + ) + if get_rank_safe() == 0: + print( + f"> initialized tensor model parallel with size {parallel_state.get_tensor_model_parallel_world_size()}" + ) + print( + f"> initialized pipeline model parallel with size {parallel_state.get_pipeline_model_parallel_world_size()}" + ) + print(f"> initialized data parallel with size {parallel_state.get_data_parallel_world_size()}") + elif get_rank_safe() == 0: + print("model parallel is already initialized", flush=True) + + # Set random seeds + if get_rank_safe() == 0: + print(f"> setting random seeds to {rng_config.seed} ...", flush=True) + + seed = rng_config.seed + (100 * parallel_state.get_pipeline_model_parallel_rank()) + if rng_config.data_parallel_random_init: + seed = seed + (10 * parallel_state.get_data_parallel_rank()) + + random.seed(seed) + np.random.seed(seed) # noqa: NPY002 + torch.manual_seed(seed) + + if device_count > 0: + tensor_parallel.model_parallel_cuda_manual_seed( + seed, + rng_config.te_rng_tracker, + rng_config.inference_rng_tracker, + ) + + +# ============================================================================= +# Context Parallelism Utilities +# ============================================================================= + + +def _gather_along_cp_dim(input_: Tensor, seq_dim: int = 1, unshuffle_zigzag: bool = True) -> Tensor: + """Gather tensors from all CP ranks and restore original sequence order. + + When using context parallelism (CP), sequences are split across multiple GPUs using a + "zigzag" pattern for load balancing. This function gathers the split tensors from all + CP ranks and optionally restores the original sequence order. + + Zigzag Pattern (CP=2 example): + Original sequence: [chunk0, chunk1, chunk2, chunk3] + CP rank 0 receives: [chunk0, chunk3] (positions 0 and 3) + CP rank 1 receives: [chunk1, chunk2] (positions 1 and 2) + + After gathering and unshuffling, the original order is restored. + + Args: + input_: Input tensor with shape [B, S/CP, ...] where S is full sequence length + seq_dim: Sequence dimension in the tensor. Default 1. + unshuffle_zigzag: If True, restore original sequence order after gathering. + Set to False only if you need the raw gathered order. Default True. + + Returns: + Gathered tensor with shape [B, S, ...] in original sequence order. + If CP=1, returns input unchanged. + + Note: + This function requires parallel_state to be initialized with CP groups. + """ + cp_size = parallel_state.get_context_parallel_world_size() + if cp_size == 1: + return input_ + + # Gather from all CP ranks + # After all_gather: [B * cp_size, seq_len_per_rank, ...] + dim_size = list(input_.size()) + dim_size[0] = dim_size[0] * cp_size + + output = torch.empty(dim_size, dtype=input_.dtype, device=torch.cuda.current_device()) + torch.distributed.all_gather_into_tensor( + output, input_.contiguous(), group=parallel_state.get_context_parallel_group() + ) + + # Chunk by batch dimension and concatenate by sequence dimension + # Result: [B, seq_len_per_rank * cp_size, ...] + tensor_list = output.chunk(cp_size, dim=0) + output = torch.cat(tensor_list, dim=seq_dim).contiguous() + + if not unshuffle_zigzag: + return output + + # Undo the zigzag pattern from get_batch_on_this_cp_rank + # The zigzag assigns chunk i and (2*cp_size - i - 1) to rank i + seq_len = output.shape[seq_dim] + num_chunks = 2 * cp_size + chunk_size = seq_len // num_chunks + + chunks = output.split(chunk_size, dim=seq_dim) + + # Build the order in which chunks appear after gathering: + # [rank0_first, rank0_second, rank1_first, rank1_second, ...] + # where rank_i has chunks (i, 2*cp_size - i - 1) + gathered_order = [] + for rank in range(cp_size): + gathered_order.append(rank) + gathered_order.append(2 * cp_size - rank - 1) + + # Create inverse mapping: original_position -> gathered_position + inverse_order = [0] * num_chunks + for pos, orig_idx in enumerate(gathered_order): + inverse_order[orig_idx] = pos + + # Reorder to original sequence order [0, 1, 2, ..., 2*cp_size-1] + reordered_chunks = [chunks[inverse_order[i]] for i in range(num_chunks)] + return torch.cat(reordered_chunks, dim=seq_dim).contiguous() + + +# ============================================================================= +# Argument Parsing +# ============================================================================= + + +def parse_args() -> argparse.Namespace: + """Parse command-line arguments for Evo2 inference. + + Returns: + Parsed arguments namespace + """ + ap = argparse.ArgumentParser( + description="Run inference on Evo2 models using MBridge checkpoints", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + # Required arguments + ap.add_argument( + "--fasta", + type=Path, + required=True, + help="Path to input FASTA file containing sequences for prediction", + ) + ap.add_argument( + "--ckpt-dir", + type=Path, + required=True, + help="Path to MBridge checkpoint directory (must contain run_config.yaml)", + ) + + # Output arguments + ap.add_argument( + "--output-dir", + type=Path, + default=None, + help="Directory for output predictions. If not set, predictions are discarded.", + ) + ap.add_argument( + "--write-interval", + type=str, + default="epoch", + choices=["epoch", "batch"], + help="When to write predictions: 'epoch' writes all at end, 'batch' writes after each batch", + ) + ap.add_argument( + "--files-per-subdir", + type=int, + help="Group output files into subdirectories. Only used with --write-interval batch.", + ) + + # Parallelism arguments + ap.add_argument("--num-nodes", type=int, default=1, help="Number of nodes for distributed inference") + ap.add_argument( + "--devices", + type=int, + help="Number of GPUs per node. Default: TP * PP * CP", + ) + ap.add_argument("--tensor-parallel-size", type=int, default=1, help="Tensor parallelism degree") + ap.add_argument( + "--pipeline-model-parallel-size", + type=int, + choices=[1], + default=1, + help="Pipeline parallelism degree (only 1 supported)", + ) + ap.add_argument("--context-parallel-size", type=int, default=1, help="Context parallelism degree") + ap.add_argument( + "--no-sequence-parallel", + action="store_true", + help="Disable sequence parallelism when using TP > 1", + ) + + # Model/precision arguments + ap.add_argument( + "--mixed-precision-recipe", + type=str, + choices=list(MIXED_PRECISION_RECIPES.keys()), + help="Override mixed precision recipe (default: use checkpoint setting)", + ) + ap.add_argument( + "--vortex-style-fp8", + action="store_true", + help="Use vortex-style FP8 (applies FP8 only to projection layers)", + ) + + # Batch/sequence arguments + ap.add_argument("--micro-batch-size", type=int, default=1, help="Batch size per forward pass") + ap.add_argument("--min-length", type=int, help="Minimum sequence length (pad shorter sequences)") + ap.add_argument("--prepend-bos", action="store_true", help="Prepend BOS token to sequences") + + # Output format arguments + ap.add_argument( + "--output-log-prob-seqs", + action="store_true", + help="Output log probabilities instead of raw logits", + ) + ap.add_argument( + "--log-prob-collapse-option", + choices=["sum", "mean", "per_token"], + default="mean", + help="How to aggregate per-token log probs: sum, mean, or keep per_token", + ) + + # Model configuration overrides (for testing) + ap.add_argument( + "--hybrid-override-pattern", + type=str, + help="Override hybrid layer pattern (e.g., 'SDH*' for testing)", + ) + ap.add_argument("--num-layers", type=int, help="Override number of layers (for testing)") + ap.add_argument( + "--seq-len-interpolation-factor", + type=int, + help="ROPE sequence length interpolation factor", + ) + + # Embedding extraction arguments + ap.add_argument( + "--embedding-layer", + type=int, + help="Extract embeddings from a specific transformer layer instead of logits. " + "Supports Python-style negative indexing (e.g., -1 for last layer, -2 for second-to-last). " + "For a 25-layer model, layer 24 and layer -1 both refer to the last layer.", + ) + + # Tokenizer arguments + ap.add_argument( + "--eden-tokenizer", + action="store_true", + help="Use Eden tokenizer patches", + ) + ap.add_argument( + "--mask-phylogenetic-tags", + action="store_true", + help="Mask phylogenetic tags in loss computation", + ) + + return ap.parse_args() + + +def on_writing_rank() -> bool: + """Returns True if the current rank is one that should own writing predictions.""" + return ( + (parallel_state.is_pipeline_last_stage()) + and (parallel_state.get_tensor_model_parallel_rank() == 0) + and (parallel_state.get_context_parallel_rank() == 0) + ) + + +# ============================================================================= +# Data Loading Utilities +# ============================================================================= + + +def _padding_collate_fn_factory( + pad_token_id: int = 0, + min_length: Optional[int] = None, +): + """Create a collate function that pads sequences to uniform length. + + Args: + pad_token_id: Token ID to use for padding + min_length: Minimum sequence length (pad shorter sequences to this) + + Returns: + Collate function compatible with DataLoader + """ + + def collate_fn(batch: list[dict[str, Tensor]]) -> dict[str, Tensor]: + return _padding_collate_fn(batch, pad_token_id, min_length) + + return collate_fn + + +def _padding_collate_fn( + batch: list[dict[str, Tensor]], + pad_token_id: int = 0, + min_length: Optional[int] = None, +) -> dict[str, Tensor]: + """Pad sequences in a batch to the same length. + + Handles the following keys specially: + - tokens: Padded with pad_token_id + - position_ids: Extended with consecutive positions + - loss_mask: Padded with 0 (masked) + - seq_idx: Not padded (scalar per sample) + - Other keys: Padded with 0 + + Args: + batch: List of sample dictionaries from the dataset + pad_token_id: Token ID for padding + min_length: Minimum length to pad to + + Returns: + Dictionary with batched and padded tensors + """ + max_len = max(sample["tokens"].shape[0] for sample in batch) + if min_length is not None: + max_len = max(max_len, min_length) + + padded_batch: dict[str, list[Tensor]] = {key: [] for key in batch[0].keys()} + + for sample in batch: + seq_len = sample["tokens"].shape[0] + pad_len = max_len - seq_len + + for key, value in sample.items(): + if key == "tokens": + padded = torch.nn.functional.pad(value, (0, pad_len), value=pad_token_id) + elif key == "position_ids": + if pad_len > 0: + padded = torch.cat([value, torch.arange(seq_len, max_len, dtype=value.dtype)]) + else: + padded = value + elif key == "loss_mask": + padded = torch.nn.functional.pad(value, (0, pad_len), value=0) + elif key == "seq_idx": + padded = value # Scalar, no padding + else: + padded = torch.nn.functional.pad(value, (0, pad_len), value=0) + padded_batch[key].append(padded) + + return {key: torch.stack(values) for key, values in padded_batch.items()} + + +# ============================================================================= +# Prediction Step +# ============================================================================= + + +def _predict_step( + model: torch.nn.Module, + batch: dict[str, Tensor], + output_log_prob_seqs: bool = False, + log_prob_collapse_option: Literal["sum", "mean", "per_token"] = "mean", + context_parallel_size: int = 1, + output_embeddings: bool = False, +) -> Optional[dict[str, Tensor]]: + """Run a single prediction step and gather outputs across parallel ranks. + + Args: + model: The Evo2 model to run inference with + batch: Input batch containing: + - tokens: Input token IDs [B, S] + - position_ids: Position indices [B, S] + - loss_mask: Mask indicating valid tokens [B, S] + - seq_idx: Original sequence indices [B] + output_log_prob_seqs: If True, return log probabilities instead of logits + log_prob_collapse_option: How to aggregate log probs ('sum', 'mean', or 'per_token') + context_parallel_size: CP size (for warning about per_token output) + output_embeddings: If True, return embeddings instead of logits (model must have + post_process=False) + + Returns: + Dictionary containing predictions: + - If output_embeddings=True: hidden_embeddings, pad_mask, seq_idx, tokens + - If output_log_prob_seqs=False: token_logits, pad_mask, seq_idx, tokens + - If output_log_prob_seqs=True with sum/mean: log_probs_seqs, seq_idx + - If output_log_prob_seqs=True with per_token: log_probs_seqs, seq_idx, loss_mask + Returns None if not on the last pipeline stage. + """ + if not parallel_state.is_pipeline_last_stage(): + return None + + # Forward pass + output_tensor = model( + input_ids=batch["tokens"], + position_ids=batch["position_ids"], + attention_mask=None, + ) + + # Gather across tensor parallel ranks + # For logits (post_process=True): gather along vocabulary dimension (last dim is sharded) + # For embeddings (post_process=False): hidden states are not sharded across TP, skip gathering + if output_embeddings: + # Hidden states are not sharded across TP ranks, just use the output directly + forward_out_tp_gathered = output_tensor + else: + # Logits have the vocab dimension sharded across TP ranks + forward_out_tp_gathered = _gather_along_last_dim( + output_tensor, group=parallel_state.get_tensor_model_parallel_group() + ) + + # Gather across context parallel ranks (sequence dimension) + forward_out_gathered = _gather_along_cp_dim(forward_out_tp_gathered) + loss_mask_gathered = _gather_along_cp_dim(batch["loss_mask"]) + tokens_gathered = _gather_along_cp_dim(batch["tokens"]) + + if output_embeddings: + # When extracting embeddings, the model output is hidden states, not logits + # Model outputs [S, B, H] (sequence-first format), transpose to [B, S, H] for consistency + hidden_embeddings = forward_out_gathered.transpose(0, 1).contiguous() + return { + "hidden_embeddings": hidden_embeddings, + "pad_mask": loss_mask_gathered, + "seq_idx": batch["seq_idx"], + "tokens": tokens_gathered, + } + elif output_log_prob_seqs: + return _compute_log_probs( + logits=forward_out_gathered, + tokens=tokens_gathered, + loss_mask=loss_mask_gathered, + seq_idx=batch["seq_idx"], + collapse_option=log_prob_collapse_option, + context_parallel_size=context_parallel_size, + ) + else: + return { + "token_logits": forward_out_gathered, + "pad_mask": loss_mask_gathered, + "seq_idx": batch["seq_idx"], + "tokens": tokens_gathered, + } + + +def _compute_log_probs( + logits: Tensor, + tokens: Tensor, + loss_mask: Tensor, + seq_idx: Tensor, + collapse_option: Literal["sum", "mean", "per_token"], + context_parallel_size: int, +) -> dict[str, Tensor]: + """Compute log probabilities from model logits. + + Computes P(token_i | token_0, ..., token_{i-1}) for each token. + + Args: + logits: Model output logits [B, S, V] + tokens: Input token IDs [B, S] + loss_mask: Mask for valid tokens [B, S] + seq_idx: Sequence indices [B] + collapse_option: How to aggregate: 'sum', 'mean', or 'per_token' + context_parallel_size: CP size (for per_token warning) + + Returns: + Dictionary with log_probs_seqs and seq_idx (and loss_mask if per_token) + """ + # Predictions for token i are at position i, labels are at i+1 + softmax_logprobs = torch.log_softmax(logits, dim=-1) + softmax_logprobs = softmax_logprobs[:, :-1] # [B, S-1, V] + target_tokens = tokens[:, 1:] # [B, S-1] + + if softmax_logprobs.shape[1] != target_tokens.shape[1]: + raise RuntimeError(f"Shape mismatch: logprobs {softmax_logprobs.shape} vs targets {target_tokens.shape}") + + # Gather log probs for actual tokens + log_probs_per_token = torch.gather(softmax_logprobs, 2, target_tokens.unsqueeze(-1)).squeeze(-1) + + # Apply loss mask (zero out padding) + loss_mask_shifted = loss_mask[:, 1:].float() + log_probs_per_token = log_probs_per_token * loss_mask_shifted + + if collapse_option == "per_token": + if context_parallel_size > 1: + logger.warning( + "Per-token log probabilities with CP>1 will have zigzag-shuffled order. " + "Use 'sum' or 'mean' to get correctly aggregated results." + ) + return { + "log_probs_seqs": log_probs_per_token, + "seq_idx": seq_idx, + "loss_mask": loss_mask_shifted.bool(), + } + + # Sum log probs across sequence + log_prob_seqs = torch.sum(log_probs_per_token, dim=1) + + if collapse_option == "mean": + # Divide by number of valid tokens + valid_token_count = torch.clamp(loss_mask_shifted.sum(dim=-1), min=1.0) + log_prob_seqs = log_prob_seqs / valid_token_count + + return {"log_probs_seqs": log_prob_seqs, "seq_idx": seq_idx} + + +# ============================================================================= +# Output Writing +# ============================================================================= + + +def _write_predictions_batch( + predictions: dict[str, Tensor], + output_dir: Path, + batch_idx: int, + global_rank: int, + dp_rank: int, + files_per_subdir: Optional[int] = None, + num_files_written: int = 0, + data_parallel_world_size: int = 1, +) -> tuple[Path, int, int]: + """Write predictions to disk as a PyTorch file (batch mode). + + File naming follows the original PredictionWriter convention: + predictions__rank_{global_rank}__dp_rank_{dp_rank}__batch_{batch_idx}.pt + + Subdirectory structure (when files_per_subdir is set): + subdir_{num}/predictions__rank_... + + The subdirectory numbering starts from 1 and increments when the number of files + written (across all DP ranks) reaches files_per_subdir. + + Args: + predictions: Dictionary of prediction tensors to save + output_dir: Base output directory + batch_idx: Batch index for file naming + global_rank: Global rank of this process + dp_rank: Data parallel rank (included in filename for multi-GPU) + files_per_subdir: If set, organize files into subdirectories + num_files_written: Number of files already written in current subdir + data_parallel_world_size: Number of data parallel ranks + + Returns: + Tuple of (output_path, updated_num_files_written, updated_num_subdirs) + """ + if (not predictions) or (not on_writing_rank()): + return output_dir, num_files_written, 0 + + output_dir.mkdir(parents=True, exist_ok=True) + + # Track subdirectory state + current_output_dir = output_dir + num_subdirs_written = 0 + + if files_per_subdir is not None: + # Calculate how many subdirs we've created based on total files written + # (counting all DP ranks) + effective_files = num_files_written * data_parallel_world_size + if effective_files >= files_per_subdir: + # Need a new subdirectory + num_subdirs_written = effective_files // files_per_subdir + 1 + current_output_dir = output_dir / f"subdir_{num_subdirs_written}" + current_output_dir.mkdir(parents=True, exist_ok=True) + num_files_written = 0 + + filename = f"predictions__rank_{global_rank}__dp_rank_{dp_rank}__batch_{batch_idx}.pt" + output_path = current_output_dir / filename + + # Add batch_idx to predictions (matching original PredictionWriter behavior) + predictions["batch_idx"] = torch.tensor([batch_idx], dtype=torch.int64) + + torch.save(predictions, output_path) + logger.info(f"Inference predictions are stored in {output_path}\n{predictions.keys()}") + + return output_path, num_files_written + 1, num_subdirs_written + + +def _write_predictions_epoch( + predictions: dict[str, Tensor], + output_dir: Path, + global_rank: int, + dp_rank: int, +) -> Path: + """Write predictions to disk as a PyTorch file (epoch mode). + + File naming follows the original PredictionWriter convention: + predictions__rank_{global_rank}__dp_rank_{dp_rank}.pt + + Args: + predictions: Dictionary of prediction tensors to save + output_dir: Base output directory + global_rank: Global rank of this process + dp_rank: Data parallel rank + + Returns: + Path to the saved file + """ + if (not predictions) or (not on_writing_rank()): + return output_dir + + output_dir.mkdir(parents=True, exist_ok=True) + + filename = f"predictions__rank_{global_rank}__dp_rank_{dp_rank}.pt" + output_path = output_dir / filename + + torch.save(predictions, output_path) + logger.info(f"Inference predictions are stored in {output_path}\n{predictions.keys()}") + + return output_path + + +# ============================================================================= +# Main Prediction Workflow +# ============================================================================= + + +def predict( + fasta_path: Path, + ckpt_dir: Path, + output_dir: Optional[Path] = None, + *, + # Parallelism settings + tensor_parallel_size: int = 1, + pipeline_model_parallel_size: int = 1, + context_parallel_size: int = 1, + no_sequence_parallel: bool = False, + # Precision settings + mixed_precision_recipe: Optional[str] = None, + # Batch/sequence settings + micro_batch_size: int = 1, + min_length: Optional[int] = None, + prepend_bos: bool = False, + # Output settings + write_interval: Literal["epoch", "batch"] = "epoch", + files_per_subdir: Optional[int] = None, + output_log_prob_seqs: bool = False, + log_prob_collapse_option: Literal["sum", "mean", "per_token"] = "mean", + # Embedding extraction + embedding_layer: Optional[int] = None, +) -> None: + """Run the complete Evo2 prediction workflow. + + This function orchestrates the full inference pipeline: + 1. Load model configuration from MBridge checkpoint + 2. Override parallelism and precision settings + 3. Initialize distributed environment + 4. Create and configure the model + 5. Load model weights + 6. Process FASTA sequences and write predictions + + Args: + fasta_path: Path to input FASTA file containing sequences for prediction. + ckpt_dir: Path to MBridge checkpoint directory (must contain run_config.yaml). + output_dir: Directory for output predictions. If None, predictions are discarded. + tensor_parallel_size: Tensor parallelism degree (splits model across GPUs). + pipeline_model_parallel_size: Pipeline parallelism degree (must be 1). + context_parallel_size: Context parallelism degree (splits sequence across GPUs). + no_sequence_parallel: Disable sequence parallelism when using TP > 1. + mixed_precision_recipe: Override mixed precision recipe (default: use checkpoint). + micro_batch_size: Batch size per forward pass. + min_length: Minimum sequence length (pad shorter sequences to this). + prepend_bos: Prepend BOS token to sequences. + write_interval: When to write predictions: 'epoch' or 'batch'. + files_per_subdir: Group output files into subdirectories (batch mode only). + output_log_prob_seqs: Output log probabilities instead of raw logits. + log_prob_collapse_option: How to aggregate log probs: 'sum', 'mean', 'per_token'. + embedding_layer: Extract embeddings from a specific layer instead of logits. + Supports Python-style negative indexing (-1 for last layer, -2 for second-to-last). + For a 25-layer model, layer 24 and -1 both refer to the last layer. + + Raises: + ValueError: If pipeline parallelism > 1 is requested. + FileNotFoundError: If checkpoint run_config.yaml is missing. + + Example: + >>> from pathlib import Path + >>> predict( + ... fasta_path=Path("sequences.fasta"), + ... ckpt_dir=Path("/path/to/mbridge/checkpoint"), + ... output_dir=Path("/path/to/output"), + ... tensor_parallel_size=2, + ... micro_batch_size=4, + ... ) + """ + if pipeline_model_parallel_size != 1: + raise ValueError("Pipeline parallelism > 1 is not currently supported for prediction.") + + # ------------------------------------------------------------------------- + # Step 1: Resolve and load configuration from checkpoint + # ------------------------------------------------------------------------- + # Handle both direct checkpoint paths and training output directories with iter_* subdirs + resolved_ckpt_dir = resolve_checkpoint_path(ckpt_dir) + logger.info(f"Loading configuration from checkpoint: {resolved_ckpt_dir}") + + run_config_filename = get_checkpoint_run_config_filename(str(resolved_ckpt_dir)) + + run_config = read_run_config(run_config_filename) + model_provider = instantiate(run_config["model"]) + logger.info(f"Instantiated model provider: {type(model_provider).__name__}") + + # ------------------------------------------------------------------------- + # Step 2: Override parallelism and precision settings + # ------------------------------------------------------------------------- + model_provider.tensor_model_parallel_size = tensor_parallel_size + model_provider.pipeline_model_parallel_size = pipeline_model_parallel_size + model_provider.context_parallel_size = context_parallel_size + model_provider.sequence_parallel = tensor_parallel_size > 1 and not no_sequence_parallel + + # Configure mixed precision + if mixed_precision_recipe is not None: + mp_config = get_mixed_precision_config(mixed_precision_recipe) + elif "mixed_precision" in run_config and run_config["mixed_precision"] is not None: + mp_value = run_config["mixed_precision"] + if isinstance(mp_value, str): + mp_config = get_mixed_precision_config(mp_value) + logger.info(f"Using mixed precision recipe from checkpoint: {mp_value}") + else: + mp_config = instantiate(mp_value) + logger.info("Using mixed precision config from checkpoint") + else: + mp_config = get_mixed_precision_config("bf16_mixed") + + mp_config.finalize() + mp_config.setup(model_provider) + + # ------------------------------------------------------------------------- + # Step 3: Load tokenizer + # ------------------------------------------------------------------------- + tokenizer_dir = resolved_ckpt_dir / "tokenizer" + if tokenizer_dir.exists(): + tokenizer = _HuggingFaceTokenizer(tokenizer_dir) + else: + tokenizer = _HuggingFaceTokenizer(DEFAULT_HF_TOKENIZER_MODEL_PATH) + + model_provider.vocab_size = tokenizer.vocab_size + model_provider.should_pad_vocab = True + + # ------------------------------------------------------------------------- + # Step 3.5: Handle embedding layer extraction + # ------------------------------------------------------------------------- + # Get the original number of layers from the checkpoint config + original_num_layers = model_provider.num_layers + output_embeddings = embedding_layer is not None + + if output_embeddings: + # Validate and resolve the embedding layer index + # Support Python-style negative indexing + if embedding_layer < 0: + # Convert negative index to positive (e.g., -1 -> last layer) + target_num_layers = original_num_layers + embedding_layer + 1 + else: + # Positive index: layer N means we need N+1 layers (0-indexed) + target_num_layers = embedding_layer + 1 + + if target_num_layers <= 0 or target_num_layers > original_num_layers: + raise ValueError( + f"Invalid embedding_layer={embedding_layer} for model with {original_num_layers} layers. " + f"Valid range: -{original_num_layers} to {original_num_layers - 1}." + ) + + # Set the model to use fewer layers and skip post-processing (output heads) + model_provider.num_layers = target_num_layers + model_provider.post_process = False + + # Also truncate the hybrid_override_pattern if it exists, since it must match num_layers + if hasattr(model_provider, "hybrid_override_pattern") and model_provider.hybrid_override_pattern is not None: + original_pattern = model_provider.hybrid_override_pattern + if len(original_pattern) > target_num_layers: + model_provider.hybrid_override_pattern = original_pattern[:target_num_layers] + logger.info( + f"Truncated hybrid_override_pattern from {len(original_pattern)} to {target_num_layers} chars" + ) + + # Disable remove_activation_post_first_layer if we only have 1 layer, since it requires at least 2 layers + if target_num_layers == 1 and hasattr(model_provider, "remove_activation_post_first_layer"): + if model_provider.remove_activation_post_first_layer: + model_provider.remove_activation_post_first_layer = False + logger.info("Disabled remove_activation_post_first_layer (requires at least 2 layers)") + + logger.info( + f"Embedding extraction mode: extracting from layer {embedding_layer} " + f"(using {target_num_layers} of {original_num_layers} layers, post_process=False)" + ) + + # Cannot use log prob output with embedding mode + if output_log_prob_seqs: + raise ValueError("Cannot use --output-log-prob-seqs with --embedding-layer. Embeddings are not logits.") + + # ------------------------------------------------------------------------- + # Step 4: Initialize distributed environment + # ------------------------------------------------------------------------- + rng_config = instantiate(run_config.get("rng")) if run_config.get("rng") else RNGConfig(seed=1234) + dist_config = instantiate(run_config.get("dist")) if run_config.get("dist") else DistributedInitConfig() + + model_parallel_size = tensor_parallel_size * pipeline_model_parallel_size * context_parallel_size + world_size = get_world_size_safe() + data_parallel_size = world_size // model_parallel_size + global_batch_size = micro_batch_size * data_parallel_size + + initialize_inference_distributed( + tensor_model_parallel_size=tensor_parallel_size, + pipeline_model_parallel_size=pipeline_model_parallel_size, + context_parallel_size=context_parallel_size, + micro_batch_size=micro_batch_size, + global_batch_size=global_batch_size, + rng_config=rng_config, + dist_config=dist_config, + ) + logger.info("Initialized distributed environment") + + # ------------------------------------------------------------------------- + # Step 5: Create model and load weights + # ------------------------------------------------------------------------- + logger.info("Creating model...") + model_provider.finalize() + + model = model_provider.provide_distributed_model( + ddp_config=None, + wrap_with_ddp=False, + data_parallel_random_init=False, + bf16=mp_config.bf16, + fp16=mp_config.fp16, + mixed_precision_wrapper=Float16Module if (mp_config.bf16 or mp_config.fp16) else None, + ) + + for model_module in model: + model_module.eval() + + # Log model layer information + # Access the underlying model to get layer count + model_for_inspection = model[0] + if hasattr(model_for_inspection, "module"): + # Handle Float16Module wrapper + model_for_inspection = model_for_inspection.module + if hasattr(model_for_inspection, "decoder") and hasattr(model_for_inspection.decoder, "layers"): + actual_num_layers = len(model_for_inspection.decoder.layers) + logger.info(f"Model initialized with {actual_num_layers} layers") + if output_embeddings: + logger.info( + f"Embedding extraction: model has {actual_num_layers} layers " + f"(from original {original_num_layers} layers)" + ) + else: + logger.warning("Could not determine number of layers from model structure") + + logger.info(f"Loading weights from: {resolved_ckpt_dir}") + _load_model_weights_from_checkpoint( + checkpoint_path=str(resolved_ckpt_dir), + model=model, + dist_ckpt_strictness="ignore_all", + ) + logger.info("Weights loaded successfully") + + # ------------------------------------------------------------------------- + # Step 6: Create dataset and dataloader + # ------------------------------------------------------------------------- + logger.info(f"Loading dataset from: {fasta_path}") + dataset = SimpleFastaDataset( + fasta_path=fasta_path, + tokenizer=tokenizer, + prepend_bos=prepend_bos, + custom_loss_masker=None, + ) + + data_parallel_rank = parallel_state.get_data_parallel_rank() + data_parallel_size = parallel_state.get_data_parallel_world_size() + + dataloader = build_pretraining_data_loader( + dataset=dataset, + consumed_samples=0, + dataloader_type="single", + micro_batch_size=micro_batch_size, + num_workers=4, + data_sharding=False, + collate_fn=_padding_collate_fn_factory( + pad_token_id=getattr(tokenizer, "pad_id", 0), + min_length=min_length, + ), + pin_memory=True, + persistent_workers=False, + data_parallel_rank=data_parallel_rank, + data_parallel_size=data_parallel_size, + drop_last=False, + ) + + # ------------------------------------------------------------------------- + # Step 7: Run prediction loop + # ------------------------------------------------------------------------- + logger.info("Starting prediction loop...") + predictions: list[dict[str, Tensor]] = [] + + # Get ranks for file naming (matching original PredictionWriter behavior) + global_rank = get_rank_safe() + num_files_written = 0 + + with torch.no_grad(): + for batch_idx, batch_data in enumerate(dataloader): + # Move to GPU + batch_gpu = { + k: v.cuda(non_blocking=True) if isinstance(v, torch.Tensor) else v for k, v in batch_data.items() + } + + # Apply context parallel slicing (seq_idx must NOT be sliced) + if context_parallel_size > 1: + seq_idx = batch_gpu.pop("seq_idx", None) + batch_gpu = get_batch_on_this_cp_rank(batch_gpu) + if seq_idx is not None: + batch_gpu["seq_idx"] = seq_idx + + # Forward pass + result = _predict_step( + model=model[0], + batch=batch_gpu, + output_log_prob_seqs=output_log_prob_seqs, + log_prob_collapse_option=log_prob_collapse_option, + context_parallel_size=context_parallel_size, + output_embeddings=output_embeddings, + ) + + if result is not None: + predictions.append({k: v.cpu() for k, v in result.items()}) + + if (batch_idx + 1) % 10 == 0: + logger.info(f"Processed batch {batch_idx + 1}/{len(dataloader)}") + + # Write at batch interval + if write_interval == "batch" and output_dir is not None and predictions: + _, num_files_written, _ = _write_predictions_batch( + predictions=predictions[0], + output_dir=output_dir, + batch_idx=batch_idx, + global_rank=global_rank, + dp_rank=data_parallel_rank, + files_per_subdir=files_per_subdir, + num_files_written=num_files_written, + data_parallel_world_size=data_parallel_size, + ) + predictions = [] + + # Write at epoch end + if write_interval == "epoch" and output_dir is not None and predictions: + combined = batch_collator( + predictions, + batch_dim=0, + seq_dim=1, + batch_dim_key_defaults={}, + seq_dim_key_defaults={}, + ) + _write_predictions_epoch( + predictions=combined, + output_dir=output_dir, + global_rank=global_rank, + dp_rank=data_parallel_rank, + ) + + # Write sequence index map + if output_dir is not None: + output_dir.mkdir(parents=True, exist_ok=True) + dataset.write_idx_map(output_dir) + + logger.info("Prediction complete!") + + # Cleanup + if dist.is_initialized(): + dist.barrier() + dist.destroy_process_group() + + +# ============================================================================= +# Entry Point +# ============================================================================= + + +def main() -> None: + """CLI entry point for Evo2 prediction.""" + args = parse_args() + predict( + fasta_path=args.fasta, + ckpt_dir=args.ckpt_dir, + output_dir=args.output_dir, + # Parallelism settings + tensor_parallel_size=args.tensor_parallel_size, + pipeline_model_parallel_size=args.pipeline_model_parallel_size, + context_parallel_size=args.context_parallel_size, + no_sequence_parallel=args.no_sequence_parallel, + # Precision settings + mixed_precision_recipe=args.mixed_precision_recipe, + # Batch/sequence settings + micro_batch_size=args.micro_batch_size, + min_length=args.min_length, + prepend_bos=args.prepend_bos, + # Output settings + write_interval=args.write_interval, + files_per_subdir=args.files_per_subdir, + output_log_prob_seqs=args.output_log_prob_seqs, + log_prob_collapse_option=args.log_prob_collapse_option, + # Embedding extraction + embedding_layer=args.embedding_layer, + ) + + +if __name__ == "__main__": + main() diff --git a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/train.py b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/train.py index 46289038e2..3561a51fc7 100644 --- a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/train.py +++ b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/train.py @@ -236,6 +236,8 @@ def parse_args(args: Optional[List[str]] = None) -> argparse.Namespace: # ) # FIXME not supported in megatron # parser.add_argument("--wandb-offline", action="store_true", help="Use wandb in offline mode") # TODO implement parser.add_argument("--sequence-parallel", action="store_true", help="Set to enable sequence parallelism.") # DONE + parser.add_argument("--no-fp8-wgrad", action="store_true", help="Set to disable fp8 weight gradients.") + parser.add_argument("--no-fp8-param-gather", action="store_true", help="Set to disable fp8 parameter gathering.") parser.add_argument( "--mixed-precision-recipe", type=str, @@ -598,12 +600,12 @@ def parse_args(args: Optional[List[str]] = None) -> argparse.Namespace: # default=True, # help="Disable saving the last checkpoint.", # ) # TODO implement - parser.add_argument( - "--lora-finetune", action="store_true", help="Use LoRA fine-tuning", default=False - ) # TODO implement - parser.add_argument( - "--lora-checkpoint-path", type=str, default=None, help="LoRA checkpoint path" - ) # TODO implement + # parser.add_argument( + # "--lora-finetune", action="store_true", help="Use LoRA fine-tuning", default=False + # ) # TODO implement + # parser.add_argument( + # "--lora-checkpoint-path", type=str, default=None, help="LoRA checkpoint path" + # ) # TODO implement parser.add_argument( "--no-calculate-per-token-loss", action="store_true", @@ -647,7 +649,16 @@ def parse_args(args: Optional[List[str]] = None) -> argparse.Namespace: default=False, help="Enable NVIDIA fault tolerance. This only works on internal NVIDIA clusters.", ) # DONE - parser.add_argument( + + # Optimizer format + optimizer_fmt_group = parser.add_mutually_exclusive_group(required=False) + optimizer_fmt_group.add_argument( + "--optim-fmt-pre-mcore-014", + action="store_true", + default=False, + help="Use the pre-Megatron-Core-v0.14 optimizer format.", + ) + optimizer_fmt_group.add_argument( "--optim-full-reshardable", action="store_true", default=False, @@ -765,6 +776,15 @@ def train(args: argparse.Namespace) -> None: cfg.checkpoint.exit_on_missing_checkpoint = False cfg.checkpoint.dist_ckpt_strictness = "assume_ok_unexpected" + if args.no_fp8_wgrad: + # change if a change is requested to the mixed precision recipe + cfg.mixed_precision.fp8_wgrad = False + if args.grad_reduce_in_fp32: + cfg.mixed_precision.grad_reduce_in_fp32 = True + cfg.ddp.grad_reduce_in_fp32 = True + if args.no_fp8_param_gather: + cfg.mixed_precision.fp8_param_gather = False + # 3. Apply Manual Overrides (for settings not exposed in recipe kwargs) if args.no_renormalize_loss: cfg.model.to_upper = "weighted" # rather than "normalized_weighted" @@ -824,7 +844,10 @@ def train(args: argparse.Namespace) -> None: cfg.optimizer.log_num_zeros_in_grad = args.log_num_zeros_in_grad cfg.optimizer.clip_grad = args.clip_grad # Optimizer checkpointing resharding - cfg.checkpoint.dist_ckpt_optim_fully_reshardable = args.optim_full_reshardable + if args.optim_fmt_pre_mcore_014: + cfg.checkpoint.dist_ckpt_save_pre_mcore_014 = True + elif args.optim_full_reshardable: + cfg.checkpoint.dist_ckpt_optim_fully_reshardable = True cfg.dataset.num_workers = args.workers @@ -832,7 +855,6 @@ def train(args: argparse.Namespace) -> None: cfg.ddp.align_param_gather = args.align_param_gather cfg.ddp.overlap_param_gather = args.overlap_param_gather cfg.ddp.overlap_grad_reduce = args.overlap_grad_reduce - cfg.ddp.grad_reduce_in_fp32 = args.grad_reduce_in_fp32 cfg.ddp.check_for_nan_in_grad = not args.no_check_for_nan_in_grad if args.use_megatron_comm_overlap_llama3_8k: # Pick the floating point appropriate config. diff --git a/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/__init__.py b/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/__init__.py new file mode 100644 index 0000000000..018372aa39 --- /dev/null +++ b/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/__init__.py @@ -0,0 +1,16 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Evo2 tests package.""" diff --git a/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/conftest.py b/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/conftest.py index 8c5dbe4018..de8ace18ec 100644 --- a/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/conftest.py +++ b/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/conftest.py @@ -16,14 +16,17 @@ # conftest.py import gc +import os +import random +import signal +import time +from pathlib import Path +import numpy as np import pytest import torch -# from bionemo.testing.torch import get_device_and_memory_allocated - - def get_device_and_memory_allocated() -> str: """Get the current device index, name, and memory usage.""" current_device_index = torch.cuda.current_device() @@ -63,16 +66,150 @@ def pytest_sessionfinish(session, exitstatus): ) +def _cleanup_child_processes(): + """Kill any orphaned child processes that might be holding GPU memory. + + This is particularly important for tests that spawn subprocesses via torchrun. + """ + import subprocess + + current_pid = os.getpid() + try: + # Find child processes + result = subprocess.run( + ["pgrep", "-P", str(current_pid)], check=False, capture_output=True, text=True, timeout=5 + ) + child_pids = result.stdout.strip().split("\n") + for pid_str in child_pids: + if pid_str: + try: + pid = int(pid_str) + os.kill(pid, signal.SIGTERM) + except (ValueError, ProcessLookupError, PermissionError): + pass + except (subprocess.TimeoutExpired, FileNotFoundError): + pass + + +def _thorough_gpu_cleanup(): + """Perform thorough GPU memory cleanup.""" + if not torch.cuda.is_available(): + return + + # Synchronize all CUDA streams to ensure all operations are complete + torch.cuda.synchronize() + + # Clear all cached memory + torch.cuda.empty_cache() + + # Reset peak memory stats + torch.cuda.reset_peak_memory_stats() + + # Run garbage collection multiple times to ensure all objects are collected + for _ in range(3): + gc.collect() + + # Another sync and cache clear after gc + torch.cuda.synchronize() + torch.cuda.empty_cache() + + # Small sleep to allow GPU memory to be fully released + time.sleep(0.1) + + +def _reset_random_seeds(): + """Reset random seeds to ensure reproducibility across tests. + + Some tests may modify global random state, which can affect subsequent tests + that depend on random splitting (like dataset preprocessing). + """ + # Reset Python's random module + random.seed(None) + + # Reset NumPy's random state (intentionally using legacy API to reset global state) + np.random.seed(None) # noqa: NPY002 + + # Reset PyTorch's random state + torch.seed() + if torch.cuda.is_available(): + torch.cuda.seed_all() + + @pytest.fixture(autouse=True) def cleanup_after_test(): - """Clean up GPU memory after each test.""" + """Clean up GPU memory and reset state after each test.""" + # Reset random seeds before the test to ensure reproducibility + _reset_random_seeds() + yield - if torch.cuda.is_available(): - torch.cuda.empty_cache() - gc.collect() + + # After the test, perform thorough cleanup + _thorough_gpu_cleanup() + + # Clean up any orphaned child processes (important for subprocess tests) + _cleanup_child_processes() + + # Final garbage collection + gc.collect() def pytest_addoption(parser: pytest.Parser): """Pytest configuration for bionemo.evo2.run tests. Adds custom command line options for dataset paths.""" parser.addoption("--dataset-dir", action="store", default=None, help="Path to preprocessed dataset directory") parser.addoption("--training-config", action="store", default=None, help="Path to training data config YAML file") + + +# ============================================================================= +# Session-scoped checkpoint fixtures for sharing across test files +# ============================================================================= + + +@pytest.fixture(scope="session") +def mbridge_checkpoint_1b_8k_bf16(tmp_path_factory) -> Path: + """Session-scoped MBridge checkpoint for the 1b-8k-bf16 model. + + This fixture converts the NeMo2 checkpoint to MBridge format once per test session, + allowing it to be shared across multiple test files (test_infer.py, test_predict.py, etc.). + + Returns: + Path to the MBridge checkpoint iteration directory (e.g., .../iter_0000001) + """ + from bionemo.core.data.load import load + from bionemo.evo2.data.dataset_tokenizer import DEFAULT_HF_TOKENIZER_MODEL_PATH_512 + from bionemo.evo2.utils.checkpoint.nemo2_to_mbridge import run_nemo2_to_mbridge + + try: + nemo2_ckpt_path = load("evo2/1b-8k-bf16:1.0") + except ValueError as e: + if e.args[0].endswith("does not have an NGC URL."): + pytest.skip( + "Please re-run test with `BIONEMO_DATA_SOURCE=pbss py.test ...`, " + "one or more files are missing from ngc." + ) + else: + raise e + + output_dir = tmp_path_factory.mktemp("mbridge_ckpt_1b_8k_bf16_session") + mbridge_ckpt_dir = run_nemo2_to_mbridge( + nemo2_ckpt_dir=nemo2_ckpt_path, + tokenizer_path=DEFAULT_HF_TOKENIZER_MODEL_PATH_512, + mbridge_ckpt_dir=output_dir / "evo2_1b_mbridge", + model_size="1b", + seq_length=8192, + mixed_precision_recipe="bf16_mixed", + vortex_style_fp8=False, + ) + return mbridge_ckpt_dir / "iter_0000001" + + +@pytest.fixture(scope="module") +def mbridge_checkpoint_path(mbridge_checkpoint_1b_8k_bf16) -> Path: + """Module-scoped alias for the session-scoped 1b-8k-bf16 checkpoint. + + This provides backward compatibility for tests that use the name 'mbridge_checkpoint_path'. + The actual checkpoint is shared at session scope via mbridge_checkpoint_1b_8k_bf16. + + Returns: + Path to the MBridge checkpoint iteration directory + """ + return mbridge_checkpoint_1b_8k_bf16 diff --git a/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/data/test_fasta_dataset.py b/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/data/test_fasta_dataset.py index a9e9ed5cdd..851f9b7f8c 100644 --- a/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/data/test_fasta_dataset.py +++ b/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/data/test_fasta_dataset.py @@ -16,80 +16,142 @@ # See the License for the specific language governing permissions and # limitations under the License. -# FIXME bring back these tests -# from pathlib import Path +"""Tests for SimpleFastaDataset.""" -# import pytest -# import torch -# from bionemo.evo2.data.fasta_dataset import SimpleFastaDataset -# from bionemo.testing.data.fasta import create_fasta_file +import json +from pathlib import Path +import pytest +import torch +from megatron.bridge.training.tokenizers.config import TokenizerConfig +from megatron.bridge.training.tokenizers.tokenizer import build_tokenizer -# # from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer +from bionemo.evo2.data.dataset_tokenizer import DEFAULT_HF_TOKENIZER_MODEL_PATH +from bionemo.evo2.data.fasta_dataset import SimpleFastaDataset +from bionemo.evo2.data.test_utils.create_fasta_file import create_fasta_file -# def get_nmt_tokenizer(tokenizer_type: str): -# """FIXME use an automodel HF tokenizer.""" -# raise NotImplementedError("FIXME use an automodel HF tokenizer.") +@pytest.fixture +def tokenizer(): + """Return a HuggingFace tokenizer for testing.""" + return build_tokenizer( + TokenizerConfig( + tokenizer_type="HuggingFaceTokenizer", + hf_tokenizer_kwargs={"trust_remote_code": False}, + tokenizer_model=DEFAULT_HF_TOKENIZER_MODEL_PATH, + ) + ) -# @pytest.fixture -# def fasta_dataset(tmp_path: Path) -> None: -# """Fixture to create a SimpleFastaDataset for testing.""" -# test_fasta_file_path = create_fasta_file(tmp_path / "test.fasta", num_sequences=10, sequence_length=100) -# tokenizer = get_nmt_tokenizer("byte-level") -# return SimpleFastaDataset(test_fasta_file_path, tokenizer) +@pytest.fixture +def fasta_dataset(tmp_path: Path, tokenizer) -> SimpleFastaDataset: + """Fixture to create a SimpleFastaDataset for testing.""" + test_fasta_file_path = create_fasta_file(tmp_path / "test.fasta", num_sequences=10, sequence_length=100) + return SimpleFastaDataset(test_fasta_file_path, tokenizer) -# def test_simple_fasta_dataset_initialization(fasta_dataset: SimpleFastaDataset) -> None: -# """Test initialization of SimpleFastaDataset.""" -# # Check dataset length -# assert len(fasta_dataset) == 10, "Dataset length should match number of sequences" +def test_simple_fasta_dataset_initialization(fasta_dataset: SimpleFastaDataset) -> None: + """Test initialization of SimpleFastaDataset.""" + # Check dataset length + assert len(fasta_dataset) == 10, "Dataset length should match number of sequences" -# # Check seqids -# assert len(fasta_dataset.seqids) == 10, "Seqids should match number of sequences" + # Check seqids + assert len(fasta_dataset.seqids) == 10, "Seqids should match number of sequences" -# def test_simple_fasta_dataset_getitem(fasta_dataset: SimpleFastaDataset) -> None: -# """Test __getitem__ method of SimpleFastaDataset.""" -# # Test first item -# item = fasta_dataset[0] +def test_simple_fasta_dataset_getitem(fasta_dataset: SimpleFastaDataset) -> None: + """Test __getitem__ method of SimpleFastaDataset.""" + # Test first item + item = fasta_dataset[0] -# # Check keys -# expected_keys = {"tokens", "position_ids", "seq_idx", "loss_mask"} -# assert set(item.keys()) == expected_keys, "Item should have correct keys" + # Check keys + expected_keys = {"tokens", "position_ids", "seq_idx", "loss_mask"} + assert set(item.keys()) == expected_keys, "Item should have correct keys" -# # Check token type -# assert isinstance(item["tokens"], torch.Tensor), "Tokens should be a torch.Tensor" -# assert item["tokens"].dtype == torch.long, "Tokens should be long dtype" + # Check token type + assert isinstance(item["tokens"], torch.Tensor), "Tokens should be a torch.Tensor" + assert item["tokens"].dtype == torch.long, "Tokens should be long dtype" -# # Check position_ids -# assert isinstance(item["position_ids"], torch.Tensor), "Position IDs should be a torch.Tensor" -# assert item["position_ids"].dtype == torch.long, "Position IDs should be long dtype" + # Check position_ids + assert isinstance(item["position_ids"], torch.Tensor), "Position IDs should be a torch.Tensor" + assert item["position_ids"].dtype == torch.long, "Position IDs should be long dtype" -# # Validate sequence index -# assert isinstance(item["seq_idx"], torch.Tensor), "Seq_idx should be a torch.Tensor" -# assert item["seq_idx"].item() == 0, "First item should have seq_idx 0" + # Validate sequence index + assert isinstance(item["seq_idx"], torch.Tensor), "Seq_idx should be a torch.Tensor" + assert item["seq_idx"].item() == 0, "First item should have seq_idx 0" + # Check loss_mask + assert isinstance(item["loss_mask"], torch.Tensor), "Loss mask should be a torch.Tensor" + assert item["loss_mask"].dtype == torch.long, "Loss mask should be long dtype" -# def test_simple_fasta_dataset_write_idx_map(fasta_dataset: SimpleFastaDataset, tmp_path: Path) -> None: -# """Test write_idx_map method of SimpleFastaDataset.""" -# # Create output directory -# output_dir = tmp_path / "output" -# output_dir.mkdir(parents=True, exist_ok=True) + # With prepend_bos=True (default), the first token should be masked + assert item["loss_mask"][0].item() == 0, "First token (BOS) should be masked" -# # Write index map -# fasta_dataset.write_idx_map(output_dir) + # Tokens length should be sequence_length + 1 (for BOS) + # Since we create sequences of length 100, tokens should be 101 + assert len(item["tokens"]) == 101, "Tokens should include BOS token" + assert len(item["position_ids"]) == 101, "Position IDs should match tokens length" -# # Check if file was created -# idx_map_file = output_dir / "seq_idx_map.json" -# assert idx_map_file.exists(), "seq_idx_map.json should be created" -# import json +def test_simple_fasta_dataset_write_idx_map(fasta_dataset: SimpleFastaDataset, tmp_path: Path) -> None: + """Test write_idx_map method of SimpleFastaDataset.""" + # Create output directory + output_dir = tmp_path / "output" + output_dir.mkdir(parents=True, exist_ok=True) -# with open(idx_map_file, "r") as f: -# idx_map = json.load(f) + # Write index map + fasta_dataset.write_idx_map(output_dir) -# assert len(idx_map) == 10, "Index map should have an entry for each sequence" -# for idx, seqid in enumerate(fasta_dataset.seqids): -# assert idx_map[seqid] == idx, f"Index for {seqid} should match" + # Check if file was created + idx_map_file = output_dir / "seq_idx_map.json" + assert idx_map_file.exists(), "seq_idx_map.json should be created" + + with open(idx_map_file) as f: + idx_map = json.load(f) + + assert len(idx_map) == 10, "Index map should have an entry for each sequence" + for idx, seqid in enumerate(fasta_dataset.seqids): + assert idx_map[seqid] == idx, f"Index for {seqid} should match" + + +def test_simple_fasta_dataset_no_bos(tmp_path: Path, tokenizer) -> None: + """Test SimpleFastaDataset without BOS token prepending.""" + test_fasta_file_path = create_fasta_file(tmp_path / "test_no_bos.fasta", num_sequences=5, sequence_length=50) + dataset = SimpleFastaDataset(test_fasta_file_path, tokenizer, prepend_bos=False) + + item = dataset[0] + + # Without BOS, tokens length should equal sequence length + assert len(item["tokens"]) == 50, "Tokens should not include BOS token" + assert len(item["position_ids"]) == 50, "Position IDs should match tokens length" + + # All tokens should be unmasked (loss_mask all 1s) + assert item["loss_mask"].sum().item() == 50, "All tokens should be unmasked without BOS" + + +def test_simple_fasta_dataset_variable_lengths(tmp_path: Path, tokenizer) -> None: + """Test SimpleFastaDataset with variable sequence lengths.""" + sequence_lengths = [50, 100, 150, 200, 75] + test_fasta_file_path = create_fasta_file( + tmp_path / "test_variable.fasta", num_sequences=5, sequence_lengths=sequence_lengths + ) + dataset = SimpleFastaDataset(test_fasta_file_path, tokenizer) + + assert len(dataset) == 5, "Dataset should have 5 sequences" + + # Check each item has the correct length (sequence_length + 1 for BOS) + for i, expected_len in enumerate(sequence_lengths): + item = dataset[i] + assert len(item["tokens"]) == expected_len + 1, f"Sequence {i} should have length {expected_len + 1}" + + +def test_simple_fasta_dataset_iteration(fasta_dataset: SimpleFastaDataset) -> None: + """Test that we can iterate through the entire dataset.""" + count = 0 + for i in range(len(fasta_dataset)): + item = fasta_dataset[i] + assert item is not None, f"Item {i} should not be None" + assert "tokens" in item, f"Item {i} should have 'tokens' key" + count += 1 + + assert count == 10, "Should iterate through all 10 items" diff --git a/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/data/test_sharded_eden_dataset_provider.py b/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/data/test_sharded_eden_dataset_provider.py index cffed080fb..f7b8acd6a6 100644 --- a/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/data/test_sharded_eden_dataset_provider.py +++ b/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/data/test_sharded_eden_dataset_provider.py @@ -26,8 +26,6 @@ from megatron.bridge.training.tokenizers.tokenizer import build_tokenizer from bionemo.evo2.data.dataset_tokenizer import DEFAULT_HF_TOKENIZER_MODEL_PATH, DEFAULT_HF_TOKENIZER_MODEL_PATH_512 - -# FIXME revive this since it might make some tests/training runs easier. from bionemo.evo2.data.sharded_eden_dataset_provider import ( DatasetBuildContext, ShardedEdenDataset, diff --git a/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/data/test_tokenizer.py b/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/data/test_tokenizer.py index e914c989ab..ba526aa944 100644 --- a/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/data/test_tokenizer.py +++ b/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/data/test_tokenizer.py @@ -56,6 +56,45 @@ def test_tokenizer_vocab_size(tokenizer_path: Path, expected_vocab_size: int) -> assert tokenizer.vocab_size == expected_vocab_size +@pytest.mark.parametrize( + "tokenizer_path", + [ + DEFAULT_HF_TOKENIZER_MODEL_PATH, + DEFAULT_HF_TOKENIZER_MODEL_PATH_512, + ], +) +def test_tokenizer_roundtrip_without_spaces(tokenizer_path: Path) -> None: + """Verifies tokenization followed by detokenization returns the original sequence. + + This is critical for character-level tokenizers used in DNA sequence modeling. + The tokenizer should NOT add spaces between tokens during detokenization. + """ + tokenizer = build_tokenizer( + TokenizerConfig( + tokenizer_type="HuggingFaceTokenizer", + hf_tokenizer_kwargs={"trust_remote_code": False}, + tokenizer_model=tokenizer_path, + ) + ) + # Test basic DNA sequence + original = "ATCGATCGATCG" + token_ids = tokenizer.text_to_ids(original) + reconstructed = tokenizer.detokenize(token_ids) + assert reconstructed == original, f"Expected '{original}', got '{reconstructed}'" + + # Test longer sequence with all nucleotides + original_long = "AAAAACCCCCGGGGGTTTTTATCGATCGNNNNN" + token_ids_long = tokenizer.text_to_ids(original_long) + reconstructed_long = tokenizer.detokenize(token_ids_long) + assert reconstructed_long == original_long, f"Expected '{original_long}', got '{reconstructed_long}'" + + # Test sequence with special characters (pipe-delimited tags) + original_tagged = "|info|ATCG|end|" + token_ids_tagged = tokenizer.text_to_ids(original_tagged) + reconstructed_tagged = tokenizer.detokenize(token_ids_tagged) + assert reconstructed_tagged == original_tagged, f"Expected '{original_tagged}', got '{reconstructed_tagged}'" + + def test_tokenizer_handles_long_dna_sequence(tokenizer: Evo2DatasetTokenizer) -> None: """Verifies tokenizer correctly processes a long DNA sequence into expected token IDs. diff --git a/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/models/megatron/hyena/__init__.py b/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/models/megatron/hyena/__init__.py new file mode 100644 index 0000000000..709f3b03b0 --- /dev/null +++ b/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/models/megatron/hyena/__init__.py @@ -0,0 +1,16 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Hyena model tests package.""" diff --git a/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/models/megatron/hyena/test_hyena_mixer_kernel.py b/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/models/megatron/hyena/test_hyena_mixer_kernel.py index 3f4e36b091..2c18811de4 100644 --- a/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/models/megatron/hyena/test_hyena_mixer_kernel.py +++ b/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/models/megatron/hyena/test_hyena_mixer_kernel.py @@ -32,6 +32,8 @@ from bionemo.evo2.models.megatron.hyena.hyena_mixer import HyenaMixer from bionemo.evo2.models.megatron.hyena.hyena_utils import ImplicitModalFilter +from ....utils import find_free_network_port + try: import subquadratic_ops_torch # noqa: F401 @@ -53,7 +55,7 @@ def init_distributed_parallel_state( if not dist.is_initialized(): # Setup minimal environment for single process distributed os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "29500" + os.environ["MASTER_PORT"] = str(find_free_network_port()) os.environ["RANK"] = str(rank) os.environ["WORLD_SIZE"] = str(world_size) @@ -388,7 +390,7 @@ def test_subquadratic_ops_kernel( # noqa: D103 mixer_kernel.zero_grad() # Compare results between PyTorch and CUDA kernel implementations - torch.testing.assert_close(output_features, output_features_kernel, msg=f"Output mismatch for {operator_type}") + torch.testing.assert_close(output_features, output_features_kernel, rtol=0.02, atol=2e-4) torch.testing.assert_close(loss, loss_kernel, msg=f"Loss mismatch for {operator_type}") # Compare gradients diff --git a/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/models/megatron/hyena/test_hyena_utils.py b/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/models/megatron/hyena/test_hyena_utils.py index 59017a1f7c..2ba21709fa 100644 --- a/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/models/megatron/hyena/test_hyena_utils.py +++ b/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/models/megatron/hyena/test_hyena_utils.py @@ -546,11 +546,17 @@ def test_fallback_functions_import_error_messages(self): def test_einops_import_error(self): """Test that the einops import error is raised with the correct message.""" - # Mock the import to fail - with patch.dict("sys.modules", {"einops": None}): - # Re-import the module to trigger the import error - with pytest.raises(ImportError, match="einops is required by the Hyena model but cannot be imported"): - import bionemo.evo2.models.megatron.hyena.hyena_utils - - # Force a reload of the module to trigger the import error - importlib.reload(bionemo.evo2.models.megatron.hyena.hyena_utils) + import bionemo.evo2.models.megatron.hyena.hyena_utils + + try: + # Mock the import to fail + with patch.dict("sys.modules", {"einops": None}): + # Re-import the module to trigger the import error + with pytest.raises(ImportError, match="einops is required by the Hyena model but cannot be imported"): + # Force a reload of the module to trigger the import error + importlib.reload(bionemo.evo2.models.megatron.hyena.hyena_utils) + finally: + # CRITICAL: Always restore the module to its proper state after the test. + # The reload above leaves the module in a corrupted state, which can cause + # subsequent tests to fail (especially test_infer.py tests). + importlib.reload(bionemo.evo2.models.megatron.hyena.hyena_utils) diff --git a/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/run/common.py b/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/run/common.py deleted file mode 100644 index 92ebd07afc..0000000000 --- a/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/run/common.py +++ /dev/null @@ -1,71 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-FileCopyrightText: Copyright (c) 2024 Arc Institute. All rights reserved. -# SPDX-FileCopyrightText: Copyright (c) 2024 Michael Poli. All rights reserved. -# SPDX-FileCopyrightText: Copyright (c) 2024 Stanford University. All rights reserved -# SPDX-License-Identifier: LicenseRef-Apache2 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -def small_training_cmd( - path, - max_steps, - val_check, - global_batch_size: int | None = None, - devices: int = 1, - additional_args: str = "", -): - """Command for training.""" - cmd = ( - f"train_evo2 --mock-data --result-dir {path} --devices {devices} " - "--model-size 1b_nv --num-layers 4 --hybrid-override-pattern SDH* --limit-val-batches 1 " - "--no-activation-checkpointing --add-bias-output --create-tensorboard-logger --create-tflops-callback " - f"--max-steps {max_steps} --warmup-steps 1 --val-check-interval {val_check} " - f"--seq-length 16 --hidden-dropout 0.1 --attention-dropout 0.1 {additional_args} " - f"{'--global-batch-size ' + str(global_batch_size) if global_batch_size is not None else ''}" - ) - return cmd - - -def small_training_finetune_cmd( - path, - max_steps, - val_check, - prev_ckpt, - devices: int = 1, - global_batch_size: int | None = None, - create_tflops_callback: bool = True, - additional_args: str = "", -): - """Command for finetuning.""" - cmd = ( - f"train_evo2 --mock-data --result-dir {path} --devices {devices} " - "--model-size 1b_nv --num-layers 4 --hybrid-override-pattern SDH* --limit-val-batches 1 " - "--no-activation-checkpointing --add-bias-output --create-tensorboard-logger " - f"--max-steps {max_steps} --warmup-steps 1 --val-check-interval {val_check} " - f"--seq-length 16 --hidden-dropout 0.1 --attention-dropout 0.1 {additional_args} --ckpt-dir {prev_ckpt} " - f"{'--create-tflops-callback' if create_tflops_callback else ''} " - f"{'--global-batch-size ' + str(global_batch_size) if global_batch_size is not None else ''}" - ) - return cmd - - -def predict_cmd(ckpt_dir: str, output_dir: str, fasta_file_path: str, additional_args: str = ""): - """Command fro predict.""" - cmd = ( - f"predict_evo2 --fasta {fasta_file_path} --ckpt-dir {ckpt_dir} --output-dir {output_dir} " - "--model-size 1b_nv --num-layers 4 --hybrid-override-pattern SDH* --tensor-parallel-size 1 " - f"--pipeline-model-parallel-size 1 --context-parallel-size 1 {additional_args}" - ) - - return cmd diff --git a/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/run/test_finetune.py b/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/run/test_finetune.py deleted file mode 100644 index fce04f2d24..0000000000 --- a/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/run/test_finetune.py +++ /dev/null @@ -1,201 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-FileCopyrightText: Copyright (c) 2024 Arc Institute. All rights reserved. -# SPDX-FileCopyrightText: Copyright (c) 2024 Michael Poli. All rights reserved. -# SPDX-FileCopyrightText: Copyright (c) 2024 Stanford University. All rights reserved -# SPDX-License-Identifier: LicenseRef-Apache2 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# FIXME bring back these tests -# import re - -# import pytest -# from bionemo.testing.subprocess_utils import run_command_in_subprocess - -# from .common import small_training_cmd, small_training_finetune_cmd - - -# def extract_val_losses(log_text: str, n: int): -# """ -# Extracts validation losses every n-th occurrence (starting at 0). -# Iteration index is derived by counting val_loss appearances. - -# Args: -# log_text (str): The log output as a string. -# n (int): Interval of occurrences (e.g., n=5 -> get val_loss at 0, 5, 10...). - -# Returns: -# List of tuples: (step, validation_loss_value). -# """ -# # Regex to capture val_loss values -# pattern = re.compile(r"val_loss: ([0-9.]+)") - -# results = [] -# for idx, match in enumerate(pattern.finditer(log_text)): -# if idx % n == 0: # take every n-th val_loss occurrence -# results.append((idx, float(match.group(1)))) - -# return results - - -# @pytest.mark.timeout(2048) # Optional: fail if the test takes too long. -# @pytest.mark.slow -# @pytest.mark.parametrize("with_peft", [True, False]) -# def test_train_evo2_finetune_runs(tmp_path, with_peft: bool): -# """ -# This test runs the `train_evo2` command with mock data in a temporary directory. -# It uses the temporary directory provided by pytest as the working directory. -# The command is run in a subshell, and we assert that it returns an exit code of 0. -# """ -# num_steps = 25 -# val_steps = 10 -# global_batch_size = 128 - -# # Note: The command assumes that `train_evo2` is in your PATH. -# command = small_training_cmd( -# tmp_path / "pretrain", -# max_steps=num_steps, -# val_check=val_steps, -# global_batch_size=global_batch_size, -# additional_args=" --lr 0.1 ", -# ) -# stdout_pretrain: str = run_command_in_subprocess(command=command, path=str(tmp_path)) -# assert "Restoring model weights from RestoreConfig(path='" not in stdout_pretrain - -# log_dir = tmp_path / "pretrain" / "evo2" -# checkpoints_dir = log_dir / "checkpoints" -# tensorboard_dir = log_dir / "dev" - -# # Check if logs dir exists -# assert log_dir.exists(), "Logs folder should exist." -# # Check if checkpoints dir exists -# assert checkpoints_dir.exists(), "Checkpoints folder does not exist." - -# expected_checkpoint_suffix = f"{num_steps * global_batch_size}.0-last" -# # Check if any subfolder ends with the expected suffix -# matching_subfolders = [ -# p for p in checkpoints_dir.iterdir() if p.is_dir() and (expected_checkpoint_suffix in p.name) -# ] - -# assert matching_subfolders, ( -# f"No checkpoint subfolder ending with '{expected_checkpoint_suffix}' found in {checkpoints_dir}." -# ) - -# # Check if directory with tensorboard logs exists -# assert tensorboard_dir.exists(), "TensorBoard logs folder does not exist." - -# event_files = list(tensorboard_dir.rglob("events.out.tfevents*")) -# assert len(event_files) == 1, f"No or multiple TensorBoard event files found under {tensorboard_dir}" - -# val_losses = extract_val_losses(stdout_pretrain, val_steps) - -# for i in range(1, len(val_losses)): -# assert val_losses[i][1] <= val_losses[i - 1][1], ( -# f"Validation loss increased at step {val_losses[i][0]}: {val_losses[i][1]} > {val_losses[i - 1][1]}" -# ) - -# # Check if directory with tensorboard logs exists -# assert tensorboard_dir.exists(), "TensorBoard logs folder does not exist." -# # Recursively search for files with tensorboard logger -# event_files = list(tensorboard_dir.rglob("events.out.tfevents*")) -# assert event_files, f"No TensorBoard event files found under {tensorboard_dir}" -# assert len(matching_subfolders) == 1, "Only one checkpoint subfolder should be found." -# if with_peft: -# result_dir = tmp_path / "lora_finetune" -# additional_args = "--lora-finetune --lr 0.1 " -# else: -# result_dir = tmp_path / "finetune" -# additional_args = " --lr 0.1 " - -# command_finetune = small_training_finetune_cmd( -# result_dir, -# max_steps=num_steps, -# val_check=val_steps, -# global_batch_size=global_batch_size, -# prev_ckpt=matching_subfolders[0], -# create_tflops_callback=not with_peft, -# additional_args=additional_args, -# ) -# stdout_finetune: str = run_command_in_subprocess(command=command_finetune, path=str(tmp_path)) -# assert "Restoring model weights from RestoreConfig(path='" in stdout_finetune - -# log_dir_ft = result_dir / "evo2" -# checkpoints_dir_ft = log_dir_ft / "checkpoints" -# tensorboard_dir_ft = log_dir_ft / "dev" - -# # Check if logs dir exists -# assert log_dir_ft.exists(), "Logs folder should exist." -# # Check if checkpoints dir exists -# assert checkpoints_dir_ft.exists(), "Checkpoints folder does not exist." - -# expected_checkpoint_suffix = f"{num_steps * global_batch_size}.0-last" -# # Check if any subfolder ends with the expected suffix -# matching_subfolders_finetune = [ -# p for p in checkpoints_dir_ft.iterdir() if p.is_dir() and (expected_checkpoint_suffix in p.name) -# ] - -# assert matching_subfolders_finetune, ( -# f"No checkpoint subfolder ending with '{expected_checkpoint_suffix}' found in {checkpoints_dir_ft}." -# ) - -# # Check if directory with tensorboard logs exists -# assert tensorboard_dir_ft.exists(), "TensorBoard logs folder does not exist." -# # Recursively search for files with tensorboard logger -# event_files_ft = list(tensorboard_dir_ft.rglob("events.out.tfevents*")) -# assert len(event_files_ft) == 1, f"No or multiple TensorBoard event files found under {tensorboard_dir_ft}" - -# val_losses_ft = extract_val_losses(stdout_finetune, val_steps) - -# # Check that each validation loss is less than or equal to the previous one -# for i in range(1, len(val_losses_ft)): -# assert val_losses_ft[i][1] <= val_losses_ft[i - 1][1], ( -# f"Validation loss increased at step {val_losses_ft[i][0]}: {val_losses_ft[i][1]} > {val_losses_ft[i - 1][1]}" -# ) - -# assert len(matching_subfolders_finetune) == 1, "Only one checkpoint subfolder should be found." - -# # With LoRA, test resuming from a saved LoRA checkpoint -# if with_peft: -# result_dir = tmp_path / "lora_finetune_resume" - -# # Resume from LoRA checkpoint -# command_resume_finetune = small_training_finetune_cmd( -# result_dir, -# max_steps=num_steps, -# val_check=val_steps, -# global_batch_size=global_batch_size, -# prev_ckpt=matching_subfolders[0], -# create_tflops_callback=False, -# additional_args=f"--lora-finetune --lora-checkpoint-path {matching_subfolders_finetune[0]} --lr 0.1 ", -# ) -# stdout_finetune: str = run_command_in_subprocess(command=command_resume_finetune, path=str(tmp_path)) - -# log_dir_ft = result_dir / "evo2" -# checkpoints_dir_ft = log_dir_ft / "checkpoints" -# tensorboard_dir_ft = log_dir_ft / "dev" - -# # Check if logs dir exists -# assert log_dir_ft.exists(), "Logs folder should exist." -# # Check if checkpoints dir exists -# assert checkpoints_dir_ft.exists(), "Checkpoints folder does not exist." - -# # Recursively search for files with tensorboard logger -# event_files_ft = list(tensorboard_dir_ft.rglob("events.out.tfevents*")) -# assert len(event_files_ft) == 1, f"No or multiple TensorBoard event files found under {tensorboard_dir_ft}" - -# val_losses_ft = extract_val_losses(stdout_finetune, val_steps) - -# # Check that each validation loss is less than or equal to the previous one -# for i in range(1, len(val_losses_ft)): -# assert val_losses_ft[i][1] <= val_losses_ft[i - 1][1], ( -# f"Validation loss increased at step {val_losses_ft[i][0]}: {val_losses_ft[i][1]} > {val_losses_ft[i - 1][1]}" -# ) diff --git a/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/run/test_infer.py b/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/run/test_infer.py index 3c4de4d918..944e899cf4 100644 --- a/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/run/test_infer.py +++ b/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/run/test_infer.py @@ -16,63 +16,438 @@ # See the License for the specific language governing permissions and # limitations under the License. +"""Tests for Evo2 text generation (inference) using MBridge. -# import pytest -# import torch -# from bionemo.core.data.load import load -# from bionemo.evo2.run.infer import infer -# from bionemo.testing.megatron_parallel_state_utils import clean_parallel_state_context -# from bionemo.testing.torch import check_fp8_support - - -# RANDOM_SEED = 42 -# FIXME bring back these tests - -# @pytest.mark.parametrize("fast", [True, False]) -# def test_run_infer(fast: bool): -# # Create PTL trainer. -# tensor_parallel_size = 1 -# pipeline_model_parallel_size = 1 -# context_parallel_size = 1 -# temperature = 1.0 -# top_k = 0 -# top_p = 0.0 -# max_new_tokens = 1 - -# # Generation args. -# default_prompt = ( -# "|d__Bacteria;" -# + "p__Pseudomonadota;" -# + "c__Gammaproteobacteria;" -# + "o__Enterobacterales;" -# + "f__Enterobacteriaceae;" -# + "g__Escherichia;" -# + "s__Escherichia|" -# ) -# try: -# checkpoint_path = load("evo2/1b-8k:1.0") -# except ValueError as e: -# if e.args[0].endswith("does not have an NGC URL."): -# raise ValueError( -# "Please re-run test with `BIONEMO_DATA_SOURCE=pbss py.test ...`, " -# "one or more files are missing from ngc." -# ) -# else: -# raise e - -# is_fp8_supported, _, _ = check_fp8_support(torch.cuda.current_device()) - -# with clean_parallel_state_context(): -# infer( -# prompt=default_prompt, -# ckpt_dir=checkpoint_path, -# temperature=temperature, -# top_k=top_k, -# top_p=top_p, -# max_new_tokens=max_new_tokens, -# tensor_parallel_size=tensor_parallel_size, -# pipeline_model_parallel_size=pipeline_model_parallel_size, -# context_parallel_size=context_parallel_size, -# vortex_style_fp8=is_fp8_supported, -# flash_decode=fast, -# ) +NOTE: Autoregressive generation tests may fail due to: +1. FP8 execution requires sequence dimensions divisible by 8/16 +2. The vortex flash_decode path needs additional integration work + +The core forward pass (predict.py) and HyenaInferenceContext are tested +in test_evo2.py which has working test_forward_manual and test_forward_ckpt_conversion. +""" + +import copy +import os +import subprocess + +import pytest +import torch + +from bionemo.evo2.models.evo2_provider import HyenaInferenceContext + +from ..utils import find_free_network_port + + +# Capture environment at import time (consistent with test_predict.py) +PRETEST_ENV = copy.deepcopy(os.environ) + +# Note: mbridge_checkpoint_path fixture is provided by conftest.py at session scope + + +def test_infer_runs(mbridge_checkpoint_path, tmp_path): + """Test that infer.py runs without errors.""" + output_file = tmp_path / "output.txt" + + # Use a longer DNA prompt to meet FP8 dimension requirements (divisible by 8) + # 64 characters should be safe + prompt = "ATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCG" + open_port = find_free_network_port() + + cmd = [ + "torchrun", + "--nproc_per_node", + "1", + "--nnodes", + "1", + "--master_port", + str(open_port), + "-m", + "bionemo.evo2.run.infer", + "--ckpt-dir", + str(mbridge_checkpoint_path), + "--prompt", + prompt, + "--max-new-tokens", + "10", + "--output-file", + str(output_file), + "--temperature", + "1.0", # Non-zero temperature required by MCore + "--top-k", + "1", # Top-k=1 for greedy decoding + ] + + env = copy.deepcopy(PRETEST_ENV) + + result = subprocess.run( + cmd, + check=False, + capture_output=True, + text=True, + timeout=300, # 5 minutes + env=env, + ) + + assert result.returncode == 0, f"infer command failed:\nstdout: {result.stdout}\nstderr: {result.stderr}" + assert output_file.exists(), "Output file was not created" + + # Check that output contains generated text + generated = output_file.read_text() + assert len(generated) > 0, "Generated text is empty" + + +@pytest.mark.parametrize("temperature", [0.5, 1.0]) +def test_infer_temperature(mbridge_checkpoint_path, tmp_path, temperature): + """Test that different temperatures produce output.""" + output_file = tmp_path / f"output_temp_{temperature}.txt" + # Use a longer prompt for FP8 compatibility + prompt = "ATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCG" + open_port = find_free_network_port() + + cmd = [ + "torchrun", + "--nproc_per_node", + "1", + "--nnodes", + "1", + "--master_port", + str(open_port), + "-m", + "bionemo.evo2.run.infer", + "--ckpt-dir", + str(mbridge_checkpoint_path), + "--prompt", + prompt, + "--max-new-tokens", + "5", + "--temperature", + str(temperature), + "--output-file", + str(output_file), + ] + + env = copy.deepcopy(PRETEST_ENV) + + result = subprocess.run( + cmd, + check=False, + capture_output=True, + text=True, + timeout=300, # 5 minutes + env=env, + ) + + assert result.returncode == 0, f"infer command failed:\nstdout: {result.stdout}\nstderr: {result.stderr}" + + +def test_infer_top_k(mbridge_checkpoint_path, tmp_path): + """Test top-k sampling.""" + output_file = tmp_path / "output_topk.txt" + # Use a longer prompt for FP8 compatibility + prompt = "ATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCG" + open_port = find_free_network_port() + + cmd = [ + "torchrun", + "--nproc_per_node", + "1", + "--nnodes", + "1", + "--master_port", + str(open_port), + "-m", + "bionemo.evo2.run.infer", + "--ckpt-dir", + str(mbridge_checkpoint_path), + "--prompt", + prompt, + "--max-new-tokens", + "5", + "--top-k", + "4", # Only sample from top 4 tokens (A, C, G, T) + "--output-file", + str(output_file), + ] + + env = copy.deepcopy(PRETEST_ENV) + + result = subprocess.run( + cmd, + check=False, + capture_output=True, + text=True, + timeout=300, # 5 minutes + env=env, + ) + + assert result.returncode == 0, f"infer command failed:\nstdout: {result.stdout}\nstderr: {result.stderr}" + + +def test_infer_phylogenetic_prompt(mbridge_checkpoint_path, tmp_path): + """Test generation with a phylogenetic lineage prompt. + + Evo2 is trained with phylogenetic tags, so generation should work + well when conditioned on these tags. Using a longer prompt for FP8. + """ + output_file = tmp_path / "output_phylo.txt" + + # Phylogenetic prompt (padded to be longer for FP8 compatibility) + prompt = ( + "|d__Bacteria;" + "p__Pseudomonadota;" + "c__Gammaproteobacteria;" + "o__Enterobacterales;" + "f__Enterobacteriaceae;" + "g__Escherichia;" + "s__Escherichia|" + ) + open_port = find_free_network_port() + + cmd = [ + "torchrun", + "--nproc_per_node", + "1", + "--nnodes", + "1", + "--master_port", + str(open_port), + "-m", + "bionemo.evo2.run.infer", + "--ckpt-dir", + str(mbridge_checkpoint_path), + "--prompt", + prompt, + "--max-new-tokens", + "20", + "--temperature", + "1.0", # Non-zero temperature required by MCore + "--top-k", + "1", # Top-k=1 for greedy decoding + "--output-file", + str(output_file), + ] + + env = copy.deepcopy(PRETEST_ENV) + + result = subprocess.run( + cmd, + check=False, + capture_output=True, + text=True, + timeout=300, # 5 minutes + env=env, + ) + + assert result.returncode == 0, f"infer command failed:\nstdout: {result.stdout}\nstderr: {result.stderr}" + assert output_file.exists(), "Output file was not created" + + generated = output_file.read_text() + assert len(generated) > 0, "Generated text is empty" + + +# DNA prompts for reproducibility tests (from test_prompt.py) +PROMPT_1 = "GAATAGGAACAGCTCCGGTCTACAGCTCCCAGCGTGAGCGACGCAGAAGACGGTGATTTCTGCATTTCCATCTGAGGTACCGGGTTCATCTCACTAGGGAGTGCCAGACAGTGGGCGCAGGCCAGTGTGTGTGCGCACCGTGCGCGAGCCGAAGCAGGG" +PROMPT_2 = "GATCACAGGTCTATCACCCTATTAACCACTCACGGGAGCTCTCCATGCATTTGGTATTTTCGTCTGGGGGGTATGCACGCGATAGCATTGCGAGACGCTGGAGCCGGAGCACCCTATGTCGCAGTATCTGTCTTTGATTCCTGCCTCATCCTATTATTT" + + +def run_infer_subprocess( + mbridge_checkpoint_path, + prompt: str, + output_file, + max_new_tokens: int = 10, + temperature: float = 1.0, + top_k: int = 1, + seed: int = 42, +): + """Helper function to run inference as a subprocess. + + Args: + mbridge_checkpoint_path: Path to the MBridge checkpoint + prompt: Input prompt for the model + output_file: Path to write output + max_new_tokens: Maximum number of tokens to generate + temperature: Sampling temperature + top_k: Top-k sampling parameter (1 for greedy) + seed: Random seed for reproducibility + + Returns: + The generated text from the output file + """ + open_port = find_free_network_port() + + cmd = [ + "torchrun", + "--nproc_per_node", + "1", + "--nnodes", + "1", + "--master_port", + str(open_port), + "-m", + "bionemo.evo2.run.infer", + "--ckpt-dir", + str(mbridge_checkpoint_path), + "--prompt", + prompt, + "--max-new-tokens", + str(max_new_tokens), + "--output-file", + str(output_file), + "--temperature", + str(temperature), + "--top-k", + str(top_k), + "--seed", + str(seed), + ] + + env = copy.deepcopy(PRETEST_ENV) + + result = subprocess.run( + cmd, + check=False, + capture_output=True, + text=True, + timeout=300, # 5 minutes + env=env, + ) + + assert result.returncode == 0, f"infer command failed:\nstdout: {result.stdout}\nstderr: {result.stderr}" + assert output_file.exists(), "Output file was not created" + + return output_file.read_text() + + +def test_identical_prompts_should_be_identical(mbridge_checkpoint_path, tmp_path): + """Test that identical prompts produce identical sequences. + + With greedy decoding (top_k=1) and the same seed, identical prompts + should produce identical outputs. + """ + output_file_1 = tmp_path / "output_prompt1_run1.txt" + output_file_2 = tmp_path / "output_prompt1_run2.txt" + + # Run inference twice with the same prompt + generated_1 = run_infer_subprocess( + mbridge_checkpoint_path, + prompt=PROMPT_1, + output_file=output_file_1, + max_new_tokens=20, + temperature=1.0, + top_k=1, # Greedy decoding for determinism + seed=42, + ) + + generated_2 = run_infer_subprocess( + mbridge_checkpoint_path, + prompt=PROMPT_1, + output_file=output_file_2, + max_new_tokens=20, + temperature=1.0, + top_k=1, # Greedy decoding for determinism + seed=42, + ) + + assert len(generated_1) > 0, "First generation produced empty output" + assert len(generated_2) > 0, "Second generation produced empty output" + assert generated_1 == generated_2, ( + f"Identical prompts with same seed and greedy decoding produced different outputs:\n" + f"Run 1: {generated_1}\n" + f"Run 2: {generated_2}" + ) + + +def test_different_prompts_produce_different_outputs(mbridge_checkpoint_path, tmp_path): + """Test that different prompts produce different sequences. + + Different input prompts should produce different outputs, demonstrating + that the model is actually responding to the prompt content. + """ + output_file_1 = tmp_path / "output_prompt1.txt" + output_file_2 = tmp_path / "output_prompt2.txt" + + # Run inference with two different prompts + generated_1 = run_infer_subprocess( + mbridge_checkpoint_path, + prompt=PROMPT_1, + output_file=output_file_1, + max_new_tokens=20, + temperature=1.0, + top_k=1, # Greedy decoding + seed=42, + ) + + generated_2 = run_infer_subprocess( + mbridge_checkpoint_path, + prompt=PROMPT_2, + output_file=output_file_2, + max_new_tokens=20, + temperature=1.0, + top_k=1, # Greedy decoding + seed=42, + ) + + assert len(generated_1) > 0, "First generation produced empty output" + assert len(generated_2) > 0, "Second generation produced empty output" + + # The outputs should be different since the prompts are different + # We check that the generated portions (after the prompt) are not identical + assert generated_1 != generated_2, ( + f"Different prompts produced identical outputs:\n" + f"Prompt 1 output: {generated_1}\n" + f"Prompt 2 output: {generated_2}" + ) + + +class TestHyenaInferenceContext: + """Unit tests for the Hyena-specific inference context.""" + + def test_context_initialization(self): + """Test that HyenaInferenceContext can be initialized.""" + context = HyenaInferenceContext(max_batch_size=1, max_sequence_length=8192) + assert context is not None + assert context.max_batch_size == 1 + assert context.max_sequence_length == 8192 + + def test_context_reset(self): + """Test that context reset works without error.""" + context = HyenaInferenceContext(max_batch_size=1, max_sequence_length=8192) + # Add some fake filter state (simulating what hyena layers do) + context.filter_state_dict_layer_0 = {"key": torch.zeros(10)} + context.filter_state_dict_layer_1 = {"key": torch.ones(10)} + + # Verify the state was added + assert hasattr(context, "filter_state_dict_layer_0") + assert hasattr(context, "filter_state_dict_layer_1") + + # Reset should remove all filter_state_dict attributes + context.reset() + + assert not hasattr(context, "filter_state_dict_layer_0") + assert not hasattr(context, "filter_state_dict_layer_1") + + def test_context_materialize_logits_setting(self): + """Test that materialize_only_last_token_logits can be configured.""" + context = HyenaInferenceContext(max_batch_size=1, max_sequence_length=8192) + + # Default should be True for efficiency + # We can set it to False if we need full sequence logits + context.materialize_only_last_token_logits = False + assert context.materialize_only_last_token_logits is False + + context.materialize_only_last_token_logits = True + assert context.materialize_only_last_token_logits is True + + def test_context_multiple_batches(self): + """Test context with different batch sizes.""" + for batch_size in [1, 2, 4]: + context = HyenaInferenceContext(max_batch_size=batch_size, max_sequence_length=4096) + assert context.max_batch_size == batch_size + context.reset() # Should not error + + def test_context_different_sequence_lengths(self): + """Test context with different max sequence lengths.""" + for seq_len in [1024, 8192, 16384]: + context = HyenaInferenceContext(max_batch_size=1, max_sequence_length=seq_len) + assert context.max_sequence_length == seq_len + context.reset() diff --git a/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/run/test_inference.py b/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/run/test_inference.py deleted file mode 100644 index 5f48d20888..0000000000 --- a/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/run/test_inference.py +++ /dev/null @@ -1,104 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-FileCopyrightText: Copyright (c) 2024 Arc Institute. All rights reserved. -# SPDX-FileCopyrightText: Copyright (c) 2024 Michael Poli. All rights reserved. -# SPDX-FileCopyrightText: Copyright (c) 2024 Stanford University. All rights reserved -# SPDX-License-Identifier: LicenseRef-Apache2 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# FIXME bring back these tests -# import nemo.lightning as nl -# import torch -# from bionemo.core.data.load import load -# from bionemo.testing.megatron_parallel_state_utils import clean_parallel_state_context -# from megatron.core.inference.common_inference_params import CommonInferenceParams -# from nemo.collections.llm import generate - - -# RANDOM_SEED = 42 - - -# def test_infer_model_generates_expected_single_token_output(): -# # Create PTL trainer. -# TENSOR_PARALLEL_SIZE = 1 -# PIPELINE_MODEL_PARALLEL_SIZE = 1 -# CONTEXT_PARALLEL_SIZE = 1 -# NUM_GPUS = 1 -# NUM_NODES = 1 - -# strategy = nl.MegatronStrategy( -# tensor_model_parallel_size=TENSOR_PARALLEL_SIZE, -# pipeline_model_parallel_size=PIPELINE_MODEL_PARALLEL_SIZE, -# context_parallel_size=CONTEXT_PARALLEL_SIZE, -# pipeline_dtype=torch.bfloat16, -# ckpt_load_optimizer=False, # Needs to be false for a normal model checkpoint. -# ckpt_save_optimizer=False, -# ckpt_async_save=False, -# save_ckpt_format="torch_dist", -# ckpt_load_strictness="log_all", -# ) -# trainer = nl.Trainer( -# accelerator="gpu", -# num_nodes=NUM_NODES, -# devices=NUM_GPUS, -# strategy=strategy, -# log_every_n_steps=1, -# limit_val_batches=10, -# num_sanity_val_steps=0, -# plugins=nl.MegatronMixedPrecision( -# precision="bf16-mixed", -# params_dtype=torch.bfloat16, -# ), -# ) - -# prompt = ( -# "|d__Bacteria;" -# + "p__Pseudomonadota;" -# + "c__Gammaproteobacteria;" -# + "o__Enterobacterales;" -# + "f__Enterobacteriaceae;" -# + "g__Escherichia;" -# + "s__Escherichia|" -# ) -# temperature = 1.0 -# top_k = 0 -# top_p = 0.0 -# max_new_tokens = 1 -# try: -# checkpoint_path = load("evo2/1b-8k:1.0") -# except ValueError as e: -# if e.args[0].endswith("does not have an NGC URL."): -# raise ValueError( -# "Please re-run test with `BIONEMO_DATA_SOURCE=pbss py.test ...`, " -# "one or more files are missing from ngc." -# ) -# else: -# raise e -# with clean_parallel_state_context(): -# results = generate( -# path=checkpoint_path, -# prompts=[prompt], -# trainer=trainer, -# inference_params=CommonInferenceParams( -# temperature, -# top_k, -# top_p, -# return_log_probs=False, -# num_tokens_to_generate=max_new_tokens, -# ), -# random_seed=RANDOM_SEED, -# text_only=True, -# ) - -# assert isinstance(results, list) -# assert results == ["T"] diff --git a/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/run/test_predict.py b/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/run/test_predict.py index 6121f660b2..e56b2f088e 100644 --- a/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/run/test_predict.py +++ b/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/run/test_predict.py @@ -16,498 +16,651 @@ # See the License for the specific language governing permissions and # limitations under the License. -# FIXME bring back these tests -# import glob -# import json -# import os -# import subprocess -# import sys -# import tempfile -# from pathlib import Path - -# # import lightning as pl -# import pytest -# import torch -# from bionemo.core.data.load import load -# from bionemo.llm.lightning import batch_collator -# from bionemo.testing.data.fasta import ALU_SEQUENCE, create_fasta_file -# from bionemo.testing.subprocess_utils import run_command_in_subprocess -# from bionemo.testing.torch import check_fp8_support - -# # FIXME copy this out of lightning. This is a useful utility. -# # from lightning.fabric.plugins.environments.lightning import find_free_network_port -# from .common import predict_cmd, small_training_finetune_cmd - - -# def find_free_network_port(*args, **kwargs): -# raise NotImplementedError("FIXME find_free_network_port is not implemented Find it in megatron bridge") - - -# def is_a6000_gpu() -> bool: -# # Check if any of the visible GPUs is an A6000 -# for i in range(torch.cuda.device_count()): -# device_name = torch.cuda.get_device_name(i) -# if "A6000" in device_name: -# return True -# return False - - -# @pytest.fixture(scope="module") -# def checkpoint_1b_8k_bf16_path() -> Path: -# try: -# checkpoint_path = load("evo2/1b-8k-bf16:1.0") -# except ValueError as e: -# if e.args[0].endswith("does not have an NGC URL."): -# raise ValueError( -# "Please re-run test with `BIONEMO_DATA_SOURCE=pbss py.test ...`, " -# "one or more files are missing from ngc." -# ) -# else: -# raise e -# yield checkpoint_path - - -# @pytest.fixture(scope="module") -# def checkpoint_7b_1m_path() -> Path: -# try: -# checkpoint_path = load("evo2/7b-1m:1.0") -# except ValueError as e: -# if e.args[0].endswith("does not have an NGC URL."): -# raise ValueError( -# "Please re-run test with `BIONEMO_DATA_SOURCE=pbss py.test ...`, " -# "one or more files are missing from ngc." -# ) -# else: -# raise e -# yield checkpoint_path - - -# # FIXME rewrite this test once we have megatron bridge running. We may not need callbacks but if we do rewrite that. -# # def test_predict_does_not_instantiate_optimizer(tmp_path: Path, checkpoint_1b_8k_bf16_path: Path): -# # output_dir = tmp_path / "test_output" -# # fasta_file_path = tmp_path / "test.fasta" -# # create_fasta_file( -# # fasta_file_path, -# # 1, -# # sequence_lengths=[512], -# # repeating_dna_pattern=ALU_SEQUENCE, -# # ) - -# # class AssertNoOptimizerCallback(Callback): -# # def on_predict_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx=0): -# # assert not trainer.optimizers, ( -# # f"Optimizer should not be instantiated for prediction, got {trainer.optimizers}" -# # ) -# # trainer_model_opt = getattr(trainer.model, "optim", None) -# # assert trainer_model_opt is None or not trainer_model_opt.state_dict(), ( -# # f"Model optimizer found, got {trainer_model_opt} with state {trainer_model_opt.state_dict()}" -# # ) - -# # with clean_parallel_state_context(): -# # predict( -# # fasta_path=fasta_file_path, -# # ckpt_dir=str(checkpoint_1b_8k_bf16_path), -# # output_dir=output_dir, -# # tensor_parallel_size=1, -# # pipeline_model_parallel_size=1, -# # context_parallel_size=1, -# # num_nodes=1, -# # devices=1, -# # model_size="1b", -# # ckpt_format="torch_dist", -# # fp8=False, -# # full_fp8=False, -# # work_dir=tmp_path, -# # micro_batch_size=1, -# # output_log_prob_seqs=True, -# # log_prob_collapse_option="mean", -# # write_interval="epoch", -# # prepend_bos=False, -# # no_sequence_parallel=False, -# # hybrid_override_pattern="SDH*", -# # num_layers=4, -# # seq_len_interpolation_factor=None, -# # files_per_subdir=None, -# # lora_checkpoint_path=None, -# # extra_callbacks=[ -# # AssertNoOptimizerCallback(), -# # ], # use this for making testing the loop easier. -# # ) - - -# @pytest.mark.parametrize( -# "ddp,pp,wi", -# [ -# pytest.param(1, 1, "epoch", id="ddp=1,pp=1,wi=epoch"), -# pytest.param(2, 1, "epoch", id="ddp=2,pp=1,wi=epoch"), -# pytest.param(2, 1, "batch", id="ddp=2,pp=1,wi=batch"), -# pytest.param( -# 1, -# 2, -# "epoch", -# id="ddp=1,pp=2,wi=epoch", -# marks=pytest.mark.skip("Pipeline parallelism test currently hangs."), -# ), -# ], -# ) -# def test_predict_evo2_runs( -# tmp_path, -# ddp: int, -# pp: int, -# wi: str, -# checkpoint_1b_8k_bf16_path: Path, -# num_sequences: int = 5, -# target_sequence_lengths: list[int] = [3149, 3140, 1024, 3148, 3147], -# ): -# """ -# This test runs the `predict_evo2` command with mock data in a temporary directory. -# It uses the temporary directory provided by pytest as the working directory. -# The command is run in a subshell, and we assert that it returns an exit code of 0. - -# Since it's the full output this does not support CP, so we only test with TP=1. We also want coverage of the -# case where the sequence lengths are different and not necessarily divisible by CP. -# """ -# world_size = ddp * pp -# if world_size > torch.cuda.device_count(): -# pytest.skip(f"World size {world_size} is less than the number of GPUs {torch.cuda.device_count()}") -# fasta_file_path = tmp_path / "test.fasta" -# create_fasta_file( -# fasta_file_path, num_sequences, sequence_lengths=target_sequence_lengths, repeating_dna_pattern=ALU_SEQUENCE -# ) -# # Create a mock data directory. -# # a local copy of the environment -# env = dict(**os.environ) -# if is_a6000_gpu(): -# # Fix hanging issue on A6000 GPUs with multi-gpu tests -# env["NCCL_P2P_DISABLE"] = "1" - -# # Build the command string. -# # Note: The command assumes that `train_evo2` is in your PATH. -# output_dir = tmp_path / "test_output" -# command = ( -# f"torchrun --nproc_per_node {world_size} --nnodes 1 --no-python " -# f"predict_evo2 --fasta {fasta_file_path} --ckpt-dir {checkpoint_1b_8k_bf16_path} " -# f"--output-dir {output_dir} --model-size 1b " -# f"--micro-batch-size 3 --write-interval {wi} " -# f"--pipeline-model-parallel-size {pp} --num-nodes 1 --devices {world_size}" -# ) - -# # Run the command in a subshell, using the temporary directory as the current working directory. -# open_port = find_free_network_port() -# env["MASTER_PORT"] = str(open_port) -# result = subprocess.run( -# command, -# check=False, -# shell=True, # Use the shell to interpret wildcards (e.g. SDH*) -# cwd=tmp_path, # Run in the temporary directory -# capture_output=True, # Capture stdout and stderr for debugging -# env=env, # Pass in the env where we override the master port. -# text=True, # Decode output as text -# ) - -# # For debugging purposes, print the output if the test fails. -# if result.returncode != 0: -# sys.stderr.write("STDOUT:\n" + result.stdout + "\n") -# sys.stderr.write("STDERR:\n" + result.stderr + "\n") - -# # Assert that the command completed successfully. -# assert result.returncode == 0, "train_evo2 command failed." - -# # Assert that the output directory was created. -# pred_files = glob.glob(os.path.join(output_dir, "predictions__rank_*.pt")) -# if wi == "batch": -# assert len(pred_files) == 2, f"Expected 2 prediction file (for this test), got {len(pred_files)}" -# else: -# assert len(pred_files) == ddp, f"Expected {ddp} prediction file (for this test), got {len(pred_files)}" -# with open(output_dir / "seq_idx_map.json", "r") as f: -# seq_idx_map = json.load( -# f -# ) # This gives us the mapping from the sequence names to the indices in the predictions. -# preds = [torch.load(pf) for pf in pred_files] -# preds = batch_collator( -# [p for p in preds if p is not None], -# batch_dim_key_defaults={"token_logits": 0}, -# seq_dim_key_defaults={"token_logits": 1}, -# ) -# assert isinstance(preds, dict) -# assert "token_logits" in preds -# assert "pad_mask" in preds -# assert "seq_idx" in preds - -# assert len(preds["token_logits"]) == len(preds["pad_mask"]) == len(preds["seq_idx"]) == num_sequences -# assert len(seq_idx_map) == num_sequences -# for original_idx, pad_mask, token_logits in zip(preds["seq_idx"], preds["pad_mask"], preds["token_logits"]): -# # seq_idx is not sorted necessarily, so use the saved "seq_idx" to determine the original order. -# expected_len = target_sequence_lengths[original_idx] -# assert pad_mask.sum() == expected_len -# assert token_logits.shape == (max(target_sequence_lengths), 512) - - -# @pytest.fixture(scope="module") -# def baseline_predictions_7b_1m_results( -# checkpoint_7b_1m_path: Path, -# num_sequences: int = 5, -# target_sequence_lengths: list[int] = [2048, 2048, 2048, 2048, 2048], -# ) -> dict[int, float]: -# with tempfile.TemporaryDirectory() as tmp_dir: -# tmp_path = Path(tmp_dir) -# fasta_file_path = tmp_path / "test.fasta" -# create_fasta_file( -# fasta_file_path, -# num_sequences, -# sequence_lengths=target_sequence_lengths, -# repeating_dna_pattern=ALU_SEQUENCE, -# ) -# output_dir = tmp_path / "test_output" -# command = ( -# f"torchrun --nproc_per_node 1 --nnodes 1 --no-python " -# f"predict_evo2 --fasta {fasta_file_path} --ckpt-dir {checkpoint_7b_1m_path} " -# f"--num-layers 4 --hybrid-override-pattern SDH* " # subset of layers for testing -# # FIXME changing batch size from 3 to 1 required dropping rel=1e-6 to rel=1e-3 -# # even when model parallelism is not used. This should be investigated. -# f"--micro-batch-size 3 " -# f"--output-dir {output_dir} --model-size 7b_arc_longcontext " -# f"--num-nodes 1 --write-interval epoch " -# "--output-log-prob-seqs --log-prob-collapse-option sum" -# ) -# # Create a mock data directory. -# # a local copy of the environment -# env = dict(**os.environ) -# open_port = find_free_network_port() -# env["MASTER_PORT"] = str(open_port) -# result = subprocess.run( -# command, -# check=False, -# shell=True, # Use the shell to interpret wildcards (e.g. SDH*) -# cwd=tmp_path, # Run in the temporary directory -# capture_output=True, # Capture stdout and stderr for debugging -# env=env, # Pass in the env where we override the master port. -# text=True, # Decode output as text -# ) -# assert result.returncode == 0, "predict_evo2 command failed." -# # Assert that the output directory was created. -# pred_files = glob.glob(os.path.join(output_dir, "predictions__rank_*.pt")) -# preds = [torch.load(pf) for pf in pred_files] -# preds = batch_collator( -# [p for p in preds if p is not None], -# ) -# yield dict(zip([i.item() for i in preds["seq_idx"]], [p.item() for p in preds["log_probs_seqs"]])) - - -# @pytest.mark.parametrize( -# "ddp,cp,pp,tp,fp8,wi", -# [ -# pytest.param(1, 1, 1, 1, False, "epoch", id="ddp=1,cp=1,pp=1,tp=1,fp8=False,wi=epoch"), -# pytest.param(2, 1, 1, 1, False, "epoch", id="ddp=2,cp=1,pp=1,tp=1,fp8=False,wi=epoch"), -# pytest.param( -# 2, 1, 1, 1, False, "batch", id="ddp=2,cp=1,pp=1,tp=1,fp8=False,wi=batch" -# ), # simulate a large prediction run with dp parallelism -# pytest.param(1, 2, 1, 1, False, "epoch", id="ddp=1,cp=2,pp=1,tp=1,fp8=False,wi=epoch"), -# pytest.param(1, 2, 1, 1, False, "batch", id="ddp=1,cp=2,pp=1,tp=1,fp8=False,wi=batch"), -# pytest.param( -# 1, -# 1, -# 2, -# 1, -# False, -# "epoch", -# id="ddp=1,cp=1,pp=2,tp=1,fp8=False,wi=epoch", -# marks=pytest.mark.skip("Pipeline parallelism test currently hangs."), -# ), -# pytest.param( -# 1, 1, 1, 2, True, "epoch", id="ddp=1,cp=1,pp=1,tp=2,fp8=True,wi=epoch" -# ), # Cover case where FP8 was not supported with TP=2 -# pytest.param(1, 1, 1, 2, False, "epoch", id="ddp=1,cp=1,pp=1,tp=2,fp8=False,wi=epoch"), -# ], -# ids=lambda x: f"ddp={x[0]},cp={x[1]},pp={x[2]},tp={x[3]},fp8={x[4]},wi={x[5]}", -# ) -# def test_predict_evo2_equivalent_with_log_probs( -# tmp_path, -# ddp: int, -# cp: int, -# pp: int, -# tp: int, -# fp8: bool, -# wi: str, -# checkpoint_7b_1m_path: Path, -# baseline_predictions_7b_1m_results: dict[int, float], -# num_sequences: int = 5, -# target_sequence_lengths: list[int] = [2048, 2048, 2048, 2048, 2048], -# ): -# """ -# This test runs the `predict_evo2` command with mock data in a temporary directory. -# It uses the temporary directory provided by pytest as the working directory. -# The command is run in a subshell, and we assert that it returns an exit code of 0. - -# For this test, we want coverage of CP, so we make sure sequence lengths are all the same and divisible by CP. - -# The other thing this test does is check that the log probabilities are equivalent to the baseline predictions -# without model parallelism. -# """ - -# world_size = ddp * cp * pp * tp -# mp_size = cp * pp * tp -# if world_size > torch.cuda.device_count(): -# pytest.skip(f"World size {world_size} is less than the number of GPUs {torch.cuda.device_count()}") -# is_fp8_supported, _, _ = check_fp8_support(torch.cuda.current_device()) -# if not is_fp8_supported and fp8: -# pytest.skip("FP8 is not supported on this GPU.") - -# fasta_file_path = tmp_path / "test.fasta" -# create_fasta_file( -# fasta_file_path, num_sequences, sequence_lengths=target_sequence_lengths, repeating_dna_pattern=ALU_SEQUENCE -# ) -# # Create a mock data directory. -# # a local copy of the environment -# env = dict(**os.environ) -# if is_a6000_gpu(): -# # Fix hanging issue on A6000 GPUs with multi-gpu tests -# env["NCCL_P2P_DISABLE"] = "1" - -# fp8_option = "--fp8" if fp8 else "" -# # Build the command string. -# # Note: The command assumes that `train_evo2` is in your PATH. -# output_dir = tmp_path / "test_output" -# command = ( -# f"torchrun --nproc_per_node {world_size} --nnodes 1 --no-python " -# f"predict_evo2 --fasta {fasta_file_path} --ckpt-dir {checkpoint_7b_1m_path} " -# f"--micro-batch-size 3 --write-interval {wi} " -# f"--num-layers 4 --hybrid-override-pattern SDH* " # subset of layers for testing -# f"--output-dir {output_dir} --model-size 7b_arc_longcontext --tensor-parallel-size {tp} {fp8_option} " -# f"--pipeline-model-parallel-size {pp} --context-parallel-size {cp} --num-nodes 1 --devices {world_size} " -# "--output-log-prob-seqs --log-prob-collapse-option sum" -# ) - -# # Run the command in a subshell, using the temporary directory as the current working directory. -# open_port = find_free_network_port() -# env["MASTER_PORT"] = str(open_port) -# result = subprocess.run( -# command, -# check=False, -# shell=True, # Use the shell to interpret wildcards (e.g. SDH*) -# cwd=tmp_path, # Run in the temporary directory -# capture_output=True, # Capture stdout and stderr for debugging -# env=env, # Pass in the env where we override the master port. -# text=True, # Decode output as text -# ) - -# # For debugging purposes, print the output if the test fails. -# if result.returncode != 0: -# sys.stderr.write("STDOUT:\n" + result.stdout + "\n") -# sys.stderr.write("STDERR:\n" + result.stderr + "\n") - -# # Assert that the command completed successfully. -# assert result.returncode == 0, "train_evo2 command failed." - -# # Assert that the output directory was created. -# pred_files = glob.glob(os.path.join(output_dir, "predictions__rank_*.pt")) -# if wi == "batch": -# assert len(pred_files) == 2, f"Expected 2 prediction file (for this test), got {len(pred_files)}" -# else: -# assert len(pred_files) == ddp, f"Expected {ddp} prediction file (for this test), got {len(pred_files)}" -# with open(output_dir / "seq_idx_map.json", "r") as f: -# seq_idx_map = json.load( -# f -# ) # This gives us the mapping from the sequence names to the indices in the predictions. -# preds = [torch.load(pf) for pf in pred_files] -# preds = batch_collator( -# [p for p in preds if p is not None], -# ) -# assert isinstance(preds, dict) -# assert "log_probs_seqs" in preds -# assert "seq_idx" in preds -# assert len(preds["log_probs_seqs"]) == len(preds["seq_idx"]) == num_sequences -# assert len(seq_idx_map) == num_sequences -# for original_idx, log_probs in zip(preds["seq_idx"], preds["log_probs_seqs"]): -# if mp_size > 1 and not fp8: -# # FIXME changing batch size so it doesn't match also required dropping rel=1e-6 to rel=1e-3. -# # This should be investigated. -# rel = 1e-3 -# elif fp8: -# # NOTE: This is hand-tuned on a b300 to pass for now as of 9/10/2025. -# rel = 1e-2 -# else: -# rel = 1e-6 -# assert log_probs.item() == pytest.approx(baseline_predictions_7b_1m_results[original_idx.item()], rel=rel) - - +"""Tests for Evo2 prediction (inference) workflow using Megatron Bridge.""" + +import copy +import glob +import json +import os +import re +import shlex +import subprocess +from pathlib import Path + +import pytest +import torch + +from bionemo.evo2.data.test_utils.create_fasta_file import ALU_SEQUENCE, create_fasta_file +from bionemo.evo2.run.predict import batch_collator + +from ..utils import check_fp8_support, find_free_network_port, is_a6000_gpu + + +# Do this at collection time before we run any tests. +PRETEST_ENV = copy.deepcopy(os.environ) + + +@pytest.fixture(scope="module") +def mbridge_checkpoint_1b_8k_bf16_path(mbridge_checkpoint_1b_8k_bf16) -> Path: + """Module-scoped alias for the session-scoped 1b-8k-bf16 checkpoint. + + The actual checkpoint conversion is done once per session in conftest.py via + the mbridge_checkpoint_1b_8k_bf16 fixture, and shared across all test files. + + Returns: + Path to the MBridge checkpoint iteration directory (e.g., .../iter_0000001) + """ + return mbridge_checkpoint_1b_8k_bf16 + + +@pytest.mark.parametrize( + "ddp,pp,wi", + [ + pytest.param(1, 1, "epoch", id="ddp=1,pp=1,wi=epoch"), + pytest.param(2, 1, "epoch", id="ddp=2,pp=1,wi=epoch"), + pytest.param(2, 1, "batch", id="ddp=2,pp=1,wi=batch"), + pytest.param( + 1, + 2, + "epoch", + id="ddp=1,pp=2,wi=epoch", + marks=pytest.mark.skip("Pipeline parallelism test currently hangs."), + ), + ], +) +@pytest.mark.slow +def test_predict_evo2_runs( + tmp_path, + ddp: int, + pp: int, + wi: str, + mbridge_checkpoint_1b_8k_bf16_path: Path, + num_sequences: int = 5, + target_sequence_lengths: list[int] | None = None, +): + """Test that the predict_evo2 command runs successfully with MBridge checkpoints. + + This test runs the `predict_evo2` command with mock data in a temporary directory. + It uses the temporary directory provided by pytest as the working directory. + The command is run in a subshell, and we assert that it returns an exit code of 0. + + Since it's the full output this does not support CP, so we only test with TP=1. We also want coverage of the + case where the sequence lengths are different and not necessarily divisible by CP. + """ + if target_sequence_lengths is None: + target_sequence_lengths = [3149, 3140, 1024, 3148, 3147] + + world_size = ddp * pp + if world_size > torch.cuda.device_count(): + pytest.skip(f"World size {world_size} is greater than the number of GPUs {torch.cuda.device_count()}") + + fasta_file_path = tmp_path / "test.fasta" + create_fasta_file( + fasta_file_path, num_sequences, sequence_lengths=target_sequence_lengths, repeating_dna_pattern=ALU_SEQUENCE + ) + + # Create a local copy of the environment + env = copy.deepcopy(PRETEST_ENV) + if is_a6000_gpu(): + # Fix hanging issue on A6000 GPUs with multi-gpu tests + env["NCCL_P2P_DISABLE"] = "1" + + # Build the command string + output_dir = tmp_path / "test_output" + open_port = find_free_network_port() + command = ( + f"torchrun --nproc_per_node {world_size} --nnodes 1 --master_port {open_port} " + f"-m bionemo.evo2.run.predict --fasta {fasta_file_path} --ckpt-dir {mbridge_checkpoint_1b_8k_bf16_path} " + f"--output-dir {output_dir} " + f"--micro-batch-size 3 --write-interval {wi} " + f"--pipeline-model-parallel-size {pp} --num-nodes 1 --devices {world_size}" + ) + + # Run the command in a subshell + cmd_parts = shlex.split(command) + result = subprocess.run( + cmd_parts, + check=False, + cwd=tmp_path, + capture_output=True, + env=env, + text=True, + ) + + # For debugging purposes, print the output if the test fails + if result.returncode != 0: + print("STDOUT:\n" + result.stdout) + print("STDERR:\n" + result.stderr) + + # Assert that the command completed successfully + assert result.returncode == 0, f"predict_evo2 command failed with code {result.returncode}" + + # Assert that the output directory was created and contains predictions + # With DDP, each DP rank produces its own file with dp_rank in the filename + # File naming convention: + # Batch mode: predictions__rank_{global_rank}__dp_rank_{dp_rank}__batch_{batch_idx}.pt + # Epoch mode: predictions__rank_{global_rank}__dp_rank_{dp_rank}.pt + if wi == "batch": + pred_files = sorted(glob.glob(str(output_dir / "predictions__rank_*__dp_rank_*__batch_*.pt"))) + # With batch write interval, we expect multiple files (batches * dp_ranks) + assert len(pred_files) >= ddp, f"Expected at least {ddp} prediction files, got {len(pred_files)}" + else: + pred_files = sorted(glob.glob(str(output_dir / "predictions__rank_*__dp_rank_*.pt"))) + # With epoch write interval, we expect one file per DP rank + assert len(pred_files) == ddp, f"Expected {ddp} prediction files (one per DP rank), got {len(pred_files)}" + + # Check sequence index map exists + seq_idx_map_path = output_dir / "seq_idx_map.json" + assert seq_idx_map_path.exists(), f"seq_idx_map.json not found at {seq_idx_map_path}" + + with open(seq_idx_map_path) as f: + seq_idx_map = json.load(f) + + # Load and collate predictions + # Note: predict.py outputs are all batch-first (batch_dim=0), seq-second (seq_dim=1) + preds = [torch.load(pf) for pf in pred_files] + preds = batch_collator( + [p for p in preds if p is not None], + batch_dim=0, + seq_dim=1, + batch_dim_key_defaults={}, + seq_dim_key_defaults={}, + ) + assert isinstance(preds, dict) + assert "token_logits" in preds + assert "pad_mask" in preds + assert "seq_idx" in preds + + assert len(preds["token_logits"]) == len(preds["pad_mask"]) == len(preds["seq_idx"]) == num_sequences + assert len(seq_idx_map) == num_sequences + + for original_idx, pad_mask, token_logits in zip(preds["seq_idx"], preds["pad_mask"], preds["token_logits"]): + # seq_idx is not sorted necessarily, so use the saved "seq_idx" to determine the original order + expected_len = target_sequence_lengths[original_idx] + assert pad_mask.sum() == expected_len + # Vocab size should be 512 for the nucleotide tokenizer + assert token_logits.shape[-1] == 512 + + +@pytest.fixture(scope="module") +def mbridge_checkpoint_7b_1m_path(tmp_path_factory) -> Path: + """Create or load a MBridge checkpoint for 7b-1m model testing.""" + from bionemo.core.data.load import load + from bionemo.evo2.data.dataset_tokenizer import DEFAULT_HF_TOKENIZER_MODEL_PATH_512 + from bionemo.evo2.utils.checkpoint.nemo2_to_mbridge import run_nemo2_to_mbridge + + try: + nemo2_checkpoint_path = load("evo2/7b-1m:1.0") + except ValueError as e: + if e.args[0].endswith("does not have an NGC URL."): + pytest.skip( + "Please re-run test with `BIONEMO_DATA_SOURCE=pbss py.test ...`, " + "one or more files are missing from ngc." + ) + else: + raise e + + # Create a temporary directory for the MBridge checkpoint + tmp_dir = tmp_path_factory.mktemp("mbridge_ckpt_7b") + # Note: run_nemo2_to_mbridge uses full model config from model_size + # For testing we use the full 7b model but with shorter sequences + mbridge_ckpt_dir = run_nemo2_to_mbridge( + nemo2_ckpt_dir=nemo2_checkpoint_path, + tokenizer_path=DEFAULT_HF_TOKENIZER_MODEL_PATH_512, + mbridge_ckpt_dir=tmp_dir / "mbridge_checkpoint", + model_size="7b_arc_longcontext", + seq_length=8192, # Use shorter seq length for tests + mixed_precision_recipe="bf16_mixed", + vortex_style_fp8=False, + ) + return mbridge_ckpt_dir / "iter_0000001" + + +@pytest.fixture(scope="module") +def baseline_predictions_7b_1m_results( + mbridge_checkpoint_7b_1m_path: Path, + tmp_path_factory, + num_sequences: int = 5, + target_sequence_lengths: list[int] | None = None, +) -> dict[int, float]: + """Generate baseline predictions for 7b-1m model comparison.""" + if target_sequence_lengths is None: + target_sequence_lengths = [2048, 2048, 2048, 2048, 2048] + + tmp_path = tmp_path_factory.mktemp("baseline_preds") + fasta_file_path = tmp_path / "test.fasta" + create_fasta_file( + fasta_file_path, + num_sequences, + sequence_lengths=target_sequence_lengths, + repeating_dna_pattern=ALU_SEQUENCE, + ) + output_dir = tmp_path / "test_output" + open_port = find_free_network_port() + command = ( + f"torchrun --nproc_per_node 1 --nnodes 1 --master_port {open_port} " + f"-m bionemo.evo2.run.predict --fasta {fasta_file_path} --ckpt-dir {mbridge_checkpoint_7b_1m_path} " + f"--micro-batch-size 3 " + f"--output-dir {output_dir} " + f"--num-nodes 1 --write-interval epoch " + "--output-log-prob-seqs --log-prob-collapse-option sum" + ) + + env = copy.deepcopy(PRETEST_ENV) + cmd_parts = shlex.split(command) + result = subprocess.run( + cmd_parts, + check=False, + cwd=tmp_path, + capture_output=True, + env=env, + text=True, + ) + assert result.returncode == 0, f"predict_evo2 command failed: {result.stderr}" + + # Use the updated glob pattern matching the new naming convention + # Epoch mode: predictions__rank_{global_rank}__dp_rank_{dp_rank}.pt + pred_files = glob.glob(str(output_dir / "predictions__rank_*__dp_rank_*.pt")) + preds = [torch.load(pf) for pf in pred_files] + preds = batch_collator( + [p for p in preds if p is not None], + batch_dim=0, + seq_dim=1, + batch_dim_key_defaults={}, + seq_dim_key_defaults={}, + ) + return dict(zip([i.item() for i in preds["seq_idx"]], [p.item() for p in preds["log_probs_seqs"]])) + + +@pytest.mark.parametrize( + "ddp,cp,pp,tp,fp8,wi", + [ + pytest.param(1, 1, 1, 1, False, "epoch", id="ddp=1,cp=1,pp=1,tp=1,fp8=False,wi=epoch"), + pytest.param(2, 1, 1, 1, False, "epoch", id="ddp=2,cp=1,pp=1,tp=1,fp8=False,wi=epoch"), + pytest.param( + 2, 1, 1, 1, False, "batch", id="ddp=2,cp=1,pp=1,tp=1,fp8=False,wi=batch" + ), # simulate a large prediction run with dp parallelism + pytest.param(1, 2, 1, 1, False, "epoch", id="ddp=1,cp=2,pp=1,tp=1,fp8=False,wi=epoch"), + pytest.param(1, 2, 1, 1, False, "batch", id="ddp=1,cp=2,pp=1,tp=1,fp8=False,wi=batch"), + pytest.param( + 1, + 1, + 2, + 1, + False, + "epoch", + id="ddp=1,cp=1,pp=2,tp=1,fp8=False,wi=epoch", + marks=pytest.mark.skip("Pipeline parallelism test currently hangs."), + ), + pytest.param( + 1, 1, 1, 2, True, "epoch", id="ddp=1,cp=1,pp=1,tp=2,fp8=True,wi=epoch" + ), # Cover case where FP8 was not supported with TP=2 + pytest.param(1, 1, 1, 2, False, "epoch", id="ddp=1,cp=1,pp=1,tp=2,fp8=False,wi=epoch"), + ], +) +@pytest.mark.slow +@pytest.mark.skipif(bool(os.environ.get("CI")), reason="Skip 7b-1m checkpoint tests in CI due to disk space") +def test_predict_evo2_equivalent_with_log_probs( + tmp_path, + ddp: int, + cp: int, + pp: int, + tp: int, + fp8: bool, + wi: str, + mbridge_checkpoint_7b_1m_path: Path, + baseline_predictions_7b_1m_results: dict[int, float], + num_sequences: int = 5, + target_sequence_lengths: list[int] | None = None, +): + """Test that predict_evo2 produces equivalent log probabilities with different parallelism settings. + + This test runs the `predict_evo2` command with mock data in a temporary directory. + It uses the temporary directory provided by pytest as the working directory. + The command is run in a subshell, and we assert that it returns an exit code of 0. + + For this test, we want coverage of CP, so we make sure sequence lengths are all the same and divisible by CP. + + The other thing this test does is check that the log probabilities are equivalent to the baseline predictions + without model parallelism. + """ + if target_sequence_lengths is None: + target_sequence_lengths = [2048, 2048, 2048, 2048, 2048] + + world_size = ddp * cp * pp * tp + mp_size = cp * pp * tp + if world_size > torch.cuda.device_count(): + pytest.skip(f"World size {world_size} is greater than the number of GPUs {torch.cuda.device_count()}") + is_fp8_supported, _, _ = check_fp8_support(torch.cuda.current_device()) + if not is_fp8_supported and fp8: + pytest.skip("FP8 is not supported on this GPU.") + + fasta_file_path = tmp_path / "test.fasta" + create_fasta_file( + fasta_file_path, num_sequences, sequence_lengths=target_sequence_lengths, repeating_dna_pattern=ALU_SEQUENCE + ) + + # Create a local copy of the environment + env = copy.deepcopy(PRETEST_ENV) + if is_a6000_gpu(): + # Fix hanging issue on A6000 GPUs with multi-gpu tests + env["NCCL_P2P_DISABLE"] = "1" + + fp8_option = "--mixed-precision-recipe bf16_with_fp8_current_scaling_mixed" if fp8 else "" + output_dir = tmp_path / "test_output" + open_port = find_free_network_port() + command = ( + f"torchrun --nproc_per_node {world_size} --nnodes 1 --master_port {open_port} " + f"-m bionemo.evo2.run.predict --fasta {fasta_file_path} --ckpt-dir {mbridge_checkpoint_7b_1m_path} " + f"--micro-batch-size 3 --write-interval {wi} " + f"--output-dir {output_dir} --tensor-parallel-size {tp} {fp8_option} " + f"--pipeline-model-parallel-size {pp} --context-parallel-size {cp} --num-nodes 1 --devices {world_size} " + "--output-log-prob-seqs --log-prob-collapse-option sum" + ) + + cmd_parts = shlex.split(command) + result = subprocess.run( + cmd_parts, + check=False, + cwd=tmp_path, + capture_output=True, + env=env, + text=True, + ) + + # For debugging purposes, print the output if the test fails + if result.returncode != 0: + print("STDOUT:\n" + result.stdout) + print("STDERR:\n" + result.stderr) + + # Assert that the command completed successfully + assert result.returncode == 0, f"predict_evo2 command failed with code {result.returncode}" + + # Assert that the output directory was created + # With DDP, each DP rank produces its own file with dp_rank in the filename + # File naming convention: + # Batch mode: predictions__rank_{global_rank}__dp_rank_{dp_rank}__batch_{batch_idx}.pt + # Epoch mode: predictions__rank_{global_rank}__dp_rank_{dp_rank}.pt + if wi == "batch": + pred_files = sorted(glob.glob(str(output_dir / "predictions__rank_*__dp_rank_*__batch_*.pt"))) + # With batch write interval, we expect multiple files (batches * dp_ranks) + assert len(pred_files) >= ddp, f"Expected at least {ddp} prediction files, got {len(pred_files)}" + else: + pred_files = sorted(glob.glob(str(output_dir / "predictions__rank_*__dp_rank_*.pt"))) + # With epoch write interval, we expect one file per DP rank + assert len(pred_files) == ddp, f"Expected {ddp} prediction files (one per DP rank), got {len(pred_files)}" + + with open(output_dir / "seq_idx_map.json") as f: + seq_idx_map = json.load(f) + + # Load and collate predictions from all DP ranks + preds = [torch.load(pf) for pf in pred_files] + preds = batch_collator( + [p for p in preds if p is not None], + batch_dim=0, + seq_dim=1, + batch_dim_key_defaults={}, + seq_dim_key_defaults={}, + ) + assert isinstance(preds, dict) + assert "log_probs_seqs" in preds + assert "seq_idx" in preds + assert len(preds["log_probs_seqs"]) == len(preds["seq_idx"]) == num_sequences + assert len(seq_idx_map) == num_sequences + + for original_idx, log_probs in zip(preds["seq_idx"], preds["log_probs_seqs"]): + if mp_size > 1 and not fp8: + # FIXME changing batch size so it doesn't match also required dropping rel=1e-6 to rel=1e-3. + # This should be investigated. TP=2 on some GPUs needs even more tolerance. + rel = 2e-3 + elif fp8: + # NOTE: This is hand-tuned on a b300 to pass for now as of 9/10/2025. + rel = 1e-2 + else: + rel = 1e-6 + assert log_probs.item() == pytest.approx(baseline_predictions_7b_1m_results[original_idx.item()], rel=rel) + + +# Note: The PEFT/LoRA test is commented out as it requires training infrastructure and LoRA support +# which may need additional updates for the Megatron Bridge API # @pytest.mark.timeout(512) # @pytest.mark.slow # def test_different_results_with_without_peft(tmp_path): -# try: -# base_model_checkpoint_path = load("evo2/1b-8k:1.0") -# except ValueError as e: -# if e.args[0].endswith("does not have an NGC URL."): -# raise ValueError( -# "Please re-run test with `BIONEMO_DATA_SOURCE=pbss py.test ...`, " -# "one or more files are missing from ngc." -# ) -# else: -# raise e - -# num_steps = 2 - -# result_dir = tmp_path / "lora_finetune" - -# # Note: The command assumes that `train_evo2` is in your PATH. -# command_finetune = small_training_finetune_cmd( -# result_dir, -# max_steps=num_steps, -# val_check=num_steps, -# prev_ckpt=base_model_checkpoint_path, -# create_tflops_callback=False, -# additional_args="--lora-finetune", -# ) -# stdout_finetune: str = run_command_in_subprocess(command=command_finetune, path=str(tmp_path)) -# assert "Restoring model weights from RestoreConfig(path='" in stdout_finetune -# assert "Loading adapters from" not in stdout_finetune - -# # Check if checkpoints dir exists -# checkpoints_dir = result_dir / "evo2" / "checkpoints" -# assert checkpoints_dir.exists(), "Checkpoints folder does not exist." - -# # Create a sample FASTA file to run predictions -# fasta_file_path = tmp_path / "test.fasta" -# create_fasta_file(fasta_file_path, 3, sequence_lengths=[32, 65, 129], repeating_dna_pattern=ALU_SEQUENCE) - -# result_dir_original = tmp_path / "results_original" -# cmd_predict = predict_cmd(base_model_checkpoint_path, result_dir_original, fasta_file_path) -# stdout_predict: str = run_command_in_subprocess(command=cmd_predict, path=str(tmp_path)) - -# # Assert that the output directory was created. -# pred_files_original = glob.glob(str(result_dir_original / "predictions__rank_*.pt")) -# assert len(pred_files_original) == 1, f"Expected 1 prediction file (for this test), got {len(pred_files_original)}" - -# # Find the checkpoint dir generated by finetuning -# expected_checkpoint_suffix = f"{num_steps}.0-last" -# # Check if any subfolder ends with the expected suffix -# matching_subfolders = [ -# p for p in checkpoints_dir.iterdir() if p.is_dir() and (expected_checkpoint_suffix in p.name) -# ] - -# assert matching_subfolders, ( -# f"No checkpoint subfolder ending with '{expected_checkpoint_suffix}' found in {checkpoints_dir}." -# ) - -# result_dir_peft = tmp_path / "results_peft" -# additional_args = f"--lora-checkpoint-path {matching_subfolders[0]}" -# cmd_predict = predict_cmd(base_model_checkpoint_path, result_dir_peft, fasta_file_path, additional_args) -# stdout_predict: str = run_command_in_subprocess(command=cmd_predict, path=str(tmp_path)) -# assert "Loading adapters from" in stdout_predict - -# pred_files_peft = glob.glob(str(result_dir_peft / "predictions__rank_*.pt")) -# assert len(pred_files_peft) == 1, f"Expected 1 prediction file (for this test), got {len(pred_files_peft)}" - -# results_original = torch.load(f"{result_dir_original}/predictions__rank_0__dp_rank_0.pt") -# results_peft = torch.load(f"{result_dir_peft}/predictions__rank_0__dp_rank_0.pt") - -# seq_idx_original = results_original["seq_idx"] -# seq_idx_peft = results_peft["seq_idx"] -# assert torch.equal(seq_idx_original, seq_idx_peft), f"Tensors differ: {seq_idx_original} vs {seq_idx_peft}" - -# logits_original = results_original["token_logits"] -# logits_peft = results_peft["token_logits"] -# assert (logits_original != logits_peft).any() -# assert logits_original.shape == logits_peft.shape, ( -# f"Shapes don't match: {logits_original.shape} vs {logits_peft.shape}" -# ) +# """Test that predictions differ when using PEFT/LoRA adapters.""" +# pass + + +@pytest.mark.parametrize( + "embedding_layer,expected_num_layers", + [ + pytest.param(-1, 25, id="embedding_layer=-1_expects_25_layers"), + pytest.param(-2, 24, id="embedding_layer=-2_expects_24_layers"), + pytest.param(0, 1, id="embedding_layer=0_expects_1_layer"), + pytest.param(5, 6, id="embedding_layer=5_expects_6_layers"), + ], +) +@pytest.mark.slow +def test_predict_evo2_embedding_extraction( + tmp_path, + embedding_layer: int, + expected_num_layers: int, + mbridge_checkpoint_1b_8k_bf16_path: Path, + num_sequences: int = 3, + target_sequence_lengths: list[int] | None = None, +): + """Test that embedding extraction produces outputs with expected shapes and keys. + + This test verifies: + 1. The model is initialized with the correct number of layers (logged and verified) + 2. Output contains 'hidden_embeddings' key instead of 'token_logits' + 3. Embeddings have expected shape [B, S, H] where H is hidden dimension + 4. Other expected keys (pad_mask, seq_idx, tokens) are present + + The 1b model has 25 layers, so: + - embedding_layer=-1 -> 25 layers (last layer) + - embedding_layer=-2 -> 24 layers (second-to-last) + - embedding_layer=0 -> 1 layer (first layer only) + - embedding_layer=5 -> 6 layers (layers 0-5) + """ + original_num_layers = 25 # 1b model has 25 layers + + if target_sequence_lengths is None: + target_sequence_lengths = [1024, 1024, 1024] + + world_size = 1 + if world_size > torch.cuda.device_count(): + pytest.skip(f"World size {world_size} is greater than the number of GPUs {torch.cuda.device_count()}") + + fasta_file_path = tmp_path / "test.fasta" + create_fasta_file( + fasta_file_path, num_sequences, sequence_lengths=target_sequence_lengths, repeating_dna_pattern=ALU_SEQUENCE + ) + + # Create a local copy of the environment + env = copy.deepcopy(PRETEST_ENV) + if is_a6000_gpu(): + env["NCCL_P2P_DISABLE"] = "1" + + output_dir = tmp_path / "test_output" + open_port = find_free_network_port() + command = ( + f"torchrun --nproc_per_node {world_size} --nnodes 1 --master_port {open_port} " + f"-m bionemo.evo2.run.predict --fasta {fasta_file_path} --ckpt-dir {mbridge_checkpoint_1b_8k_bf16_path} " + f"--output-dir {output_dir} " + f"--micro-batch-size 2 --write-interval epoch " + f"--embedding-layer {embedding_layer}" + ) + + cmd_parts = shlex.split(command) + result = subprocess.run( + cmd_parts, + check=False, + cwd=tmp_path, + capture_output=True, + env=env, + text=True, + ) + + # For debugging purposes, print the output if the test fails + if result.returncode != 0: + print("STDOUT:\n" + result.stdout) + print("STDERR:\n" + result.stderr) + + # Assert that the command completed successfully + assert result.returncode == 0, f"predict_evo2 command failed with code {result.returncode}" + + # Combine stdout and stderr for log checking + combined_output = result.stdout + result.stderr + + # Verify logging about model layers is present and extract the layer count + assert "Model initialized with" in combined_output, "Expected logging about model layer count" + assert "Embedding extraction" in combined_output, "Expected logging about embedding extraction mode" + + # Parse and verify the actual number of layers from the log + # Look for pattern: "Model initialized with N layers" + layer_match = re.search(r"Model initialized with (\d+) layers", combined_output) + assert layer_match is not None, "Could not parse 'Model initialized with N layers' from output" + actual_num_layers = int(layer_match.group(1)) + assert actual_num_layers == expected_num_layers, ( + f"Expected model to have {expected_num_layers} layers for embedding_layer={embedding_layer}, " + f"but got {actual_num_layers} layers" + ) + + # Verify the embedding extraction log shows correct layer info + # Look for pattern: "using N of M layers" + extraction_match = re.search(r"using (\d+) of (\d+) layers", combined_output) + assert extraction_match is not None, "Could not parse 'using N of M layers' from output" + layers_used = int(extraction_match.group(1)) + layers_original = int(extraction_match.group(2)) + assert layers_used == expected_num_layers, ( + f"Expected 'using {expected_num_layers}' layers, but log shows 'using {layers_used}'" + ) + assert layers_original == original_num_layers, ( + f"Expected original model to have {original_num_layers} layers, but log shows {layers_original}" + ) + + # Load predictions + pred_files = sorted(glob.glob(str(output_dir / "predictions__rank_*__dp_rank_*.pt"))) + assert len(pred_files) == 1, f"Expected 1 prediction file, got {len(pred_files)}" + + preds = torch.load(pred_files[0]) + assert isinstance(preds, dict) + + # Verify expected keys for embedding extraction + assert "hidden_embeddings" in preds, "Expected 'hidden_embeddings' key in embedding extraction mode" + assert "token_logits" not in preds, "Should not have 'token_logits' in embedding extraction mode" + assert "pad_mask" in preds, "Expected 'pad_mask' key" + assert "seq_idx" in preds, "Expected 'seq_idx' key" + assert "tokens" in preds, "Expected 'tokens' key" + + # Verify shapes + hidden_embeddings = preds["hidden_embeddings"] + pad_mask = preds["pad_mask"] + tokens = preds["tokens"] + + # hidden_embeddings should be [B, S, H] where H is hidden dimension (1920 for 1b model) + assert len(hidden_embeddings.shape) == 3, f"Expected 3D tensor, got shape {hidden_embeddings.shape}" + batch_size, seq_len, hidden_dim = hidden_embeddings.shape + + assert batch_size == num_sequences, f"Expected batch size {num_sequences}, got {batch_size}" + # Sequence length should match padded length + max_seq_len = max(target_sequence_lengths) + assert seq_len == max_seq_len, f"Expected seq_len {max_seq_len}, got {seq_len}" + # Hidden dim should be 1920 for 1b model + assert hidden_dim == 1920, f"Expected hidden_dim 1920 for 1b model, got {hidden_dim}" + + # Verify pad_mask and tokens have matching shapes + assert pad_mask.shape == (batch_size, seq_len), f"pad_mask shape mismatch: {pad_mask.shape}" + assert tokens.shape == (batch_size, seq_len), f"tokens shape mismatch: {tokens.shape}" + + # Verify seq_idx has correct count + assert len(preds["seq_idx"]) == num_sequences, f"Expected {num_sequences} seq_idx entries" + + # Check sequence index map exists + seq_idx_map_path = output_dir / "seq_idx_map.json" + assert seq_idx_map_path.exists(), f"seq_idx_map.json not found at {seq_idx_map_path}" + + with open(seq_idx_map_path) as f: + seq_idx_map = json.load(f) + assert len(seq_idx_map) == num_sequences + + +@pytest.mark.slow +def test_predict_evo2_embedding_layer_validation( + tmp_path, + mbridge_checkpoint_1b_8k_bf16_path: Path, +): + """Test that invalid embedding layer values are rejected with appropriate errors.""" + fasta_file_path = tmp_path / "test.fasta" + create_fasta_file(fasta_file_path, 1, sequence_lengths=[512], repeating_dna_pattern=ALU_SEQUENCE) + + env = copy.deepcopy(PRETEST_ENV) + if is_a6000_gpu(): + env["NCCL_P2P_DISABLE"] = "1" + + output_dir = tmp_path / "test_output" + open_port = find_free_network_port() + + # Test with an invalid embedding layer (too large positive index) + # The 1b model has 25 layers, so layer 100 should be invalid + command = ( + f"torchrun --nproc_per_node 1 --nnodes 1 --master_port {open_port} " + f"-m bionemo.evo2.run.predict --fasta {fasta_file_path} --ckpt-dir {mbridge_checkpoint_1b_8k_bf16_path} " + f"--output-dir {output_dir} --embedding-layer 100" + ) + + cmd_parts = shlex.split(command) + result = subprocess.run( + cmd_parts, + check=False, + cwd=tmp_path, + capture_output=True, + env=env, + text=True, + ) + + # Should fail with an error about invalid embedding layer + assert result.returncode != 0, "Expected command to fail with invalid embedding layer" + assert "Invalid embedding_layer" in result.stderr or "Invalid embedding_layer" in result.stdout, ( + "Expected error message about invalid embedding layer" + ) + + +@pytest.mark.slow +def test_predict_evo2_embedding_with_log_probs_rejected( + tmp_path, + mbridge_checkpoint_1b_8k_bf16_path: Path, +): + """Test that using both --embedding-layer and --output-log-prob-seqs is rejected.""" + fasta_file_path = tmp_path / "test.fasta" + create_fasta_file(fasta_file_path, 1, sequence_lengths=[512], repeating_dna_pattern=ALU_SEQUENCE) + + env = copy.deepcopy(PRETEST_ENV) + if is_a6000_gpu(): + env["NCCL_P2P_DISABLE"] = "1" + + output_dir = tmp_path / "test_output" + open_port = find_free_network_port() + + # Test combining embedding extraction with log prob output (should fail) + command = ( + f"torchrun --nproc_per_node 1 --nnodes 1 --master_port {open_port} " + f"-m bionemo.evo2.run.predict --fasta {fasta_file_path} --ckpt-dir {mbridge_checkpoint_1b_8k_bf16_path} " + f"--output-dir {output_dir} --embedding-layer -1 --output-log-prob-seqs" + ) + + cmd_parts = shlex.split(command) + result = subprocess.run( + cmd_parts, + check=False, + cwd=tmp_path, + capture_output=True, + env=env, + text=True, + ) + + # Should fail with an error about incompatible options + assert result.returncode != 0, "Expected command to fail with incompatible options" + assert "Cannot use --output-log-prob-seqs with --embedding-layer" in result.stderr or ( + "Cannot use --output-log-prob-seqs with --embedding-layer" in result.stdout + ), "Expected error message about incompatible options" diff --git a/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/run/test_train.py b/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/run/test_train.py index 126b11a2bd..65ae1d976f 100644 --- a/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/run/test_train.py +++ b/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/run/test_train.py @@ -16,624 +16,633 @@ # See the License for the specific language governing permissions and # limitations under the License. -# FIXME bring back these tests once we get mbridge running. -# import argparse -# import io -# import os -# import shlex -# from contextlib import redirect_stderr, redirect_stdout -# from typing import Tuple - -# import pytest -# import torch -# from bionemo.evo2.run.train import parse_args, train -# from bionemo.testing.assert_optimizer_grads_match import assert_optimizer_states_match -# from bionemo.testing.lightning import extract_global_steps_from_log -# from bionemo.testing.megatron_parallel_state_utils import distributed_model_parallel_state -# from bionemo.testing.subprocess_utils import run_command_in_subprocess - -# # from nemo import lightning as nl -# from transformer_engine.pytorch.fp8 import check_fp8_support - -# from .common import small_training_cmd, small_training_finetune_cmd - - -# fp8_available, reason_for_no_fp8 = check_fp8_support() - - -# def run_train_with_std_redirect(args: argparse.Namespace) -> Tuple[str, nl.Trainer]: -# """Run a function with output capture.""" -# stdout_buf, stderr_buf = io.StringIO(), io.StringIO() -# with redirect_stdout(stdout_buf), redirect_stderr(stderr_buf): -# with distributed_model_parallel_state(): -# trainer: nl.Trainer = train(args) - -# train_stdout = stdout_buf.getvalue() -# train_stderr = stderr_buf.getvalue() -# print("Captured STDOUT:\n", train_stdout) -# print("Captured STDERR:\n", train_stderr) -# return train_stdout, trainer - - -# def distributed_training_cmd( -# path, -# max_steps, -# val_check, -# num_devices, -# dp, -# tp, -# cp, -# pp, -# dataset_dir=None, -# training_config=None, -# additional_args: str = "", -# ): -# """Create distributed training command with specified parallelism settings. - -# Args: -# path: Result directory path -# max_steps: Maximum training steps -# val_check: Validation check interval -# num_devices: Total number of devices -# dp: Data parallel size -# tp: Tensor parallel size -# cp: Context parallel size -# pp: Pipeline parallel size -# dataset_dir: Path to preprocessed dataset directory (if None, uses --mock-data) -# training_config: Path to training data config YAML file (required if dataset_dir is provided) -# additional_args: Additional command line arguments -# """ -# micro_batch_size = 1 if dp == 2 else 2 - -# # Use real dataset if provided, otherwise fall back to mock data -# if dataset_dir and training_config: -# data_args = f"-d {training_config} --dataset-dir {dataset_dir}" -# else: -# data_args = "--mock-data" - -# cmd = ( -# f"train_evo2 {data_args} --result-dir {path} --devices {num_devices} " -# f"--tensor-parallel-size {tp} --pipeline-model-parallel-size {pp} --context-parallel-size {cp} " -# "--model-size 7b --num-layers 4 --hybrid-override-pattern SDH* --limit-val-batches 1 " -# "--no-activation-checkpointing --add-bias-output --create-tensorboard-logger --create-tflops-callback " -# f"--max-steps {max_steps} --warmup-steps 1 --val-check-interval {val_check} --limit-val-batches 1 " -# f"--seq-length 64 --hidden-dropout 0.0 --attention-dropout 0.0 --seed 42 --workers 0 " -# f"--micro-batch-size {micro_batch_size} --global-batch-size 2 " -# f"--adam-beta1 0 --adam-beta2 0 {additional_args}" -# ) -# return cmd - - -# def small_training_mamba_cmd(path, max_steps, val_check, devices: int = 1, additional_args: str = ""): -# cmd = ( -# f"train_evo2 --mock-data --result-dir {path} --devices {devices} " -# "--model-size hybrid_mamba_8b --num-layers 2 --hybrid-override-pattern M- --limit-val-batches 1 " -# "--no-activation-checkpointing --create-tensorboard-logger --create-tflops-callback " -# f"--max-steps {max_steps} --warmup-steps 1 --val-check-interval {val_check} --limit-val-batches 1 " -# f"--seq-length 8 --hidden-dropout 0.1 --attention-dropout 0.1 {additional_args}" -# ) -# return cmd - - -# def small_training_mamba_finetune_cmd( -# path, max_steps, val_check, prev_ckpt, devices: int = 1, additional_args: str = "" -# ): -# cmd = ( -# f"train_evo2 --mock-data --result-dir {path} --devices {devices} " -# "--model-size hybrid_mamba_8b --num-layers 2 --hybrid-override-pattern M- --limit-val-batches 1 " -# "--no-activation-checkpointing --create-tensorboard-logger --create-tflops-callback " -# f"--max-steps {max_steps} --warmup-steps 1 --val-check-interval {val_check} --limit-val-batches 1 " -# f"--seq-length 16 --hidden-dropout 0.1 --attention-dropout 0.1 {additional_args} --ckpt-dir {prev_ckpt}" -# ) -# return cmd - - -# def small_training_llama_cmd(path, max_steps, val_check, devices: int = 1, additional_args: str = ""): -# cmd = ( -# f"train_evo2 --no-fp32-residual-connection --mock-data --result-dir {path} --devices {devices} " -# "--model-size 8B --num-layers 2 --limit-val-batches 1 " -# "--no-activation-checkpointing --create-tensorboard-logger --create-tflops-callback " -# f"--max-steps {max_steps} --warmup-steps 1 --val-check-interval {val_check} --limit-val-batches 1 " -# f"--seq-length 8 --hidden-dropout 0.1 --attention-dropout 0.1 {additional_args}" -# ) -# return cmd - - -# def small_training_llama_finetune_cmd( -# path, max_steps, val_check, prev_ckpt, devices: int = 1, additional_args: str = "" -# ): -# cmd = ( -# f"train_evo2 --no-fp32-residual-connection --mock-data --result-dir {path} --devices {devices} " -# "--model-size 8B --num-layers 2 --limit-val-batches 1 " -# "--no-activation-checkpointing --create-tensorboard-logger --create-tflops-callback " -# f"--max-steps {max_steps} --warmup-steps 1 --val-check-interval {val_check} --limit-val-batches 1 " -# f"--seq-length 16 --hidden-dropout 0.1 --attention-dropout 0.1 {additional_args} --ckpt-dir {prev_ckpt}" -# ) -# return cmd - - -# @pytest.mark.timeout(512) # Optional: fail if the test takes too long. -# @pytest.mark.slow -# def test_train_evo2_finetune_runs(tmp_path): -# """ -# This test runs the `train_evo2` command with mock data in a temporary directory. -# It uses the temporary directory provided by pytest as the working directory. -# The command is run in a subshell, and we assert that it returns an exit code of 0. -# """ -# num_steps = 2 -# # Note: The command assumes that `train_evo2` is in your PATH. -# command = small_training_cmd(tmp_path / "pretrain", max_steps=num_steps, val_check=num_steps) -# stdout_pretrain: str = run_command_in_subprocess(command=command, path=str(tmp_path)) -# assert "Restoring model weights from RestoreConfig(path='" not in stdout_pretrain - -# log_dir = tmp_path / "pretrain" / "evo2" -# checkpoints_dir = log_dir / "checkpoints" -# tensorboard_dir = log_dir / "dev" - -# # Check if logs dir exists -# assert log_dir.exists(), "Logs folder should exist." -# # Check if checkpoints dir exists -# assert checkpoints_dir.exists(), "Checkpoints folder does not exist." - -# expected_checkpoint_suffix = f"{num_steps}.0-last" -# # Check if any subfolder ends with the expected suffix -# matching_subfolders = [ -# p for p in checkpoints_dir.iterdir() if p.is_dir() and (expected_checkpoint_suffix in p.name) -# ] - -# assert matching_subfolders, ( -# f"No checkpoint subfolder ending with '{expected_checkpoint_suffix}' found in {checkpoints_dir}." -# ) - -# # Check if directory with tensorboard logs exists -# assert tensorboard_dir.exists(), "TensorBoard logs folder does not exist." -# # Recursively search for files with tensorboard logger -# event_files = list(tensorboard_dir.rglob("events.out.tfevents*")) -# assert event_files, f"No TensorBoard event files found under {tensorboard_dir}" -# assert len(matching_subfolders) == 1, "Only one checkpoint subfolder should be found." -# command_finetune = small_training_finetune_cmd( -# tmp_path / "finetune", max_steps=num_steps, val_check=num_steps, prev_ckpt=matching_subfolders[0] -# ) -# stdout_finetune: str = run_command_in_subprocess(command=command_finetune, path=str(tmp_path)) -# assert "Restoring model weights from RestoreConfig(path='" in stdout_finetune - -# log_dir_ft = tmp_path / "finetune" / "evo2" -# checkpoints_dir_ft = log_dir_ft / "checkpoints" -# tensorboard_dir_ft = log_dir_ft / "dev" - -# # Check if logs dir exists -# assert log_dir_ft.exists(), "Logs folder should exist." -# # Check if checkpoints dir exists -# assert checkpoints_dir_ft.exists(), "Checkpoints folder does not exist." - -# expected_checkpoint_suffix = f"{num_steps}.0-last" -# # Check if any subfolder ends with the expected suffix -# matching_subfolders_ft = [ -# p for p in checkpoints_dir_ft.iterdir() if p.is_dir() and (expected_checkpoint_suffix in p.name) -# ] - -# assert matching_subfolders_ft, ( -# f"No checkpoint subfolder ending with '{expected_checkpoint_suffix}' found in {checkpoints_dir_ft}." -# ) - -# # Check if directory with tensorboard logs exists -# assert tensorboard_dir_ft.exists(), "TensorBoard logs folder does not exist." -# # Recursively search for files with tensorboard logger -# event_files = list(tensorboard_dir_ft.rglob("events.out.tfevents*")) -# assert event_files, f"No TensorBoard event files found under {tensorboard_dir_ft}" - -# assert len(matching_subfolders_ft) == 1, "Only one checkpoint subfolder should be found." - - -# @pytest.mark.timeout(512) # Optional: fail if the test takes too long. -# @pytest.mark.slow -# def test_train_evo2_mamba_finetune_runs(tmp_path): -# """ -# This test runs the `train_evo2` command with mock data in a temporary directory. -# It uses the temporary directory provided by pytest as the working directory. -# The command is run in a subshell, and we assert that it returns an exit code of 0. -# """ -# num_steps = 2 -# # Note: The command assumes that `train_evo2` is in your PATH. -# command = small_training_mamba_cmd(tmp_path / "pretrain", max_steps=num_steps, val_check=num_steps) -# stdout_pretrain: str = run_command_in_subprocess(command=command, path=str(tmp_path)) -# assert "Restoring model weights from RestoreConfig(path='" not in stdout_pretrain - -# log_dir = tmp_path / "pretrain" / "evo2" -# checkpoints_dir = log_dir / "checkpoints" -# tensorboard_dir = log_dir / "dev" - -# # Check if logs dir exists -# assert log_dir.exists(), "Logs folder should exist." -# # Check if checkpoints dir exists -# assert checkpoints_dir.exists(), "Checkpoints folder does not exist." - -# expected_checkpoint_suffix = f"{num_steps}.0-last" -# # Check if any subfolder ends with the expected suffix -# matching_subfolders = [ -# p for p in checkpoints_dir.iterdir() if p.is_dir() and (expected_checkpoint_suffix in p.name) -# ] - -# assert matching_subfolders, ( -# f"No checkpoint subfolder ending with '{expected_checkpoint_suffix}' found in {checkpoints_dir}." -# ) - -# # Check if directory with tensorboard logs exists -# assert tensorboard_dir.exists(), "TensorBoard logs folder does not exist." -# # Recursively search for files with tensorboard logger -# event_files = list(tensorboard_dir.rglob("events.out.tfevents*")) -# assert event_files, f"No TensorBoard event files found under {tensorboard_dir}" - -# assert len(matching_subfolders) == 1, "Only one checkpoint subfolder should be found." -# command_finetune = small_training_mamba_finetune_cmd( -# tmp_path / "finetune", max_steps=num_steps, val_check=num_steps, prev_ckpt=matching_subfolders[0] -# ) -# stdout_finetune: str = run_command_in_subprocess(command=command_finetune, path=str(tmp_path)) -# assert "Restoring model weights from RestoreConfig(path='" in stdout_finetune - -# log_dir_ft = tmp_path / "finetune" / "evo2" -# checkpoints_dir_ft = log_dir_ft / "checkpoints" -# tensorboard_dir_ft = log_dir_ft / "dev" - -# # Check if logs dir exists -# assert log_dir_ft.exists(), "Logs folder should exist." -# # Check if checkpoints dir exists -# assert checkpoints_dir_ft.exists(), "Checkpoints folder does not exist." - -# expected_checkpoint_suffix = f"{num_steps}.0-last" -# # Check if any subfolder ends with the expected suffix -# matching_subfolders_ft = [ -# p for p in checkpoints_dir_ft.iterdir() if p.is_dir() and (expected_checkpoint_suffix in p.name) -# ] - -# assert matching_subfolders_ft, ( -# f"No checkpoint subfolder ending with '{expected_checkpoint_suffix}' found in {checkpoints_dir_ft}." -# ) - -# # Check if directory with tensorboard logs exists -# assert tensorboard_dir_ft.exists(), "TensorBoard logs folder does not exist." -# # Recursively search for files with tensorboard logger -# event_files = list(tensorboard_dir_ft.rglob("events.out.tfevents*")) -# assert event_files, f"No TensorBoard event files found under {tensorboard_dir_ft}" - -# assert len(matching_subfolders_ft) == 1, "Only one checkpoint subfolder should be found." - - -# @pytest.mark.timeout(512) # Optional: fail if the test takes too long. -# @pytest.mark.slow -# def test_train_evo2_llama_finetune_runs(tmp_path): -# """ -# This test runs the `train_evo2` command with mock data in a temporary directory using Llama model. -# It uses the temporary directory provided by pytest as the working directory. -# The command is run in a subshell, and we assert that it returns an exit code of 0. -# """ -# num_steps = 2 -# # Note: The command assumes that `train_evo2` is in your PATH. -# command = small_training_llama_cmd(tmp_path / "pretrain", max_steps=num_steps, val_check=num_steps) -# stdout_pretrain: str = run_command_in_subprocess(command=command, path=str(tmp_path)) -# assert "Restoring model weights from RestoreConfig(path='" not in stdout_pretrain - -# log_dir = tmp_path / "pretrain" / "evo2" -# checkpoints_dir = log_dir / "checkpoints" -# tensorboard_dir = log_dir / "dev" - -# # Check if logs dir exists -# assert log_dir.exists(), "Logs folder should exist." -# # Check if checkpoints dir exists -# assert checkpoints_dir.exists(), "Checkpoints folder does not exist." - -# expected_checkpoint_suffix = f"{num_steps}.0-last" -# # Check if any subfolder ends with the expected suffix -# matching_subfolders = [ -# p for p in checkpoints_dir.iterdir() if p.is_dir() and (expected_checkpoint_suffix in p.name) -# ] - -# assert matching_subfolders, ( -# f"No checkpoint subfolder ending with '{expected_checkpoint_suffix}' found in {checkpoints_dir}." -# ) - -# # Check if directory with tensorboard logs exists -# assert tensorboard_dir.exists(), "TensorBoard logs folder does not exist." -# # Recursively search for files with tensorboard logger -# event_files = list(tensorboard_dir.rglob("events.out.tfevents*")) -# assert event_files, f"No TensorBoard event files found under {tensorboard_dir}" - -# assert len(matching_subfolders) == 1, "Only one checkpoint subfolder should be found." -# command_finetune = small_training_llama_finetune_cmd( -# tmp_path / "finetune", max_steps=num_steps, val_check=num_steps, prev_ckpt=matching_subfolders[0] -# ) -# stdout_finetune: str = run_command_in_subprocess(command=command_finetune, path=str(tmp_path)) -# assert "Restoring model weights from RestoreConfig(path='" in stdout_finetune - -# log_dir_ft = tmp_path / "finetune" / "evo2" -# checkpoints_dir_ft = log_dir_ft / "checkpoints" -# tensorboard_dir_ft = log_dir_ft / "dev" - -# # Check if logs dir exists -# assert log_dir_ft.exists(), "Logs folder should exist." -# # Check if checkpoints dir exists -# assert checkpoints_dir_ft.exists(), "Checkpoints folder does not exist." - -# expected_checkpoint_suffix = f"{num_steps}.0-last" -# matching_subfolders_ft = [ -# p for p in checkpoints_dir_ft.iterdir() if p.is_dir() and (expected_checkpoint_suffix in p.name) -# ] - -# assert matching_subfolders_ft, ( -# f"No checkpoint subfolder ending with '{expected_checkpoint_suffix}' found in {checkpoints_dir_ft}." -# ) - -# # Check if directory with tensorboard logs exists -# assert tensorboard_dir_ft.exists(), "TensorBoard logs folder does not exist." -# # Recursively search for files with tensorboard logger -# event_files = list(tensorboard_dir_ft.rglob("events.out.tfevents*")) -# assert event_files, f"No TensorBoard event files found under {tensorboard_dir_ft}" - -# assert len(matching_subfolders_ft) == 1, "Only one checkpoint subfolder should be found." - - -# @pytest.mark.timeout(256) # Optional: fail if the test takes too long. -# @pytest.mark.slow -# def test_train_evo2_stops(tmp_path): -# """ -# This test runs the `train_evo2` command with mock data in a temporary directory. -# It uses the temporary directory provided by pytest as the working directory. -# The command is run in a subshell, and we assert that it returns an exit code of 0. -# """ -# max_steps = 500000 -# early_stop_steps = 4 -# val_check = 2 -# additional_args = f"--early-stop-on-step {early_stop_steps}" -# # Expected location of logs and checkpoints -# log_dir = tmp_path / "evo2" -# checkpoints_dir = log_dir / "checkpoints" - -# assert not log_dir.exists(), "Logs folder shouldn't exist yet." - -# # Note: The command assumes that `train_evo2` is in your PATH. -# command = small_training_cmd(tmp_path, max_steps=max_steps, val_check=val_check, additional_args=additional_args) -# command_parts_no_program = shlex.split(command)[1:] -# args = parse_args(args=command_parts_no_program) -# train_stdout, trainer = run_train_with_std_redirect(args) - -# assert f"Training epoch 0, iteration 0/{early_stop_steps - 1}" in train_stdout -# # Extract and validate global steps -# global_steps = extract_global_steps_from_log(train_stdout) -# assert global_steps[0] == 0 -# assert global_steps[-1] == (early_stop_steps - 1) -# assert trainer.global_step == early_stop_steps -# assert len(global_steps) == early_stop_steps - -# expected_checkpoint_suffix = f"{early_stop_steps}.0-last" -# # Check if checkpoints dir exists -# assert checkpoints_dir.exists(), "Checkpoints folder does not exist." - -# # Check if any subfolder ends with the expected suffix -# matching_subfolders = [ -# p for p in checkpoints_dir.iterdir() if p.is_dir() and (expected_checkpoint_suffix in p.name) -# ] - -# assert matching_subfolders, ( -# f"No checkpoint subfolder ending with '{expected_checkpoint_suffix}' found in {checkpoints_dir}." -# ) - -# assert "reduced_train_loss" in trainer.logged_metrics # validation logging on by default -# assert "TFLOPS_per_GPU" in trainer.logged_metrics # ensuring that tflops logger can be added -# assert "train_step_timing in s" in trainer.logged_metrics - - -# @pytest.mark.parametrize( -# "additional_args", -# [ -# pytest.param("", id="no_fp8"), -# pytest.param( -# "--fp8", -# marks=[ -# pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8), -# ], -# id="fp8", -# ), -# ], -# ) -# @pytest.mark.timeout(512) # Optional: fail if the test takes too long. -# @pytest.mark.slow -# def test_train_evo2_stop_at_max_steps_and_continue(tmp_path, additional_args): -# max_steps_first_run = 4 -# max_steps_second_run = 6 -# val_check_interval = 2 -# # Expected location of logs and checkpoints -# log_dir = tmp_path / "evo2" -# checkpoints_dir = log_dir / "checkpoints" - -# command_first_run = small_training_cmd( -# tmp_path, max_steps_first_run, val_check_interval, additional_args=additional_args -# ) - -# # The first training command to finish at max_steps_first_run -# stdout_first_run = run_command_in_subprocess(command=command_first_run, path=str(tmp_path)) - -# assert f"Training epoch 0, iteration 0/{max_steps_first_run - 1}" in stdout_first_run -# # Extract and validate global steps -# global_steps_first_run = extract_global_steps_from_log(stdout_first_run) - -# assert global_steps_first_run[0] == 0 -# assert global_steps_first_run[-1] == max_steps_first_run - 1 -# assert len(global_steps_first_run) == max_steps_first_run - -# expected_checkpoint_first_run_suffix = f"{max_steps_first_run}.0-last" -# # Check if checkpoints dir exists -# assert checkpoints_dir.exists(), "Checkpoints folder does not exist." -# # Check if any ckpt subfolder ends with the expected suffix -# matching_subfolders = [ -# p for p in checkpoints_dir.iterdir() if p.is_dir() and (expected_checkpoint_first_run_suffix in p.name) -# ] -# assert matching_subfolders, ( -# f"No checkpoint subfolder ending with '{expected_checkpoint_first_run_suffix}' found in {checkpoints_dir}." -# ) - -# # The second training command to continue from max_steps_first_run and finish at max_steps_second_run -# command_second_run = small_training_cmd( -# tmp_path, max_steps_second_run, val_check_interval, additional_args=additional_args -# ) -# stdout_second_run = run_command_in_subprocess(command=command_second_run, path=str(tmp_path)) -# global_steps_second_run = extract_global_steps_from_log(stdout_second_run) - -# assert global_steps_second_run[0] == max_steps_first_run -# assert global_steps_second_run[-1] == max_steps_second_run - 1 -# assert len(global_steps_second_run) == max_steps_second_run - max_steps_first_run - -# expected_checkpoint_second_run_suffix = f"{max_steps_second_run}.0-last" -# matching_subfolders = [ -# p for p in checkpoints_dir.iterdir() if p.is_dir() and (expected_checkpoint_second_run_suffix in p.name) -# ] -# assert matching_subfolders, ( -# f"No checkpoint subfolder ending with '{expected_checkpoint_second_run_suffix}' found in {checkpoints_dir}." -# ) - - -# @pytest.fixture(scope="session") -# def dataset_config(request): -# """Get dataset directory and training config from command line options or environment variables. - -# Users can provide dataset paths via: -# - Command line: pytest --dataset-dir=/path/to/data --training-config=/path/to/config.yaml -# - Environment: DATASET_DIR=/path/to/data TRAINING_CONFIG=/path/to/config.yaml pytest - -# If not provided, tests will fall back to --mock-data. -# """ -# # Try to get from pytest command line options first -# dataset_dir = request.config.getoption("--dataset-dir", default=None) -# training_config = request.config.getoption("--training-config", default=None) - -# # Fall back to environment variables -# if not dataset_dir: -# dataset_dir = os.environ.get("DATASET_DIR") -# if not training_config: -# training_config = os.environ.get("TRAINING_CONFIG") - -# return {"dataset_dir": dataset_dir, "training_config": training_config} - - -# @pytest.fixture(scope="session") -# def initial_checkpoint(): -# """Load the initial checkpoint for distributed training tests.""" -# from bionemo.core.data.load import load - -# return load("evo2/7b-8k:1.0") - - -# @pytest.fixture(scope="session") -# def base_checkpoint(tmp_path_factory, initial_checkpoint, dataset_config): -# """Create a base checkpoint by training one step with no parallelism. - -# This fixture is session-scoped, so it creates the checkpoint once and reuses it -# across all parametrized test cases, significantly improving test performance. -# """ -# num_steps = 1 -# tmp_path = tmp_path_factory.mktemp("base_checkpoint_session") -# base_path = tmp_path / "base_training" - -# # Create command with the initial checkpoint and dataset (if provided) -# cmd = distributed_training_cmd( -# path=base_path, -# max_steps=num_steps, -# val_check=num_steps, -# num_devices=1, -# dp=1, -# tp=1, -# cp=1, -# pp=1, -# dataset_dir=dataset_config["dataset_dir"], -# training_config=dataset_config["training_config"], -# additional_args=f"--ckpt-dir {initial_checkpoint}", -# ) - -# # Run training -# stdout = run_command_in_subprocess(command=cmd, path=str(tmp_path)) -# assert "Restoring model weights from RestoreConfig(path='" in stdout - -# # Find the resulting checkpoint -# log_dir = base_path / "evo2" -# checkpoints_dir = log_dir / "checkpoints" -# # Lightning uses 0-indexed step counting, so after max_steps=1, we're at step 0 -# expected_checkpoint_suffix = "step=0" - -# matching_subfolders = [ -# p for p in checkpoints_dir.iterdir() if p.is_dir() and (expected_checkpoint_suffix in p.name) -# ] - -# assert len(matching_subfolders) == 1, "Expected exactly one checkpoint subfolder" -# return matching_subfolders[0] - - -# @pytest.mark.parametrize( -# "dp,cp,tp,pp", -# [ -# pytest.param(2, 1, 1, 1, id="data_parallel"), -# pytest.param(1, 2, 1, 1, id="context_parallel"), -# pytest.param(1, 1, 2, 1, id="tensor_parallel"), -# pytest.param(1, 1, 1, 2, id="pipeline_parallel"), -# ], -# ) -# @pytest.mark.timeout(900) -# @pytest.mark.slow -# @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Test requires at least 2 GPUs") -# def test_distributed_training_gradient_equivalence( -# tmp_path, initial_checkpoint, base_checkpoint, dataset_config, dp, cp, tp, pp -# ): -# """Test that gradients are equivalent across different distributed training strategies.""" -# # NOTE: Megatron Core is changing its distributed checkpoint format soon. This test needs to be updated after release 0.14. -# num_steps = 1 - -# # Calculate total devices needed -# num_devices = dp * cp * tp * pp -# assert num_devices == 2, ( -# f"Test is designed for 2 GPUs but got {num_devices} for dp={dp}, cp={cp}, tp={tp}, pp={pp}" -# ) - -# # Create parallel training checkpoint -# parallel_path = tmp_path / f"parallel_dp{dp}_cp{cp}_tp{tp}_pp{pp}" - -# cmd = distributed_training_cmd( -# path=parallel_path, -# max_steps=num_steps, -# val_check=num_steps, -# num_devices=num_devices, -# dp=dp, -# tp=tp, -# cp=cp, -# pp=pp, -# dataset_dir=dataset_config["dataset_dir"], -# training_config=dataset_config["training_config"], -# additional_args=f"--ckpt-dir {initial_checkpoint}", -# ) - -# # Run distributed training -# stdout = run_command_in_subprocess(command=cmd, path=str(tmp_path)) -# assert "Restoring model weights from RestoreConfig(path='" in stdout - -# # Find the resulting checkpoint -# log_dir = parallel_path / "evo2" -# checkpoints_dir = log_dir / "checkpoints" -# # Lightning uses 0-indexed step counting, so after max_steps=1, we're at step 0 -# expected_checkpoint_suffix = "step=0" - -# matching_subfolders = [ -# p for p in checkpoints_dir.iterdir() if p.is_dir() and (expected_checkpoint_suffix in p.name) -# ] - -# assert len(matching_subfolders) == 1, "Expected exactly one checkpoint subfolder" -# parallel_checkpoint = matching_subfolders[0] - -# # Compare gradients/optimizer states between base and parallel distributed training -# print(f"Base checkpoint: {base_checkpoint}") -# print(f"Parallel checkpoint (dp={dp}, cp={cp}, tp={tp}, pp={pp}): {parallel_checkpoint}") - -# # Ensure both checkpoints exist before comparison -# assert base_checkpoint.exists(), "Base checkpoint should exist" -# assert parallel_checkpoint.exists(), "Parallel checkpoint should exist" - -# # Use the custom gradient comparison logic to verify optimizer states match -# # This implements theorem 5.3 of https://www.arxiv.org/pdf/2506.09280 for gradient equivalence -# checkpoint_dirs = [str(base_checkpoint / "weights"), str(parallel_checkpoint / "weights")] -# assert_optimizer_states_match(checkpoint_dirs) + +import copy +import os +import re +import shlex +import shutil +import subprocess +from pathlib import Path +from typing import Dict, Iterable, Optional, Tuple, Union + +import pytest +import torch +from tensorboard.backend.event_processing.event_accumulator import EventAccumulator +from torch.distributed.checkpoint.filesystem import FileSystemReader +from torch.distributed.checkpoint.state_dict_loader import load + +from bionemo.evo2.data.dataset_tokenizer import DEFAULT_HF_TOKENIZER_MODEL_PATH + +from ..utils import find_free_network_port, is_a6000_gpu, is_fp4_supported, is_fp8_supported, is_mxfp8_supported + + +TensorLike = Union[torch.Tensor, Iterable[torch.Tensor]] + + +def _as_iter(x: TensorLike): + return x if (isinstance(x, Iterable) and not isinstance(x, torch.Tensor)) else [x] + + +def _fro_norm(x: TensorLike) -> torch.Tensor: + """Frobenius norm; supports sharded tensors (sum of shard ||·||_F^2).""" + it = list(_as_iter(x)) # Convert to list to avoid iterator consumption issues + if not it: + return torch.tensor(0.0, device="cpu") + s = torch.tensor(0.0, device=it[0].device) + for t in it: + s = s + t.float().pow(2).sum() + return torch.sqrt(s) + + +def machine_epsilon_for_dtype(dtype: torch.dtype) -> float: + """Return machine epsilon for dtype. For FP8, use BF16 epsilon per paper.""" + # Standard types + if dtype in (torch.float32, torch.float16, torch.bfloat16): + return float(torch.finfo(dtype).eps) + # FP8 recipes: accum/store typically BF16/FP32; use BF16 epsilon + if hasattr(torch, "float8_e4m3fn") and dtype in ( + torch.float8_e4m3fn, + getattr(torch, "float8_e5m2fn", None), + ): + return float(torch.finfo(torch.bfloat16).eps) + # Fallback + return float(torch.finfo(torch.float32).eps) + + +def relative_grad_diff(g_hat: TensorLike, g_ref: TensorLike, eps_den: float = 1e-30) -> float: + """Relative difference ||g_hat - g_ref||_F / ||g_ref||_F. + + Accepts a single tensor or an iterable of shards for each argument. + """ + # Convert to lists to avoid iterator consumption issues + gh_list = list(_as_iter(g_hat)) + gr_list = list(_as_iter(g_ref)) + + if len(gh_list) != len(gr_list): + raise ValueError(f"Shard count mismatch: {len(gh_list)} vs {len(gr_list)}") + + if not gh_list: + return 0.0 + + num_sq = torch.tensor(0.0, device=gh_list[0].device) + for a, b in zip(gh_list, gr_list): + num_sq = num_sq + (a.float() - b.float()).pow(2).sum() + num = torch.sqrt(num_sq) + den = _fro_norm(g_ref) + return float(num / (den + eps_den)) + + +def expected_rel_bound( + l: int, # noqa: E741 + *, + L: int = 32, # noqa: N803 + C: float = 1.03, # noqa: N803 + dtype: Optional[torch.dtype] = torch.bfloat16, + k: float = 4.0, +) -> float: + """Bound ~ k * (C ** (L + 1 - l)) * eps_mch, with 1-based layer index l. + + - L is hard-coded default to 32 per your request. + - C is 'close to 1'; 1.01-1.05 are reasonable defaults. + - k absorbs the hidden constant in big-O; 2-8 are common choices. + - dtype controls eps_mch; for FP8 use BF16 epsilon (see https://www.arxiv.org/pdf/2506.09280 theorem 5.3). + """ + eps_mch = machine_epsilon_for_dtype(dtype or torch.bfloat16) + depth = L + 1 - l # 1-based depth from the top (as in the theorem) + depth = max(depth, 0) + return float(k * (C**depth) * eps_mch) + + +def check_gradient( + g_hat: TensorLike, + g_ref: TensorLike, + l: int, # noqa: E741 + *, + L: int = 32, # noqa: N803 + C: float = 1.03, # noqa: N803 + dtype: Optional[torch.dtype] = None, + k: float = 4.0, +) -> Tuple[float, float, bool]: + """Compute (rel_error, bound, ok) for layer l. + + - If dtype is None, infer from g_ref (or g_hat if needed). + # See https://www.arxiv.org/pdf/2506.09280 theorem 5.3 + """ + # Infer dtype if not provided + if dtype is None: + gr_list = list(_as_iter(g_ref)) + if gr_list: + dtype = gr_list[0].dtype + else: + dtype = torch.bfloat16 # fallback + rel = relative_grad_diff(g_hat, g_ref) + bnd = expected_rel_bound(l, L=L, C=C, dtype=dtype, k=k) + return rel, bnd, (rel <= bnd) + + +def _filter_optimizer_tensors(plain_tensors: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + """Return only optimizer-related tensors from a flat checkpoint tensor dict.""" + return {k: v for k, v in plain_tensors.items() if k.startswith("optimizer.") and ".exp_avg." in k} + + +def assert_grads_close(left: torch.Tensor, right: torch.Tensor): + """Assert that two gradient tensors are close using theorem 5.3 of https://www.arxiv.org/pdf/2506.09280.""" + # Implement theorem 5.3 of https://www.arxiv.org/pdf/2506.09280 + + # This is the real test: + # k=5.0 provides margin for small numerical differences in sequence parallel gradient sync + rel, bnd, ok = check_gradient( + left, right, l=0, dtype=torch.bfloat16, k=5.0 + ) # hard code to layer 0 since that's the most permissive + + # If the real test above fails, run an assert close for the useful diagnostics and raise either way. + if not ok: + rel_shuff, _, ok_shuff = check_gradient( + left, torch.roll(right, shifts=-1, dims=-1), l=0, dtype=torch.bfloat16, k=5.0 + ) + + try: + torch.testing.assert_close(left, right) + msg = ( + "AssertionError on relative norm magnitude " + f"(rel={rel}, bnd={bnd}, ok={ok}, rel_shuff={rel_shuff}, ok_shuff={ok_shuff}) " + "but torch.testing.assert_close(left, right) passes. \n" + f"Left: {left.shape}/{left.dtype} {left}\n" + f"Right: {right.shape}/{right.dtype} {right}" + ) + except AssertionError as e: + msg = ( + "AssertionError on relative norm magnitude " + f"(rel={rel}, bnd={bnd}, ok={ok}, rel_shuff={rel_shuff}, ok_shuff={ok_shuff}): {e}\n" + f"Left: {left.shape}/{left.dtype} {left}\n" + f"Right: {right.shape}/{right.dtype} {right}" + ) + raise AssertionError(msg) + + +def _assert_optimizer_tensors_equal( + left: Dict[str, torch.Tensor], + right: Dict[str, torch.Tensor], + eps=1e-4, +): + left_keys = set(left.keys()) + right_keys = set(right.keys()) + + only_left = sorted(left_keys - right_keys) + only_right = sorted(right_keys - left_keys) + assert not only_left and not only_right, ( + f"Optimizer tensor keys mismatch.\nOnly in left: {only_left}\nOnly in right: {only_right}" + ) + some_non_zero = False + assertions = [] + for key in sorted(left_keys): + lt, rt = left[key], right[key] + assert lt.shape == rt.shape and lt.dtype == rt.dtype, ( + f"Tensor meta mismatch for {key}: {lt.shape}/{lt.dtype} vs {rt.shape}/{rt.dtype}" + ) + # Reduce the rate of 0 vs near 0 rtol failures by adding a small epsilon + left_scale = torch.max(torch.abs(lt)) + right_scale = torch.max(torch.abs(rt)) + if left_scale <= eps and right_scale <= eps: + print( + f"WARNING: zero-ish scale tensors ({left_scale=} vs {right_scale=}) " + f"so they will trivially pass comparing {key=}" + ) + else: + some_non_zero = True + try: + assert_grads_close(lt, rt) + print(f"Optimizer tensors match for {key}") + except AssertionError as e: + assertions.append(AssertionError(f"AssertionError for {key}: {e}")) + assert not assertions, f"Assertion Errors found comparing keys: {assertions}" + assert some_non_zero, "No non-zero tensors found in this comparison" + + +def load_dist_checkpoint_pt( + ckpt_dir, + metadata_ckpt_dir=None, + pattern=r"optimizer", + device="cpu", + return_full_empty: bool = False, +): + """Return {full_key: tensor} for every tensor whose key matches *pattern*.""" + meta_ckpt_dir = Path(metadata_ckpt_dir or ckpt_dir) + meta_reader = FileSystemReader(str(meta_ckpt_dir)) + + # --- fast metadata pass (no tensor data yet) ----------------------------- + meta = meta_reader.read_metadata() # tiny JSON read + tmeta = meta.state_dict_metadata # key ➜ TensorMetadata + if return_full_empty: + wanted = [k for k in tmeta if hasattr(tmeta[k], "size")] + else: + wanted = [k for k in tmeta if re.search(pattern, k) and hasattr(tmeta[k], "size")] + if not wanted: + raise ValueError(f"No keys matching /{pattern}/ in {ckpt_dir}") + + # --- build "empty" placeholders ----------------------------------------- + placeholders = { + k: torch.empty(tuple(tmeta[k].size), dtype=tmeta[k].properties.dtype, device=device) for k in wanted + } + if return_full_empty: + return placeholders + # --- stream just those tensors (no process-group needed) ----------------- + data_reader = FileSystemReader(str(ckpt_dir)) + + load( + state_dict=placeholders, + storage_reader=data_reader, + no_dist=False, # switches off all collectives + ) + return placeholders # dict[str, Tensor] + + +def assert_optimizer_states_match(checkpoint_dirs): + """Compare optimizer state across provided torch_dist checkpoints. + + - Keys: ensure the set of optimizer tensor keys match across checkpoints + - Values: ensure corresponding tensors are equal (allclose) + - Structure (non-tensor common state): ensure common optimizer structures match + """ + assert len(checkpoint_dirs) > 1, "This test requires 2 or more checkpoints [ ...]." + + base_dir = checkpoint_dirs[0] + + # Compare optimizer tensors + base_plain = load_dist_checkpoint_pt(base_dir) + base_opt_tensors = _filter_optimizer_tensors(base_plain) + assert base_opt_tensors, f"No optimizer tensors found in checkpoint: {base_dir}" + assertions = [] + for other_dir in checkpoint_dirs[1:]: + try: + other_plain = load_dist_checkpoint_pt(other_dir) + other_opt_tensors = _filter_optimizer_tensors(other_plain) + assert other_opt_tensors, f"No optimizer tensors found in checkpoint: {other_dir}" + _assert_optimizer_tensors_equal(base_opt_tensors, other_opt_tensors) + print(f"Optimizer tensors match for {base_dir} and {other_dir}") + del other_plain + del other_opt_tensors + except AssertionError as e: # noqa: PERF203 + msg = f"AssertionError comparing {base_dir} to {other_dir}:\n{e}" + print(f"Optimizer tensors mismatch for {base_dir} and {other_dir}:\n{msg}") + assertions.append(AssertionError(msg)) + assert not assertions, f"AssertionErrors comparing {checkpoint_dirs}:\n{assertions}" + + +# Do this at collection time before we run any tests. +PRETEST_ENV = copy.deepcopy(os.environ) + + +def _run_train_command(cmd: str, run_dir: Path) -> str: + env = copy.deepcopy(PRETEST_ENV) + env["MASTER_PORT"] = str(find_free_network_port()) + result = subprocess.run( + shlex.split(cmd), + check=False, + capture_output=True, + text=True, + cwd=run_dir, + env=env, + ) + if result.returncode != 0: + print(f"Return code: {result.returncode}") + print(f"STDOUT:\n{result.stdout}") + print(f"STDERR:\n{result.stderr}") + assert result.returncode == 0, f"Command failed with return code {result.returncode}\nSTDERR:\n{result.stderr}" + return result.stdout + + +def _distributed_training_cmd( + *, + path: Path, + max_steps: int, + val_check: int, + num_devices: int, + dp: int, + tp: int, + cp: int, + pp: int, + finetune_ckpt_dir: Path, + additional_args: str = "", +) -> str: + micro_batch_size = 1 if dp == 2 else 2 + return ( + f"torchrun --nproc-per-node {num_devices} --no-python train_evo2 " + f"--mock-data --result-dir {path} " + f"--hf-tokenizer-model-path {DEFAULT_HF_TOKENIZER_MODEL_PATH} " + "--model-size 7b_arc_longcontext --num-layers 4 --hybrid-override-pattern SDH* " + "--no-activation-checkpointing --optim-full-reshardable " + f"--finetune-ckpt-dir {finetune_ckpt_dir} " + f"--max-steps {max_steps} --eval-interval {val_check} --eval-iters 1 " + f"--seq-length 64 --hidden-dropout 0.0 --attention-dropout 0.0 " + f"--micro-batch-size {micro_batch_size} --global-batch-size 2 " + f"--tensor-model-parallel-size {tp} --pipeline-model-parallel-size {pp} --context-parallel-size {cp} " + "--adam-beta1 0 --adam-beta2 0 --ckpt-format torch_dist --log-interval 1 --decay-steps 1000 --warmup-steps 10 " + f"--seed 42 --dataset-seed 33 {additional_args}" + ) + + +@pytest.mark.timeout(300) +@pytest.mark.slow +@pytest.mark.parametrize( + "tp_size", + [ + pytest.param(1, id="tp_1_pretrain"), + pytest.param( + 2, + id="tp_2_pretrain", + marks=pytest.mark.skipif( + torch.cuda.device_count() < 2, reason="TP=2 requires at least 2 GPUs for pretraining." + ), + ), + ], +) +def test_fine_tuning( + tmp_path: Path, + tp_size: int, + cp_size: int = 1, + dp_size: int = 1, + final_tp: int = 1, + dp_rank_check: bool = True, + precision_recipe: str = "bf16_mixed", + pp_size: int = 1, +): + """Test fine-tuning functionality, which should mirror stop/go but reset optimizer, data, and training state.""" + world_size = tp_size * pp_size * cp_size * dp_size + mbs = 32 + gbs = mbs * dp_size + num_gpus = torch.cuda.device_count() + if world_size > num_gpus: + pytest.skip(f"World size {world_size} is greater than the number of GPUs {num_gpus}") + if "nvfp4" in precision_recipe and not is_fp4_supported(): + pytest.skip("NVFP4 is not supported on this device") + if "mxfp8" in precision_recipe and not is_mxfp8_supported(): + pytest.skip("MXFP8 is not supported on this device") + if "fp8" in precision_recipe and not is_fp8_supported(): + pytest.skip("FP8 is not supported on this device") + if "bf16_with_fp8_delayed_scaling_mixed" == precision_recipe and is_fp8_supported(): + pytest.xfail(reason="FP8 delayed scaling is not currently working with Evo2, use another FP8 recipe.") + if "bf16_with_fp8_subchannel_scaling_mixed" == precision_recipe and is_fp8_supported(): + pytest.xfail(reason="FP8 subchannel scaling is not currently working with Evo2 on some GPUs.") + run_dir = tmp_path / f"run_tp{tp_size}_pp{pp_size}_cp{cp_size}_dp{dp_size}_rc{dp_rank_check}_pr{precision_recipe}" + run_dir.mkdir(parents=True, exist_ok=True) + master_port = find_free_network_port() + dp_rank_check_str = "--debug-ddp-parity-freq 5" if dp_rank_check else "" + cmd1 = f"""torchrun --nproc-per-node {world_size} --no-python --master_port {master_port} \ + train_evo2 \ + --hf-tokenizer-model-path {DEFAULT_HF_TOKENIZER_MODEL_PATH} \ + --model-size striped_hyena_1b_nv_parallel --num-layers 4 --hybrid-override-pattern SDH* \ + --max-steps 5 --eval-interval 5 \ + --eval-iters 3 --mock-data --result-dir {run_dir} \ + --micro-batch-size {mbs} --global-batch-size {gbs} --seq-length 512 \ + --tensor-model-parallel {tp_size} \ + --pipeline-model-parallel {pp_size} \ + --context-parallel {cp_size} \ + --mixed-precision-recipe {precision_recipe} \ + --overlap-param-gather \ + --overlap-grad-reduce \ + {dp_rank_check_str} \ + --use-precision-aware-optimizer --dataset-seed 33 \ + --seed 41 --spike-no-more-embedding-init \ + --no-weight-decay-embeddings --cross-entropy-loss-fusion \ + --grad-reduce-in-fp32 \ + --decay-steps 1000 --warmup-steps 10 \ + --eod-pad-in-loss-mask \ + --log-interval 1 \ + """ + + # Split the command and run it + cmd_parts = shlex.split(cmd1) + env = copy.deepcopy(PRETEST_ENV) + if is_a6000_gpu(): + # Fix hanging issue on A6000 GPUs with multi-gpu tests + env["NCCL_P2P_DISABLE"] = "1" + result = subprocess.run(cmd_parts, check=False, capture_output=True, text=True, cwd=run_dir, env=env) + + stdout = result.stdout + stderr = result.stderr + returncode = result.returncode + + # For debugging, print the output + print(f"Return code: {returncode}") + print(f"STDOUT:\n{stdout}") + print(f"STDERR:\n{stderr}") + + # Assert the command succeeded + assert returncode == 0, f"Command failed with return code {returncode}\nSTDERR:\n{stderr}" + result_dir = run_dir / "evo2" + ckpt_dir = result_dir / "checkpoints" + tb_log_dir = result_dir / "tb_logs" + assert ckpt_dir.exists() and ckpt_dir.is_dir(), "Checkpoints directory not found" + assert tb_log_dir.exists() and tb_log_dir.is_dir(), "TensorBoard logs directory not found" + iter_5_dir = ckpt_dir / "iter_0000005" + assert iter_5_dir.exists() and iter_5_dir.is_dir(), f"No iterations 5 checkpoint found in {ckpt_dir}" + assert len(list(ckpt_dir.glob("iter_*"))) == 1, f"Expected 1 iterations, found {list(ckpt_dir.glob('iter_*'))}" + # Load tensorboard logs to verify they were written correctly + + # Find the events file(s) in tb_log_dir + event_files = list(tb_log_dir.rglob("events.out.*")) + assert len(event_files) > 0, f"No tensorboard event files found in {tb_log_dir}" + + # Load events from the event files + event_acc = EventAccumulator(str(tb_log_dir)) + event_acc.Reload() + + # 1. collect the last loss, as well as the average of the last step validation losses, as well as the last step + # Note: EventAccumulator.Scalars returns a list of ScalarEvent(wall_time, step, value) + lm_loss_events = event_acc.Scalars("lm loss") + + assert len(lm_loss_events) > 0, "No 'lm loss' events found in run 1" + last_lm_loss_step = lm_loss_events[-1].step + + assert last_lm_loss_step == 5, f"Expected run 1 to end at step 5, but got {last_lm_loss_step}" + + # 2. run the above training command a second time, this time set max_steps to 10. Verify that the run resumes from the last step. + # Do this by moving the tb_logs to a different directory from the first part so the second run makes fresh logs. + tb_log_dir_run1 = result_dir / "tb_logs_run1" + if tb_log_dir.exists(): + shutil.move(str(tb_log_dir), str(tb_log_dir_run1)) + + # Modify the command to increase max steps to 10 + # We reuse the same result_dir so it picks up the checkpoint + ft_run_dir = ( + tmp_path / f"ft_run_tp{tp_size}_pp{pp_size}_cp{cp_size}_dp{dp_size}_rc{dp_rank_check}_pr{precision_recipe}" + ) + ft_run_dir.mkdir(parents=True, exist_ok=True) + ft_world_size = final_tp * pp_size * cp_size * dp_size + cmd2 = ( + cmd1.rstrip() + .replace(f"--nproc-per-node {world_size}", f"--nproc-per-node {ft_world_size}") + .replace(f"--result-dir {run_dir}", f"--result-dir {ft_run_dir}") + .replace(f"--tensor-model-parallel {tp_size}", f"--tensor-model-parallel {final_tp}") + ) + cmd2 += f" --finetune-ckpt-dir {ckpt_dir} " + cmd_parts_2 = shlex.split(cmd2) + + print("Starting Run 2 (resuming to step 10)...") + result_2 = subprocess.run(cmd_parts_2, check=False, capture_output=True, text=True, cwd=run_dir, env=env) + + print(f"Run 2 Return code: {result_2.returncode}") + if result_2.returncode != 0: + print(f"Run 2 STDERR:\n{result_2.stderr}") + + assert result_2.returncode == 0, f"Run 2 failed with return code {result_2.returncode}" + + # 3. Load the new tb logs as before, and sanity check my recommendations as well as any others that make sense. + ft_result_dir = ft_run_dir / "evo2" + ft_tb_log_dir = ft_result_dir / "tb_logs" + assert ft_tb_log_dir.exists(), "TensorBoard logs directory not found after Run 2" + + event_acc_2 = EventAccumulator(str(ft_tb_log_dir)) + event_acc_2.Reload() + + lm_loss_events_2 = event_acc_2.Scalars("lm loss") + assert len(lm_loss_events_2) > 0, "No 'lm loss' events found in run 2" + + first_step_run2 = lm_loss_events_2[0].step + first_step_run1 = lm_loss_events[0].step + last_step_run2 = lm_loss_events_2[-1].step + + # Sanity checks: + # 1. Resumption: Should start after step 5 (e.g., step 6) + assert first_step_run2 == first_step_run1, ( + f"Run 2 FT steps should match run 1, but started at {first_step_run2} vs {first_step_run1}" + ) + + # 2. Completion: Should reach step 5 like run 1 + assert last_step_run2 == 5, f"Run 2 should reach step 5, but ended at {last_step_run2}" + + # 3. Loss Continuity check (basic): The first loss of run 2 should be reasonably close to the last loss of run 1, + # or at least not exploding, though optimization steps might cause fluctuations. + first_loss_run1 = lm_loss_events[0].value + first_loss_run2 = lm_loss_events_2[0].value + last_loss_run1 = lm_loss_events[-1].value + assert first_loss_run1 > last_loss_run1, ( + f"Run 1 first loss {first_loss_run1} is less than run 1 last loss {last_loss_run1}" + ) + assert first_loss_run2 < first_loss_run1, ( + f"Run 2 first loss {first_loss_run2} is greater than run 1 first loss {first_loss_run1}" + ) + assert abs(first_loss_run2 - first_loss_run1) > abs(last_loss_run1 - first_loss_run2), ( + f"Run 2 beginning {first_loss_run2} should be closer to end of run 1 {last_loss_run1} than beginning {first_loss_run1}." + ) + assert first_loss_run2 - last_loss_run1 < 0.1, ( + f"Run 2 first loss {first_loss_run2} is not better than run 1 last loss {last_loss_run1} by no worse than 0.1" + ) + + +@pytest.fixture(scope="session") +def mbridge_checkpoint_7b_1m(tmp_path_factory) -> Path: + """Session-scoped MBridge checkpoint for the 1b-8k-bf16 model. + + This fixture converts the NeMo2 checkpoint to MBridge format once per test session, + allowing it to be shared across multiple test files (test_infer.py, test_predict.py, etc.). + + Returns: + Path to the MBridge checkpoint iteration directory (e.g., .../iter_0000001) + """ + from bionemo.core.data.load import load + from bionemo.evo2.data.dataset_tokenizer import DEFAULT_HF_TOKENIZER_MODEL_PATH_512 + from bionemo.evo2.utils.checkpoint.nemo2_to_mbridge import run_nemo2_to_mbridge + + try: + nemo2_ckpt_path = load("evo2/7b-1m:1.0") + except ValueError as e: + if e.args[0].endswith("does not have an NGC URL."): + pytest.skip( + "Please re-run test with `BIONEMO_DATA_SOURCE=pbss py.test ...`, " + "one or more files are missing from ngc." + ) + else: + raise e + + output_dir = tmp_path_factory.mktemp("mbridge_checkpoint_7b_1m_session") + mbridge_ckpt_dir = run_nemo2_to_mbridge( + nemo2_ckpt_dir=nemo2_ckpt_path, + tokenizer_path=DEFAULT_HF_TOKENIZER_MODEL_PATH_512, + mbridge_ckpt_dir=output_dir / "evo2_7b_1m_mbridge", + model_size="7b_arc_longcontext", + seq_length=1_048_576, + mixed_precision_recipe="bf16_mixed", + vortex_style_fp8=False, + ) + # Return the parent directory (containing latest_train_state.pt), not the iter_0000001 subdirectory + # The checkpoint loading code looks for tracker files in the parent directory + return mbridge_ckpt_dir + + +@pytest.fixture(scope="session") +def base_checkpoint(tmp_path_factory: pytest.TempPathFactory, mbridge_checkpoint_7b_1m: Path) -> Path: + """Create a base checkpoint by training one step with no parallelism.""" + if torch.cuda.device_count() < 1: + pytest.skip("Test requires at least 1 GPU") + num_steps = 1 + tmp_path = tmp_path_factory.mktemp("base_checkpoint_session") + base_path = tmp_path / "base_training" + base_path.mkdir(parents=True, exist_ok=True) + + cmd = _distributed_training_cmd( + path=base_path, + max_steps=num_steps, + val_check=num_steps, + num_devices=1, + dp=1, + tp=1, + cp=1, + pp=1, + finetune_ckpt_dir=mbridge_checkpoint_7b_1m, + ) + _run_train_command(cmd, base_path) + + ckpt_dir = base_path / "evo2" / "checkpoints" / "iter_0000001" + assert ckpt_dir.exists() and ckpt_dir.is_dir(), f"Checkpoint dir not found: {ckpt_dir}" + return ckpt_dir + + +@pytest.mark.parametrize( + "dp,cp,tp,pp", + [ + pytest.param(2, 1, 1, 1, id="data_parallel"), + pytest.param(1, 2, 1, 1, id="context_parallel"), + pytest.param(1, 1, 2, 1, id="tensor_parallel"), + pytest.param(1, 1, 1, 2, id="pipeline_parallel"), + ], +) +@pytest.mark.timeout(900) +@pytest.mark.slow +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Test requires at least 2 GPUs") +@pytest.mark.skipif(bool(os.environ.get("CI")), reason="Skip in CI due to disk space limitations") +def test_distributed_training_gradient_equivalence( + tmp_path: Path, base_checkpoint: Path, mbridge_checkpoint_7b_1m: Path, dp, cp, tp, pp +): + """Test that optimizer states match across different distributed training strategies.""" + num_steps = 1 + num_devices = dp * cp * tp * pp + assert num_devices == 2, ( + f"Test is designed for 2 GPUs but got {num_devices} for dp={dp}, cp={cp}, tp={tp}, pp={pp}" + ) + + parallel_path = tmp_path / f"parallel_dp{dp}_cp{cp}_tp{tp}_pp{pp}" + parallel_path.mkdir(parents=True, exist_ok=True) + cmd = _distributed_training_cmd( + path=parallel_path, + max_steps=num_steps, + val_check=num_steps, + num_devices=num_devices, + dp=dp, + tp=tp, + cp=cp, + pp=pp, + finetune_ckpt_dir=mbridge_checkpoint_7b_1m, # must use the same checkpoint since PP/TP will have different RNG + additional_args=" --sequence-parallel " if tp > 1 else "", + ) + _run_train_command(cmd, parallel_path) + + parallel_checkpoint = parallel_path / "evo2" / "checkpoints" / "iter_0000001" + assert parallel_checkpoint.exists() and parallel_checkpoint.is_dir(), ( + f"Checkpoint dir not found: {parallel_checkpoint}" + ) + + checkpoint_dirs = [str(base_checkpoint), str(parallel_checkpoint)] + assert_optimizer_states_match(checkpoint_dirs) diff --git a/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/test_evo2.py b/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/test_evo2.py index 700419333b..2a8f9cea14 100644 --- a/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/test_evo2.py +++ b/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/test_evo2.py @@ -16,8 +16,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - -# FIXME bring back these tests, at least the batch_generate and forward pass correctness tests. import gc import inspect import logging @@ -38,16 +36,6 @@ from megatron.bridge.training.tokenizers.tokenizer import _HuggingFaceTokenizer, build_tokenizer from megatron.core import dist_checkpointing, parallel_state from megatron.core.dist_checkpointing.mapping import ShardedTensor - -# # FIXME copy these out or make them not depend on NeMo -# from bionemo.llm.utils.weight_utils import ( -# MegatronModelType, -# _key_in_filter, -# _munge_key_megatron_to_nemo2, -# _munge_sharded_tensor_key_megatron_to_nemo2, -# ) -# from bionemo.testing.megatron_parallel_state_utils import distributed_model_parallel_state -# from bionemo.testing.torch import check_fp8_support from megatron.core.tensor_parallel import random as tp_random from megatron.core.transformer.enums import AttnBackend from megatron.core.transformer.module import Float16Module @@ -63,6 +51,8 @@ ) from bionemo.evo2.utils.checkpoint.nemo2_to_mbridge import run_nemo2_to_mbridge +from .utils import check_fp8_support, find_free_network_port + logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) # Capture all levels in the logger itself @@ -73,22 +63,6 @@ DEFAULT_NCCL_TIMEOUT = "30" # in second -def find_free_network_port(address: str = "localhost") -> int: - """Finds a free port on localhost. - - It is useful in single-node training when we don't want to connect to a real master node but - have to set the `MASTER_PORT` environment variable. - """ - import socket - - s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - s.bind(("", 0)) - s.listen(1) - port = s.getsockname()[1] - s.close() - return port - - def _reset_microbatch_calculator(): """Resets _GLOBAL_NUM_MICROBATCHES_CALCULATOR in megatron which is used in NeMo to initilised model parallel in nemo.collections.nlp.modules.common.megatron.megatron_init.initialize_model_parallel_for_nemo @@ -196,21 +170,6 @@ def distributed_model_parallel_state( clean_up_distributed_and_parallel_states() -def check_fp8_support(device_id: int = 0) -> tuple[bool, str, str]: - """Check if FP8 is supported on the current GPU. - - FP8 requires compute capability 8.9+ (Ada Lovelace/Hopper architecture or newer). - """ - if not torch.cuda.is_available(): - return False, "0.0", "CUDA not available" - device_props = torch.cuda.get_device_properties(device_id) - compute_capability = f"{device_props.major}.{device_props.minor}" - device_name = device_props.name - # FP8 is supported on compute capability 8.9+ (Ada Lovelace/Hopper architecture) - is_supported = (device_props.major > 8) or (device_props.major == 8 and device_props.minor >= 9) - return is_supported, compute_capability, f"Device: {device_name}, Compute Capability: {compute_capability}" - - ############################################################################################# # Core utility functions: Below are some utility functions that allow for loading a nemo2 # trained model back into a newly initialized megatron core model. The key insight is that @@ -295,13 +254,13 @@ def determine_memory_requirement_and_skip_if_not_met(ckpt_name: str, test_name: "memory_needed_by_test": 21, }, # checked both variants in isolation { - "test_name": "test_batch_generate", + "test_name": "test_batch_generate_mbridge", "model_size": "1b", "seq_len_cap": -1, "memory_needed_by_test": 16, - }, # checked both variants in isolation + }, # checked both variants in isolation - needs ~21GB peak on L4 { - "test_name": "test_batch_generate", + "test_name": "test_batch_generate_mbridge", "model_size": "7b", "seq_len_cap": -1, "memory_needed_by_test": 43, @@ -310,26 +269,14 @@ def determine_memory_requirement_and_skip_if_not_met(ckpt_name: str, test_name: "test_name": "test_batch_generate_coding_sequences", "model_size": "1b", "seq_len_cap": -1, - "memory_needed_by_test": 6, + "memory_needed_by_test": 12, }, # checked both variants in isolation { "test_name": "test_batch_generate_coding_sequences", "model_size": "7b", "seq_len_cap": -1, - "memory_needed_by_test": 21, + "memory_needed_by_test": 28, }, # checked both variants in isolation - { - "test_name": "test_generate_speed", - "model_size": "1b", - "seq_len_cap": -1, - "memory_needed_by_test": -1, - }, # skipped for now until Anton's changes - { - "test_name": "test_generate_speed", - "model_size": "7b", - "seq_len_cap": -1, - "memory_needed_by_test": -1, - }, # skipped for now until Anton's changes ], columns=["test_name", "model_size", "seq_len_cap", "memory_needed_by_test"], ) @@ -381,174 +328,6 @@ def load_weights_sharded_inplace_nemo2_to_mcore( dist_checkpointing.load(sharded_state_dict, str(distributed_checkpoint_dir)) -# @pytest.mark.parametrize("seq_len", [8_192, 16_384]) -# def test_golden_values_top_k_logits_and_cosine_similarity(seq_len: int): -# try: -# evo2_1b_checkpoint_weights: Path = load("evo2/1b-8k:1.0") / "weights" -# gold_standard_no_fp8 = load("evo2/1b-8k-nofp8-te-goldvalue-testdata-A6000:1.0") -# except ValueError as e: -# if e.args[0].endswith("does not have an NGC URL."): -# raise ValueError( -# "Please re-run test with `BIONEMO_DATA_SOURCE=pbss py.test ...`, " -# "one or more files are missing from ngc." -# ) -# else: -# raise e -# with distributed_model_parallel_state(), torch.no_grad(): -# hyena_config = llm.Hyena1bConfig(use_te=True, seq_length=seq_len) -# tokenizer = get_nmt_tokenizer( -# "byte-level", -# ) -# raw_megatron_model = hyena_config.configure_model(tokenizer).eval().cuda() -# device = raw_megatron_model.parameters().__next__().device -# load_weights_sharded_inplace_nemo2_to_mcore(raw_megatron_model, evo2_1b_checkpoint_weights, {}, "torch_dist") -# model = Float16Module(hyena_config, raw_megatron_model) -# input_seq = "GAAATTAGCGCGTCCGGAATGATACGAGGGGAAACGAAATTTTGAATTAATGGAGAAAAAAGACGAGAAACCTTAAGCAAAAAAATTTTAGCTTCGAATATTTATTAATTTCTGAGATGTTGTTAAACGATTTTCGATTCCAAGTTGTGCGCACGAACGTTATTGCAAATAAATGCTGCTTATTCGGATGTTTCCACGATCTTTGTTGCAATGGTAGTCGAGTACCCGATAACCCAATTTCGTTACATCGGCCTATCTGTAGAATATCCAATCTATGGTTCATAAAAAATCTGATCGTTTGTTTTTAAGAAATTAAACGCGTTAAATTGAACGAATTTCGAATACCGGTCTTAGCGAAGGACCTCCCCTCTTGCTTGCGTATTGCCCCGCGAAATTTCTTTTCGGCGATGAACGATACAAAAAATTCTATCGAATGTTACTTCTATTCTCTGCCTCGTCTATGACTTGGAGATTGGTCTATGTCGTTCGTTTTCTCGCGAGTTTCCAATATGTCCGTAGTATGTGAACGCTGGTATTCGTGAAGATAAATTATTGTTTTTACAATTTCTTTCAAAAATATATAATTTTAATTTATATAAT" -# input_ids = torch.tensor(tokenizer.text_to_ids(input_seq)).int().unsqueeze(0).to(device) -# position_ids = torch.arange(len(input_seq)).unsqueeze(0).to(device) -# attention_mask = None -# outputs = model(input_ids=input_ids, position_ids=position_ids, attention_mask=attention_mask) -# gold_standard_no_fp8_tensor = torch.load(gold_standard_no_fp8).to(device=outputs.device, dtype=outputs.dtype) -# top_2_logits_golden = gold_standard_no_fp8_tensor.topk(dim=-1, sorted=True, largest=True, k=2) -# ambiguous_positions = ( -# top_2_logits_golden.values[..., 0] - top_2_logits_golden.values[..., 1] -# ).abs() < 9.9e-3 # hand tunes for observed diffs from A100 and H100 -# n_ambiguous = ambiguous_positions.sum() - -# assert n_ambiguous <= 19 - -# our_char_indices = outputs.softmax(dim=-1).argmax(dim=-1).flatten().detach().cpu().numpy() -# not_amb_positions = ~ambiguous_positions.flatten().cpu().numpy() -# # Generate our string, removing the ambiguous positions. -# our_generation_str = "".join([chr(idx) for idx in our_char_indices[not_amb_positions].tolist()]) -# # Do the same to the golden values -# gold_std_char_indices = ( -# gold_standard_no_fp8_tensor.softmax(dim=-1).argmax(dim=-1).flatten().detach().cpu().numpy() -# ) -# # Make the string -# gold_std_str = "".join([chr(idx) for idx in gold_std_char_indices[not_amb_positions].tolist()]) -# array_eq = np.array(list(our_generation_str)) == np.array(list(gold_std_str)) -# # Ensure the two strings are approximately equal. -# if array_eq.mean() < 0.95: -# array_eq = np.array(list(our_generation_str)) == np.array(list(gold_std_str)) -# mismatch_positions = np.arange(outputs.shape[1])[not_amb_positions][~array_eq] -# err_str = f"Fraction of expected mismatch positions exceeds 5%: {(~array_eq).mean()}" -# err_str += f"Mismatch positions: {mismatch_positions}" -# err_str += f"Fraction of unexpected mismatch positions: {(~array_eq).mean()}" -# top_two_logits_at_mismatch = top_2_logits_golden.values[0, mismatch_positions] -# top_2_logits_pred = outputs.topk(dim=-1, sorted=True, largest=True, k=2) -# top_two_pred_logits_at_mismatch = top_2_logits_pred.values[0, mismatch_positions] -# err_str += f"Top two logits at mismatch positions: {top_two_logits_at_mismatch}" -# err_str += f"Top two pred logits at mismatch positions: {top_two_pred_logits_at_mismatch}" -# raise AssertionError(err_str) - -# # Verify that the top-4 from the logit vectors are the same. -# # A: 65 -# # C: 67 -# # G: 71 -# # T: 84 -# # Find the corresponding ATGC and compare the two vectors with those four values. -# # Ensures that the top 4 ascii characters of the output are ACGT. -# top_4_inds = outputs.topk(dim=-1, sorted=False, largest=True, k=4) -# assert set(top_4_inds.indices.flatten().cpu().numpy().tolist()).issubset((65, 67, 71, 84)) -# output_vector = outputs[0, -1, top_4_inds.indices] - -# # Then its the top 4 indices of the gold standard tensor -# top_4_inds_golden = gold_standard_no_fp8_tensor.topk(dim=-1, sorted=False, largest=True, k=4) -# assert set(top_4_inds_golden.indices.flatten().cpu().numpy().tolist()).issubset((65, 67, 71, 84)) -# gold_standard_no_fp8_vector = gold_standard_no_fp8_tensor[0, -1, top_4_inds_golden.indices] - -# # Run cosine similarity between the two vectors. -# logit_similarity = torch.nn.functional.cosine_similarity(output_vector, gold_standard_no_fp8_vector, dim=-1) -# assert torch.mean(torch.abs(logit_similarity - torch.ones_like(logit_similarity))) < 0.03 - - -# @pytest.mark.skip(reason="test fails on main, not due to #1058") -# @pytest.mark.slow -# def test_golden_values_top_k_logits_and_cosine_similarity_7b(seq_len: int = 8_192): -# try: -# evo2_7b_checkpoint_weights: Path = load("evo2/7b-8k:1.0") / "weights" -# gold_standard_no_fp8 = load("evo2/7b-8k-nofp8-te-goldvalue-testdata:1.0") -# except ValueError as e: -# if e.args[0].endswith("does not have an NGC URL."): -# raise ValueError( -# "Please re-run test with `BIONEMO_DATA_SOURCE=pbss py.test ...`, " -# "one or more files are missing from ngc." -# ) -# else: -# raise e -# with distributed_model_parallel_state(), torch.no_grad(): -# hyena_config = llm.Hyena7bConfig(use_te=True, seq_length=seq_len) -# tokenizer = get_nmt_tokenizer( -# "byte-level", -# ) -# raw_megatron_model = hyena_config.configure_model(tokenizer).eval().cuda() -# device = raw_megatron_model.parameters().__next__().device -# load_weights_sharded_inplace_nemo2_to_mcore(raw_megatron_model, evo2_7b_checkpoint_weights, {}, "torch_dist") -# model = Float16Module(hyena_config, raw_megatron_model) -# input_seq = "GAAATTAGCGCGTCCGGAATGATACGAGGGGAAACGAAATTTTGAATTAATGGAGAAAAAAGACGAGAAACCTTAAGCAAAAAAATTTTAGCTTCGAATATTTATTAATTTCTGAGATGTTGTTAAACGATTTTCGATTCCAAGTTGTGCGCACGAACGTTATTGCAAATAAATGCTGCTTATTCGGATGTTTCCACGATCTTTGTTGCAATGGTAGTCGAGTACCCGATAACCCAATTTCGTTACATCGGCCTATCTGTAGAATATCCAATCTATGGTTCATAAAAAATCTGATCGTTTGTTTTTAAGAAATTAAACGCGTTAAATTGAACGAATTTCGAATACCGGTCTTAGCGAAGGACCTCCCCTCTTGCTTGCGTATTGCCCCGCGAAATTTCTTTTCGGCGATGAACGATACAAAAAATTCTATCGAATGTTACTTCTATTCTCTGCCTCGTCTATGACTTGGAGATTGGTCTATGTCGTTCGTTTTCTCGCGAGTTTCCAATATGTCCGTAGTATGTGAACGCTGGTATTCGTGAAGATAAATTATTGTTTTTACAATTTCTTTCAAAAATATATAATTTTAATTTATATAAT" -# input_ids = torch.tensor(tokenizer.text_to_ids(input_seq)).int().unsqueeze(0).to(device) -# position_ids = torch.arange(len(input_seq)).unsqueeze(0).to(device) -# attention_mask = None -# outputs = model(input_ids=input_ids, position_ids=position_ids, attention_mask=attention_mask) -# gold_standard_no_fp8_tensor = torch.load(gold_standard_no_fp8).to(device=outputs.device, dtype=outputs.dtype) -# is_fp8_supported, compute_capability, device_info = check_fp8_support(device.index) - -# if is_fp8_supported and compute_capability == "9.0": -# # Most rigurous assertion for output equivalence currently works on devices that are new enough to -# # support FP8. -# logger.info( -# f"Device {device_info} ({compute_capability}) supports FP8 with 9.0 compute capability, the " -# "same configuration as the gold standard was generated with. Running most rigurous assertion." -# ) -# torch.testing.assert_close(outputs, gold_standard_no_fp8_tensor) -# else: -# logger.info( -# f"Device {device_info} ({compute_capability}) does not support FP8. Running less rigurous assertions." -# ) -# top_2_logits_golden = gold_standard_no_fp8_tensor.topk(dim=-1, sorted=True, largest=True, k=2) -# ambiguous_positions = ( -# top_2_logits_golden.values[..., 0] - top_2_logits_golden.values[..., 1] -# ).abs() < 9.9e-3 # hand tunes for observed diffs from A100 and H100 with 7b model -# n_ambiguous = ambiguous_positions.sum() - -# assert n_ambiguous <= 19 - -# our_char_indices = outputs.softmax(dim=-1).argmax(dim=-1).flatten().detach().cpu().numpy() -# not_amb_positions = ~ambiguous_positions.flatten().cpu().numpy() -# # Generate our string, removing the ambiguous positions. -# our_generation_str = "".join([chr(idx) for idx in our_char_indices[not_amb_positions].tolist()]) -# # Do the same to the golden values -# gold_std_char_indices = ( -# gold_standard_no_fp8_tensor.softmax(dim=-1).argmax(dim=-1).flatten().detach().cpu().numpy() -# ) -# # Make the string -# gold_std_str = "".join([chr(idx) for idx in gold_std_char_indices[not_amb_positions].tolist()]) - -# # Ensure the two strings are equal. -# assert all(np.array(list(our_generation_str)) == np.array(list(gold_std_str))) - -# # Verify that the top-4 from the logit vectors are the same. -# # A: 65 -# # C: 67 -# # G: 71 -# # T: 84 -# # Find the corresponding ATGC and compare the two vectors with those four values. -# # Ensures that the top 4 ascii characters of the output are ACGT. -# top_4_inds = outputs.topk(dim=-1, sorted=False, largest=True, k=4) -# assert set(top_4_inds.indices.flatten().cpu().numpy().tolist()).issubset((65, 67, 71, 84)) -# output_vector = outputs[0, -1, top_4_inds.indices] - -# # Then its the top 4 indices of the gold standard tensor -# top_4_inds_golden = gold_standard_no_fp8_tensor.topk(dim=-1, sorted=False, largest=True, k=4) -# assert set(top_4_inds_golden.indices.flatten().cpu().numpy().tolist()).issubset((65, 67, 71, 84)) -# gold_standard_no_fp8_vector = gold_standard_no_fp8_tensor[0, -1, top_4_inds_golden.indices] - -# # Run cosine similarity between the two vectors. -# logit_similarity = torch.nn.functional.cosine_similarity(output_vector, gold_standard_no_fp8_vector, dim=-1) -# assert torch.mean(torch.abs(logit_similarity - torch.ones_like(logit_similarity))) < 9.9e-3 - - @pytest.fixture def sequences(): """Fixture that returns a list of sequences from the prompts.csv file.""" @@ -559,86 +338,17 @@ def sequences(): return [row["Sequence"] for row in reader] -# @pytest.fixture -# def coding_sequences(): -# with (Path(__file__).parent / "data" / "cds_prompts.csv").open(newline="") as f: -# from csv import DictReader - -# reader = DictReader(f) -# return [row["Sequence"] for row in reader] - - -# def get_trainer(pipeline_parallel=1): -# import nemo.lightning as nl - -# fp8 = True -# full_fp8 = False -# return nl.Trainer( -# accelerator="gpu", -# devices=pipeline_parallel, -# strategy=nl.MegatronStrategy( -# tensor_model_parallel_size=1, -# pipeline_model_parallel_size=pipeline_parallel, -# context_parallel_size=1, -# pipeline_dtype=torch.bfloat16, -# ckpt_load_optimizer=False, -# ckpt_save_optimizer=False, -# ckpt_async_save=False, -# save_ckpt_format="torch_dist", -# ckpt_load_strictness="log_all", -# ), -# log_every_n_steps=1, -# limit_val_batches=10, -# num_sanity_val_steps=0, -# plugins=nl.MegatronMixedPrecision( -# precision="bf16-mixed", -# params_dtype=torch.bfloat16, -# # Only use FP8 in this plugin when using full FP8 precision and FP8. -# # Otherwise use vortex_style_fp8 in the model config. -# fp8="hybrid" if fp8 and full_fp8 else None, -# fp8_amax_history_len=16 if fp8 and full_fp8 else 1, -# fp8_amax_compute_algo="max" if fp8 and full_fp8 else "most_recent", -# ), -# ) - - -# # here: pass arg through to inference_batch_times_seqlen_threshold and inference_max_seq_length -# def get_model_and_tokenizer_raw(ckpt_dir_or_name: Path | str, seq_len_max: int = 8192, **kwargs): -# """ -# Load a model and tokenizer from a checkpoint directory or name. If you supply a Path argument then we assume that -# the path is already a checkpoint directory, otherwise we load the checkpoint from NGC or PBSS depending on -# the environment variable BIONEMO_DATA_SOURCE. -# """ -# trainer = get_trainer() -# from bionemo.core.data.load import load - -# if isinstance(ckpt_dir_or_name, Path): -# ckpt_dir: Path = ckpt_dir_or_name -# else: -# ckpt_dir: Path = load(ckpt_dir_or_name) -# from nemo.collections.llm import inference - -# inference_wrapped_model, mcore_tokenizer = inference.setup_model_and_tokenizer( -# path=ckpt_dir, -# trainer=trainer, -# params_dtype=torch.bfloat16, -# inference_batch_times_seqlen_threshold=seq_len_max, -# inference_max_seq_length=seq_len_max, -# recompute_granularity=None, -# recompute_num_layers=None, -# recompute_method=None, -# **kwargs, -# ) -# return inference_wrapped_model, mcore_tokenizer - - -# def get_model_and_tokenizer(ckpt_name, vortex_style_fp8=False, seq_len_max: int = 8192, **kwargs): -# return get_model_and_tokenizer_raw(ckpt_name, vortex_style_fp8=vortex_style_fp8, seq_len_max=seq_len_max, **kwargs) - - -# def get_model_and_tokenizer_ignore_vortex(ckpt_name, vortex_style_fp8=False, seq_len_max: int = 8192, **kwargs): -# # Capture and remove the vortex_style_fp8 argument for mamba models. -# return get_model_and_tokenizer_raw(ckpt_name, seq_len_max=seq_len_max, **kwargs) +@pytest.fixture +def coding_sequences(): + """Fixture that returns coding sequences from the cds_prompts.csv file.""" + cds_file = Path(__file__).parent / "data" / "cds_prompts.csv" + if not cds_file.exists(): + pytest.skip(f"CDS prompts file not found: {cds_file}") + with cds_file.open(newline="") as f: + from csv import DictReader + + reader = DictReader(f) + return [row["Sequence"] for row in reader] def _calc_matchrate(*, tokenizer, in_seq, logits): @@ -672,10 +382,28 @@ def _check_matchrate(*, ckpt_name, matchrate, assert_matchrate=True): "ckpt_name,expected_matchpercents,flash_decode", [ # Try flash decode with one and not the other to verify that both paths work. - ("evo2/1b-8k-bf16:1.0", [96.27, 67.93, 77.50, 80.30], True), - ("evo2/1b-8k:1.0", [96.27, 67.93, 77.50, 80.30], False), - ("evo2/7b-8k:1.0", [97.60, 89.63, 80.03, 84.57], False), - ("evo2/7b-1m:1.0", [97.60, 89.63, 80.03, 84.57], False), + pytest.param("evo2/1b-8k-bf16:1.0", [96.27, 67.93, 77.50, 80.30], True, id="1b-8k-bf16"), + pytest.param( + "evo2/1b-8k:1.0", + [96.27, 67.93, 77.50, 80.30], + False, + id="1b-8k", + marks=pytest.mark.skipif(bool(os.environ.get("CI")), reason="Skip in CI due to disk space"), + ), + pytest.param( + "evo2/7b-8k:1.0", + [97.60, 89.63, 80.03, 84.57], + False, + id="7b-8k", + marks=pytest.mark.skipif(bool(os.environ.get("CI")), reason="Skip in CI due to disk space"), + ), + pytest.param( + "evo2/7b-1m:1.0", + [97.60, 89.63, 80.03, 84.57], + False, + id="7b-1m", + marks=pytest.mark.skipif(bool(os.environ.get("CI")), reason="Skip in CI due to disk space"), + ), ], ) def test_forward_manual(sequences: list[str], ckpt_name: str, expected_matchpercents: list[float], flash_decode: bool): @@ -776,10 +504,28 @@ def test_forward_manual(sequences: list[str], ckpt_name: str, expected_matchperc "ckpt_name,expected_matchpercents,flash_decode", [ # Try flash decode with one and not the other to verify that both paths work. - ("evo2/1b-8k-bf16:1.0", [96.27, 67.93, 77.50, 80.30], True), - ("evo2/1b-8k:1.0", [96.27, 67.93, 77.50, 80.30], False), - ("evo2/7b-8k:1.0", [97.60, 89.63, 80.03, 84.57], False), - ("evo2/7b-1m:1.0", [97.60, 89.63, 80.03, 84.57], False), + pytest.param("evo2/1b-8k-bf16:1.0", [96.27, 67.93, 77.50, 80.30], True, id="1b-8k-bf16"), + pytest.param( + "evo2/1b-8k:1.0", + [96.27, 67.93, 77.50, 80.30], + False, + id="1b-8k", + marks=pytest.mark.skipif(bool(os.environ.get("CI")), reason="Skip in CI due to disk space"), + ), + pytest.param( + "evo2/7b-8k:1.0", + [97.60, 89.63, 80.03, 84.57], + False, + id="7b-8k", + marks=pytest.mark.skipif(bool(os.environ.get("CI")), reason="Skip in CI due to disk space"), + ), + pytest.param( + "evo2/7b-1m:1.0", + [97.60, 89.63, 80.03, 84.57], + False, + id="7b-1m", + marks=pytest.mark.skipif(bool(os.environ.get("CI")), reason="Skip in CI due to disk space"), + ), ], ) def test_forward_ckpt_conversion( @@ -868,275 +614,311 @@ def test_forward_ckpt_conversion( ) -# def mid_point_split(*, seq, num_tokens: int | None = None, fraction: float = 0.5): -# mid_point = int(fraction * len(seq)) -# prompt = seq[:mid_point] -# if num_tokens is not None: -# target = seq[mid_point : mid_point + num_tokens] # Only compare to the section of sequence directly -# else: -# target = seq[mid_point:] -# return prompt, target - - -# def calculate_sequence_identity(seq1: str, seq2: str) -> float | None: -# """Calculate sequence identity between two sequences through direct comparison.""" -# if not seq1 or not seq2: -# return None - -# # Direct comparison of sequences -# min_length = min(len(seq1), len(seq2)) -# matches = sum(a == b for a, b in zip(seq1[:min_length], seq2[:min_length])) - -# return (matches / min_length) * 100 - - -# @pytest.mark.parametrize( -# "ckpt_name,model_tokenizer_provider,expected_matchpercents", -# [ -# ("evo2/1b-8k-bf16:1.0", get_model_and_tokenizer, [96.8, 29.7, 76.6, 71.6]), -# ("evo2/1b-8k:1.0", get_model_and_tokenizer, [96.8, 29.7, 76.6, 71.6]), -# ("evo2_mamba/7b-8k:0.1", get_model_and_tokenizer_ignore_vortex, [99.2, 51.0, 73.0, 82.6]), -# ("evo2/7b-8k:1.0", get_model_and_tokenizer, [97.60, 89.63, 80.03, 84.57]), -# ("evo2/7b-1m:1.0", get_model_and_tokenizer, [97.60, 89.63, 80.03, 84.57]), -# ], -# ) -# def test_batch_generate( -# sequences: list[str], ckpt_name: str, model_tokenizer_provider: Callable, expected_matchpercents: list[float] -# ): -# assert len(sequences) > 0 -# _ = determine_memory_requirement_and_skip_if_not_met(ckpt_name, test_name=inspect.currentframe().f_code.co_name) - -# is_fp8_supported, compute_capability, device_info = check_fp8_support(torch.cuda.current_device()) -# skip = "evo2/1b-8k:" in ckpt_name and not is_fp8_supported -# if skip: -# # This checkpoint is sensitive to FP8, so we skip it if it is not supported on the current device. -# pytest.skip(f"Skipping {ckpt_name} because it is not supported on {device_info} ({compute_capability})") -# if "evo2_mamba" in ckpt_name and os.environ.get("BIONEMO_DATA_SOURCE") != "pbss": -# # TODO: add evo2_mamba/7b-8k to NGC and remove this skip -# pytest.skip(f"Skipping {ckpt_name} because it is not on NGC yet. Run with `BIONEMO_DATA_SOURCE=pbss`.") -# # only use vortex_style_fp8 for non-bf16 checkpoints with fp8 support -# vortex_style_fp8 = is_fp8_supported and "bf16" not in ckpt_name - -# num_tokens = 500 -# seq_prompts = [mid_point_split(seq=seq, num_tokens=num_tokens) for seq in sequences] -# seq_len_max = num_tokens + max([len(sq[0]) for sq in seq_prompts]) -# inference_wrapped_model, mcore_tokenizer = model_tokenizer_provider( -# ckpt_name, -# vortex_style_fp8=vortex_style_fp8, -# seq_len_max=seq_len_max, -# ) - -# results = generate( -# model=inference_wrapped_model, -# max_batch_size=1, # vortex only supports batch size 1 -# tokenizer=mcore_tokenizer, -# prompts=[sq[0] for sq in seq_prompts], -# random_seed=42, -# inference_params=CommonInferenceParams( -# temperature=1.0, -# top_k=1, -# top_p=0.0, -# return_log_probs=False, -# num_tokens_to_generate=num_tokens, -# ), -# ) - -# match_percents = [] -# for i, (result, (prompt, target)) in enumerate(zip(results, seq_prompts)): -# gen_seq = result.generated_text -# logging.info(f"{ckpt_name} {torch.distributed.get_rank()=} {gen_seq=}") -# logging.info(f"{ckpt_name} {torch.distributed.get_rank()=} {target=}") -# match_percent = calculate_sequence_identity(target, gen_seq) -# logging.info( -# f"{ckpt_name} {torch.distributed.get_rank()=} {match_percent=} expected: {expected_matchpercents[i]}" -# ) -# match_percents.append(match_percent) - -# assert len(match_percents) == len(expected_matchpercents) -# matchperc_print = [f"{mp:.1f}%" for mp in match_percents] -# matchperc_print_expected = [f"{ep:.1f}%" for ep in expected_matchpercents] -# assert all(mp >= 0.90 * ep for mp, ep in zip(match_percents, expected_matchpercents)), ( -# f"Expected at least 90% of {matchperc_print_expected=}, got {matchperc_print=}" -# ) - - -# @pytest.mark.parametrize( -# "ckpt_name,model_tokenizer_provider,expected_matchpercents", -# [ -# ("evo2/1b-8k-bf16:1.0", get_model_and_tokenizer, [86.4, 78.8, 49.7]), -# ("evo2/1b-8k:1.0", get_model_and_tokenizer, [86.4, 78.8, 49.7]), -# ("evo2_mamba/7b-8k:0.1", get_model_and_tokenizer_ignore_vortex, [86.5, 88.4, 88.2]), -# ("evo2/7b-8k:1.0", get_model_and_tokenizer, [88.8, 88.5, 82.2]), -# ("evo2/7b-1m:1.0", get_model_and_tokenizer, [88.8, 88.5, 82.2]), -# ], -# ) -# def test_batch_generate_coding_sequences( -# coding_sequences: list[str], -# ckpt_name: str, -# model_tokenizer_provider: Callable, -# expected_matchpercents: list[float], -# ): -# assert len(coding_sequences) > 0 -# determine_memory_requirement_and_skip_if_not_met(ckpt_name, test_name=inspect.currentframe().f_code.co_name) - -# is_fp8_supported, compute_capability, device_info = check_fp8_support(torch.cuda.current_device()) -# skip = "evo2/1b-8k:" in ckpt_name and not is_fp8_supported -# if skip: -# # This checkpoint is sensitive to FP8, so we skip it if it is not supported on the current device. -# pytest.skip(f"Skipping {ckpt_name} because it is not supported on {device_info} ({compute_capability})") -# if "evo2_mamba" in ckpt_name and os.environ.get("BIONEMO_DATA_SOURCE") != "pbss": -# # TODO: add evo2_mamba/7b-8k to NGC and remove this skip -# pytest.skip(f"Skipping {ckpt_name} because it is not on NGC yet. Run with `BIONEMO_DATA_SOURCE=pbss`.") -# # only use vortex_style_fp8 for non-bf16 checkpoints with fp8 support -# vortex_style_fp8 = is_fp8_supported and "bf16" not in ckpt_name - -# match_percents: list[float] = [] -# cds_lengths: list[int | None] = [] -# original_cds_lengths: list[int] = [len(seq) for seq in coding_sequences] -# seq_prompts = [mid_point_split(seq=seq, num_tokens=None, fraction=0.3) for seq in coding_sequences] -# num_tokens = max(len(sq[1]) for sq in seq_prompts) + 15 - -# inference_wrapped_model, mcore_tokenizer = model_tokenizer_provider( -# ckpt_name, vortex_style_fp8=vortex_style_fp8, enable_flash_decode=True, flash_decode=True -# ) - -# _ = generate( -# model=inference_wrapped_model, -# max_batch_size=1, # vortex only supports batch size 1 -# tokenizer=mcore_tokenizer, -# prompts=["AAACCC"], -# random_seed=42, -# inference_params=CommonInferenceParams( -# temperature=1.0, -# top_k=1, -# top_p=0.0, -# return_log_probs=False, -# num_tokens_to_generate=1, -# ), -# ) -# results = generate( -# model=inference_wrapped_model, -# max_batch_size=1, # vortex only supports batch size 1 -# tokenizer=mcore_tokenizer, -# prompts=[sq[0] for sq in seq_prompts], -# random_seed=42, -# inference_params=CommonInferenceParams( -# temperature=1.0, -# top_k=1, -# top_p=0.0, -# return_log_probs=False, -# num_tokens_to_generate=num_tokens, -# ), -# ) - -# for i, (result, (prompt, target)) in enumerate(zip(results, seq_prompts)): -# gen_seq = result.generated_text -# logging.info(f"{ckpt_name} {torch.distributed.get_rank()=} {gen_seq=}") -# logging.info(f"{ckpt_name} {torch.distributed.get_rank()=} {target=}") -# full_seq = prompt + gen_seq -# stop_codons = {"TAA", "TAG", "TGA"} -# assert full_seq[:3] == "ATG" # start codon -# cds_length = None -# for codon_start in range(0, len(full_seq), 3): -# codon = full_seq[codon_start : codon_start + 3] -# if codon in stop_codons: -# cds_length = codon_start + 3 -# break -# match_percent = calculate_sequence_identity(target, gen_seq) -# logging.info( -# f"{ckpt_name} {torch.distributed.get_rank()=} {match_percent=} expected: {expected_matchpercents[i]}" -# ) -# match_percents.append(match_percent) -# cds_lengths.append(cds_length) -# # 99% of the time, you have a stop codon within the first 96 codons if everything were random. - -# assert len(match_percents) == len(expected_matchpercents) -# assert len(cds_lengths) == len(original_cds_lengths) -# matchperc_print = [f"{mp:.1f}%" for mp in match_percents] -# matchperc_print_expected = [f"{ep:.1f}%" for ep in expected_matchpercents] -# # By chance you expect to have a stop codon within the first 96 codons if everything were random -# # so verify that we are putting the first stop codon after this point, as well as it being at least 90% of the -# # original sequence length. -# assert all( -# pcl is None or ((pcl - len(pmpt) > 96 * 3 or len(tgt) < 96 * 3) and pcl >= 0.9 * ocl) -# for pcl, ocl, (pmpt, tgt) in zip(cds_lengths, original_cds_lengths, seq_prompts) -# ), f"Expected at least 70% of {original_cds_lengths=}, got {cds_lengths=}" -# assert all(mp >= 0.90 * ep for mp, ep in zip(match_percents, expected_matchpercents)), ( -# f"Expected at least 90% of {matchperc_print_expected=}, got {matchperc_print=}" -# ) - - -# @pytest.mark.skip( -# reason="skip the test for now, and decide what to do after getting Anton's changes sorted and merged." -# ) -# @pytest.mark.slow -# @pytest.mark.parametrize( -# "ckpt_name,model_tokenizer_provider,expected_tokens_sec", -# [ -# ("evo2/1b-8k-bf16:1.0", get_model_and_tokenizer, 41.0), -# ("evo2/1b-8k:1.0", get_model_and_tokenizer, 41.0), -# ("evo2_mamba/7b-8k:0.1", get_model_and_tokenizer_ignore_vortex, 39.73), -# ("evo2/7b-8k:1.0", get_model_and_tokenizer, 32.0), -# ("evo2/7b-1m:1.0", get_model_and_tokenizer, 32.0), -# ], -# ) -# def test_generate_speed( -# ckpt_name: str, -# model_tokenizer_provider: Callable, -# expected_tokens_sec: float, -# ): -# is_fp8_supported, compute_capability, device_info = check_fp8_support(torch.cuda.current_device()) -# determine_memory_requirement_and_skip_if_not_met(ckpt_name, test_name=inspect.currentframe().f_code.co_name) - -# skip = "evo2/1b-8k:" in ckpt_name and not is_fp8_supported -# if skip: -# # This checkpoint is sensitive to FP8, so we skip it if it is not supported on the current device. -# pytest.skip(f"Skipping {ckpt_name} because it is not supported on {device_info} ({compute_capability})") -# if "evo2_mamba" in ckpt_name and os.environ.get("BIONEMO_DATA_SOURCE") != "pbss": -# # TODO: add evo2_mamba/7b-8k to NGC and remove this skip -# pytest.skip(f"Skipping {ckpt_name} because it is not on NGC yet. Run with `BIONEMO_DATA_SOURCE=pbss`.") -# # only use vortex_style_fp8 for non-bf16 checkpoints with fp8 support -# vortex_style_fp8 = is_fp8_supported and "bf16" not in ckpt_name -# inference_wrapped_model, mcore_tokenizer = model_tokenizer_provider( -# ckpt_name, -# vortex_style_fp8=vortex_style_fp8, -# fp32_residual_connection=False, -# enable_flash_decode=True, -# flash_decode=True, -# ) - -# # warm up the model with a single call before timing. This should take care of compilation etc. -# _ = generate( -# model=inference_wrapped_model, -# max_batch_size=1, # vortex only supports batch size 1 -# tokenizer=mcore_tokenizer, -# prompts=["AAACCC"], -# random_seed=42, -# inference_params=CommonInferenceParams( -# temperature=1.0, -# top_k=1, -# top_p=0.0, -# return_log_probs=False, -# num_tokens_to_generate=1, -# ), -# ) -# t0 = time.perf_counter_ns() -# results = generate( -# model=inference_wrapped_model, -# max_batch_size=1, # vortex only supports batch size 1 -# tokenizer=mcore_tokenizer, -# prompts=["A"], -# random_seed=42, -# inference_params=CommonInferenceParams( -# temperature=1.0, -# top_k=1, -# top_p=0.0, -# return_log_probs=False, -# num_tokens_to_generate=300, -# ), -# ) -# dt = (time.perf_counter_ns() - t0) / 1e9 # seconds -# tokens_per_sec = (len(results[0].generated_text) + 1) / dt # +1 for the prompt -# assert tokens_per_sec > expected_tokens_sec * 0.85, ( -# f"Expected at least {expected_tokens_sec} tokens/sec, got {tokens_per_sec}" -# ) +def mid_point_split(*, seq, num_tokens: int | None = None, fraction: float = 0.5): + """Split a sequence at a midpoint for prompt/target evaluation.""" + mid_point = int(fraction * len(seq)) + prompt = seq[:mid_point] + if num_tokens is not None: + target = seq[mid_point : mid_point + num_tokens] # Only compare to the section of sequence directly + else: + target = seq[mid_point:] + return prompt, target + + +def calculate_sequence_identity(seq1: str, seq2: str) -> float | None: + """Calculate sequence identity between two sequences through direct comparison.""" + if not seq1 or not seq2: + return None + + # Direct comparison of sequences + min_length = min(len(seq1), len(seq2)) + matches = sum(a == b for a, b in zip(seq1[:min_length], seq2[:min_length])) + + return (matches / min_length) * 100 + + +@pytest.mark.timeout(900) +@pytest.mark.slow +@pytest.mark.parametrize( + "ckpt_name,expected_matchpercents,fp8", + [ + pytest.param("evo2/1b-8k-bf16:1.0", [86.4, 78.8, 49.7], False, id="1b-bf16_bf16"), + pytest.param("evo2/1b-8k-bf16:1.0", [86.4, 78.8, 49.7], True, id="1b-bf16_fp8"), + pytest.param( + "evo2/1b-8k:1.0", + [86.4, 78.8, 49.7], + True, + id="1b_fp8", + marks=pytest.mark.skipif(bool(os.environ.get("CI")), reason="Skip in CI due to disk space"), + ), + pytest.param( + "evo2/7b-8k:1.0", + [88.8, 88.5, 82.2], + False, + id="7b-8k_bf16", + marks=pytest.mark.skipif(bool(os.environ.get("CI")), reason="Skip in CI due to disk space"), + ), + pytest.param( + "evo2/7b-1m:1.0", + [88.8, 88.5, 82.2], + False, + id="7b-1m_bf16", + marks=pytest.mark.skipif(bool(os.environ.get("CI")), reason="Skip in CI due to disk space"), + ), + ], +) +def test_batch_generate_coding_sequences( + coding_sequences: list[str], + tmp_path: Path, + ckpt_name: str, + expected_matchpercents: list[float], + fp8: bool, +): + """Test generation on coding sequences using MCore inference infrastructure. + + This test validates that the model can generate reasonable coding sequence + continuations, checking for proper stop codon placement and sequence identity. + """ + from bionemo.evo2.run.infer import generate, setup_inference_engine + + assert len(coding_sequences) > 0 + + # Check memory availability + try: + _ = determine_memory_requirement_and_skip_if_not_met( + ckpt_name, test_name="test_batch_generate_coding_sequences" + ) + except KeyError: + gb_available = torch.cuda.mem_get_info()[0] / 1024**3 + if gb_available < 16: + pytest.skip(f"Insufficient GPU memory: {gb_available:.1f}GB available, need at least 16GB") + + is_fp8_supported, compute_capability, device_info = check_fp8_support(torch.cuda.current_device()) + if fp8 and not is_fp8_supported: + pytest.skip(f"Skipping {ckpt_name} - FP8 not supported on {device_info} ({compute_capability})") + + # Use bf16 checkpoint to avoid FP8 issues with single-token generation + if "bf16" not in ckpt_name and not fp8: + pytest.skip(f"Skipping {ckpt_name} - use bf16 checkpoint or enable FP8 for this test") + + # Prepare prompts and targets + seq_prompts = [mid_point_split(seq=seq, num_tokens=None, fraction=0.3) for seq in coding_sequences] + num_tokens = max(len(sq[1]) for sq in seq_prompts) + 15 + original_cds_lengths: list[int] = [len(seq) for seq in coding_sequences] + + vortex_style_fp8 = ckpt_name == "evo2/1b-8k:1.0" and fp8 + mixed_precision_recipe = "bf16_with_fp8_current_scaling_mixed" if fp8 and not vortex_style_fp8 else "bf16_mixed" + + with distributed_model_parallel_state(), torch.no_grad(): + # Convert checkpoint to MBridge format + nemo2_ckpt_path = load(ckpt_name) + mbridge_ckpt_dir = run_nemo2_to_mbridge( + nemo2_ckpt_dir=nemo2_ckpt_path, + tokenizer_path=DEFAULT_HF_TOKENIZER_MODEL_PATH_512, + mbridge_ckpt_dir=tmp_path / "mbridge_checkpoint", + model_size="1b" if "1b" in ckpt_name else "7b_arc_longcontext" if "7b-1m" in ckpt_name else "7b", + seq_length=8192, + mixed_precision_recipe=mixed_precision_recipe, + vortex_style_fp8=vortex_style_fp8, + ) + mbridge_ckpt_path = mbridge_ckpt_dir / "iter_0000001" + + # Extract prompts for generation + prompts = [split[0] for split in seq_prompts] + + # Setup MCore inference engine with batch size matching number of prompts + batch_size = len(prompts) // 2 + components = setup_inference_engine( + ckpt_dir=mbridge_ckpt_path, + max_seq_length=8192, + max_batch_size=batch_size, + tensor_parallel_size=1, + random_seed=42, + ) + + # Generate all sequences - engine handles iteration internally + results = generate( + components, + prompts=prompts, + max_new_tokens=num_tokens, + temperature=1.0, + top_k=1, # Greedy for determinism + ) + + # Process results + match_percents: list[float] = [] + cds_lengths: list[int | None] = [] + stop_codons = {"TAA", "TAG", "TGA"} + + for i, (result, (prompt, target)) in enumerate(zip(results, seq_prompts)): + gen_seq = result.generated_text if result else "" + logger.info(f"{ckpt_name} {gen_seq=}") + logger.info(f"{ckpt_name} {target=}") + + full_seq = prompt + gen_seq + assert full_seq[:3] == "ATG", f"Expected start codon ATG, got {full_seq[:3]}" + + # Find first stop codon + cds_length = None + for codon_start in range(0, len(full_seq), 3): + codon = full_seq[codon_start : codon_start + 3] + if codon in stop_codons: + cds_length = codon_start + 3 + break + if cds_length is None: + logger.warning(f"{ckpt_name} {gen_seq=} no stop codon found") + cds_length = len(full_seq) + match_percent: float = calculate_sequence_identity(target, gen_seq) or 0.0 + logger.info(f"{ckpt_name} {match_percent=} expected: {expected_matchpercents[i]}") + match_percents.append(match_percent) + cds_lengths.append(cds_length) + + # Verify results + assert len(match_percents) == len(expected_matchpercents) + assert len(cds_lengths) == len(original_cds_lengths) + matchperc_print = [f"{mp:.1f}%" for mp in match_percents] + matchperc_print_expected = [f"{ep:.1f}%" for ep in expected_matchpercents] + + # By chance you expect to have a stop codon within the first 96 codons if everything were random + # so verify that we are putting the first stop codon after this point, as well as it being at least 90% of the + # original sequence length. + assert all( + pcl is None or ((pcl - len(pmpt) > 96 * 3 or len(tgt) < 96 * 3) and pcl >= 0.90 * ocl) + for pcl, ocl, (pmpt, tgt) in zip(cds_lengths, original_cds_lengths, seq_prompts) + ), f"Expected at least 90% of {original_cds_lengths=}, got {cds_lengths=}" + + assert all(mp >= 0.90 * ep for mp, ep in zip(match_percents, expected_matchpercents)), ( + f"Expected at least 90% of {matchperc_print_expected=}, got {matchperc_print=}" + ) + + +# ============================================================================= +# MBridge-based generation tests using HyenaInferenceContext +# ============================================================================= + + +@pytest.mark.timeout(900) +@pytest.mark.slow +@pytest.mark.parametrize( + "ckpt_name,expected_matchpercents,fp8", + [ + pytest.param("evo2/1b-8k-bf16:1.0", [96.8, 29.7, 76.6, 71.6], False, id="1b-bf16_bf16"), + pytest.param("evo2/1b-8k-bf16:1.0", [96.8, 29.7, 76.6, 71.6], True, id="1b-bf16_fp8"), + pytest.param( + "evo2/1b-8k:1.0", + [96.8, 29.7, 76.6, 71.6], + True, + id="1b_fp8", + marks=pytest.mark.skipif(bool(os.environ.get("CI")), reason="Skip in CI due to disk space"), + ), + pytest.param( + "evo2/7b-8k:1.0", + [97.60, 89.63, 80.03, 84.57], + True, + id="7b-8k_fp8", + marks=pytest.mark.skipif(bool(os.environ.get("CI")), reason="Skip in CI due to disk space"), + ), + pytest.param( + "evo2/7b-1m:1.0", + [97.60, 89.63, 80.03, 84.57], + False, + id="7b-1m_bf16", + marks=pytest.mark.skipif(bool(os.environ.get("CI")), reason="Skip in CI due to disk space"), + ), + ], +) +def test_batch_generate_mbridge( + sequences: list[str], + tmp_path: Path, + ckpt_name: str, + expected_matchpercents: list[float], + fp8: bool, +): + """Test autoregressive generation using MCore inference infrastructure. + + This test validates that the model can generate reasonable continuations + of DNA sequences using the StaticInferenceEngine and TextGenerationController. + + Note: Hyena/Evo2 SSM state caching currently only supports batch size 1, + so prompts are processed sequentially. The MCore inference engine handles + this internally through legacy mode. + + Uses the same expected values as the original NeMo test_batch_generate. + """ + from bionemo.evo2.run.infer import generate, setup_inference_engine + + assert len(sequences) > 0 + + # Check memory availability (use test_batch_generate requirements as proxy) + try: + _ = determine_memory_requirement_and_skip_if_not_met(ckpt_name, test_name="test_batch_generate_mbridge") + except KeyError: + # If no entry exists, check basic memory availability + gb_available = torch.cuda.mem_get_info()[0] / 1024**3 + if gb_available < 16: + pytest.skip(f"Insufficient GPU memory: {gb_available:.1f}GB available, need at least 16GB") + + is_fp8_supported, compute_capability, device_info = check_fp8_support(torch.cuda.current_device()) + if fp8 and not is_fp8_supported: + pytest.skip(f"Skipping {ckpt_name} - FP8 not supported on {device_info} ({compute_capability})") + + num_tokens_to_generate = 500 # Match original test + vortex_style_fp8 = ckpt_name == "evo2/1b-8k:1.0" and fp8 + mixed_precision_recipe = "bf16_with_fp8_current_scaling_mixed" if fp8 and not vortex_style_fp8 else "bf16_mixed" + + with distributed_model_parallel_state(), torch.no_grad(): + # Convert checkpoint to MBridge format + nemo2_ckpt_path = load(ckpt_name) + mbridge_ckpt_dir = run_nemo2_to_mbridge( + nemo2_ckpt_dir=nemo2_ckpt_path, + tokenizer_path=DEFAULT_HF_TOKENIZER_MODEL_PATH_512, + mbridge_ckpt_dir=tmp_path / "mbridge_checkpoint", + model_size="1b" if "1b" in ckpt_name else "7b_arc_longcontext" if "7b-1m" in ckpt_name else "7b", + seq_length=8192, + mixed_precision_recipe=mixed_precision_recipe, + vortex_style_fp8=vortex_style_fp8, + ) + mbridge_ckpt_path = mbridge_ckpt_dir / "iter_0000001" + + # Split all sequences at midpoint to get prompts and targets + seq_splits = [mid_point_split(seq=seq, num_tokens=num_tokens_to_generate, fraction=0.5) for seq in sequences] + prompts = [split[0] for split in seq_splits] + targets = [split[1] for split in seq_splits] + + # Setup MCore inference engine + # Note: max_batch_size=1 due to Hyena SSM state constraints, but engine handles iteration + components = setup_inference_engine( + ckpt_dir=mbridge_ckpt_path, + max_seq_length=8192, + max_batch_size=1, # 1 because this test takes more memory. + tensor_parallel_size=1, + random_seed=42, + ) + + # Generate all sequences - engine handles iteration internally with max_batch_size=1 + results = generate( + components, + prompts=prompts, + max_new_tokens=num_tokens_to_generate, + temperature=1.0, + top_k=1, # Greedy for determinism + ) + + # Calculate match percentages for each result + match_percents: list[float] = [] + for i, (result, target) in enumerate(zip(results, targets)): + generated_text = result.generated_text if result else "" + match_percent = calculate_sequence_identity(target, generated_text) + if match_percent is not None: + match_percents.append(match_percent) + logger.info( + f"{ckpt_name} seq[{i}] identity: {match_percent:.1f}% expected: {expected_matchpercents[i]:.1f}%" + ) + + # Use original assertion style - expect at least 90% of expected values + assert len(match_percents) == len(expected_matchpercents) + matchperc_print = [f"{mp:.1f}%" for mp in match_percents] + matchperc_print_expected = [f"{ep:.1f}%" for ep in expected_matchpercents] + assert all(mp >= 0.90 * ep for mp, ep in zip(match_percents, expected_matchpercents)), ( + f"Expected at least 90% of {matchperc_print_expected=}, got {matchperc_print=}" + ) diff --git a/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/test_prompt.py b/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/test_prompt.py deleted file mode 100644 index c1bd4a4b51..0000000000 --- a/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/test_prompt.py +++ /dev/null @@ -1,140 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-FileCopyrightText: Copyright (c) 2024 Arc Institute. All rights reserved. -# SPDX-FileCopyrightText: Copyright (c) 2024 Michael Poli. All rights reserved. -# SPDX-FileCopyrightText: Copyright (c) 2024 Stanford University. All rights reserved -# SPDX-License-Identifier: LicenseRef-Apache2 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# FIXME bring back these tests -# from dataclasses import dataclass - -# import pytest -# from bionemo.core.data.load import load -# from bionemo.evo2.run.infer import infer -# from megatron.core.inference.inference_request import InferenceRequest - - -# RANDOM_SEED = 42 -# MAX_NEW_TOKENS = 500 -# TEMPERATURE = 1.0 -# TOP_K = 0 -# TOP_P = 0.0 - -# # todo: figure out 1M checkpoints (or add to NGC) -# CHECKPOINT_NAMES = [ -# "evo2/1b-8k-bf16:1.0", -# # "evo2/7b-8k:1.0", -# # "evo2/7b-1m:1.0", -# ] - - -# PROMPT_1 = "GAATAGGAACAGCTCCGGTCTACAGCTCCCAGCGTGAGCGACGCAGAAGACGGTGATTTCTGCATTTCCATCTGAGGTACCGGGTTCATCTCACTAGGGAGTGCCAGACAGTGGGCGCAGGCCAGTGTGTGTGCGCACCGTGCGCGAGCCGAAGCAGGG" - -# PROMPT_2 = "GATCACAGGTCTATCACCCTATTAACCACTCACGGGAGCTCTCCATGCATTTGGTATTTTCGTCTGGGGGGTATGCACGCGATAGCATTGCGAGACGCTGGAGCCGGAGCACCCTATGTCGCAGTATCTGTCTTTGATTCCTGCCTCATCCTATTATTT" - - -# @dataclass -# class InferCofig: -# """Configuration for model inference parameters.""" - -# temperature: float = TEMPERATURE -# top_k: int = TOP_K -# top_p: float = TOP_P -# tensor_parallel_size: int = 1 -# pipeline_model_parallel_size: int = 1 -# context_parallel_size: int = 1 -# max_new_tokens: int = MAX_NEW_TOKENS -# ckpt_format: str = "torch_dist" -# seed: int = RANDOM_SEED -# flash_decode: bool = False - - -# _checkpoint_cache = {} - - -# @pytest.fixture(scope="session") -# def load_checkpoint(): -# """Factory function that returns a checkpoint loader with caching.""" - -# def _load_checkpoint(ckpt_name: str) -> str: -# if ckpt_name not in _checkpoint_cache: -# _checkpoint_cache[ckpt_name] = load(ckpt_name) -# return _checkpoint_cache[ckpt_name] - -# return _load_checkpoint - - -# def percent_equal_tokens(response1: list[InferenceRequest], response2: list[InferenceRequest]) -> float: -# """Percent of tokens that are equal between two responses.""" -# num_equal = [i == j for i, j in zip(response1[0].generated_tokens, response2[0].generated_tokens)] -# return sum(num_equal) / len(num_equal) - - -# # just a DRY wrapper for the infer function -# def run_inference(prompt: str, checkpoint_path: str, config: InferCofig) -> list[InferenceRequest]: -# """Run model inference with given parameters. - -# Args: -# prompt: Input prompt for the model -# checkpoint_path: Path to model checkpoint -# config: Inference configuration parameters - -# Returns: -# Model response -# """ -# return infer( -# prompt=prompt, -# ckpt_dir=checkpoint_path, -# temperature=config.temperature, -# top_k=config.top_k, -# top_p=config.top_p, -# max_new_tokens=config.max_new_tokens, -# tensor_parallel_size=config.tensor_parallel_size, -# pipeline_model_parallel_size=config.pipeline_model_parallel_size, -# context_parallel_size=config.context_parallel_size, -# output_file=None, -# ckpt_format=config.ckpt_format, -# seed=config.seed, -# flash_decode=config.flash_decode, -# ) - - -# @pytest.mark.parametrize("ckpt_name", CHECKPOINT_NAMES) -# def test_identical_prompts_should_be_identical(load_checkpoint, ckpt_name): -# """Test that identical prompts produce identical sequences for temperature 1.0.""" -# checkpoint_path = load_checkpoint(ckpt_name) - -# # with clean_parallel_state_context(): -# response_prompt1 = run_inference(PROMPT_1, checkpoint_path, InferCofig()) -# response_prompt2 = run_inference(PROMPT_1, checkpoint_path, InferCofig()) - -# sequence_similarity = percent_equal_tokens(response_prompt1, response_prompt2) -# print(f"sequence similarity {ckpt_name} identical prompts: {sequence_similarity}") -# assert sequence_similarity == 1.0 - - -# @pytest.mark.parametrize("ckpt_name", CHECKPOINT_NAMES) -# def test_different_prompts_too_similar(load_checkpoint, ckpt_name): -# """Test that different prompts for the same sequence are too similar. -# That is, different prompts should produce more varied sequences. -# """ -# checkpoint_path = load_checkpoint(ckpt_name) - -# similarity_threshold = 0.9 - -# # with clean_parallel_state_context(): -# response_prompt1 = run_inference(PROMPT_1, checkpoint_path, InferCofig()) -# response_prompt2 = run_inference(PROMPT_2, checkpoint_path, InferCofig()) -# sequence_similarity = percent_equal_tokens(response_prompt1, response_prompt2) -# assert sequence_similarity <= similarity_threshold diff --git a/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/test_stop_and_go.py b/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/test_stop_and_go.py index e58eda0928..c13fa2f43d 100644 --- a/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/test_stop_and_go.py +++ b/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/test_stop_and_go.py @@ -26,59 +26,13 @@ from bionemo.evo2.data.dataset_tokenizer import DEFAULT_HF_TOKENIZER_MODEL_PATH +from .utils import find_free_network_port, is_fp4_supported, is_fp8_supported, is_mxfp8_supported + # Do this at collection time before we run any tests. PRETEST_ENV = copy.deepcopy(os.environ) -def get_compute_capability() -> tuple[int, int]: - """Get the compute capability of the current device.""" - if not torch.cuda.is_available(): - return (0, 0) - # Returns a tuple, e.g., (9, 0) for H100 - return torch.cuda.get_device_capability() - - -# 1. FP8 Support Logic -# Supported on Ada Lovelace (8.9) and Hopper (9.0+) -def is_fp8_supported() -> bool: - """Check if FP8 is supported on the current device.""" - cc = get_compute_capability() - return cc >= (8, 9) - - -# 2. FP4 Support Logic -# Native support requires Blackwell (10.0+) -def is_fp4_supported() -> bool: - """Check if FP4 is supported on the current device.""" - cc = get_compute_capability() - return (10, 0) <= cc < (12, 0) - - -# 3. MXFP8 Support Logic -# Native support requires Blackwell (10.0+) -def is_mxfp8_supported() -> bool: - """Check if MXFP8 is supported on the current device.""" - cc = get_compute_capability() - return (10, 0) <= cc < (12, 0) - - -def find_free_network_port() -> int: - """Finds a free port on localhost. - - It is useful in single-node training when we don't want to connect to a real master node but - have to set the `MASTER_PORT` environment variable. - """ - import socket - - s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - s.bind(("", 0)) - s.listen(1) - port = s.getsockname()[1] - s.close() - return port - - @pytest.mark.parametrize( "tp_size,cp_size,dp_size,dp_rank_check,precision_recipe", [ @@ -86,7 +40,6 @@ def find_free_network_port() -> int: (1, 1, 1, False, "bf16_with_fp8_current_scaling_mixed"), (1, 1, 1, False, "bf16_with_fp8_delayed_scaling_mixed"), # XFAIL (1, 1, 1, False, "bf16_with_fp8_subchannel_scaling_mixed"), - (1, 1, 1, False, "nanov2_bf16_with_fp8_current_scaling_mixed"), (1, 1, 1, False, "bf16_with_nvfp4_mixed"), # XFAIL other than blackwell+ (1, 1, 1, False, "bf16_with_mxfp8_mixed"), # XFAIL other than blackwell+ (1, 1, 2, True, "bf16_mixed"), @@ -119,9 +72,9 @@ def test_stop_and_go( if "fp8" in precision_recipe and not is_fp8_supported(): pytest.skip("FP8 is not supported on this device") if "bf16_with_fp8_delayed_scaling_mixed" == precision_recipe and is_fp8_supported(): - pytest.xfail(reason="FP8 delayed scaling is not currently working with Evo2, use another FP8 recipe.") + pytest.skip(reason="FP8 delayed scaling is not currently working with Evo2, use another FP8 recipe.") if "bf16_with_fp8_subchannel_scaling_mixed" == precision_recipe and is_fp8_supported(): - pytest.xfail(reason="FP8 subchannel scaling is not currently working with Evo2 on some GPUs.") + pytest.skip(reason="FP8 subchannel scaling is not currently working with Evo2 on some GPUs.") run_dir = tmp_path / f"run_tp{tp_size}_pp{pp_size}_cp{cp_size}_dp{dp_size}_rc{dp_rank_check}_pr{precision_recipe}" run_dir.mkdir(parents=True, exist_ok=True) master_port = find_free_network_port() @@ -240,169 +193,3 @@ def test_stop_and_go( assert first_loss_run2 - last_loss_run1 < 0.1, ( f"Run 2 first loss {first_loss_run2} is not better than run 1 last loss {last_loss_run1} by no worse than 0.1" ) - - -@pytest.mark.slow -def test_fine_tuning( - tmp_path: Path, - tp_size: int = 1, - cp_size: int = 1, - dp_size: int = 1, - dp_rank_check: bool = True, - precision_recipe: str = "bf16_mixed", - pp_size: int = 1, -): - """Test fine-tuning functionality, which should mirror stop/go but reset optimizer, data, and training state.""" - world_size = tp_size * pp_size * cp_size * dp_size - mbs = 32 - gbs = mbs * dp_size - num_gpus = torch.cuda.device_count() - if world_size > num_gpus: - pytest.skip(f"World size {world_size} is greater than the number of GPUs {num_gpus}") - if "nvfp4" in precision_recipe and not is_fp4_supported(): - pytest.skip("NVFP4 is not supported on this device") - if "mxfp8" in precision_recipe and not is_mxfp8_supported(): - pytest.skip("MXFP8 is not supported on this device") - if "fp8" in precision_recipe and not is_fp8_supported(): - pytest.skip("FP8 is not supported on this device") - if "bf16_with_fp8_delayed_scaling_mixed" == precision_recipe and is_fp8_supported(): - pytest.xfail(reason="FP8 delayed scaling is not currently working with Evo2, use another FP8 recipe.") - if "bf16_with_fp8_subchannel_scaling_mixed" == precision_recipe and is_fp8_supported(): - pytest.xfail(reason="FP8 subchannel scaling is not currently working with Evo2 on some GPUs.") - run_dir = tmp_path / f"run_tp{tp_size}_pp{pp_size}_cp{cp_size}_dp{dp_size}_rc{dp_rank_check}_pr{precision_recipe}" - run_dir.mkdir(parents=True, exist_ok=True) - master_port = find_free_network_port() - dp_rank_check_str = "--debug-ddp-parity-freq 5" if dp_rank_check else "" - cmd1 = f"""torchrun --nproc-per-node {world_size} --no-python --master_port {master_port} \ - train_evo2 \ - --hf-tokenizer-model-path {DEFAULT_HF_TOKENIZER_MODEL_PATH} \ - --model-size striped_hyena_1b_nv_parallel --num-layers 4 --hybrid-override-pattern SDH* \ - --max-steps 5 --eval-interval 5 \ - --eval-iters 3 --mock-data --result-dir {run_dir} \ - --micro-batch-size {mbs} --global-batch-size {gbs} --seq-length 512 \ - --tensor-model-parallel {tp_size} \ - --pipeline-model-parallel {pp_size} \ - --context-parallel {cp_size} \ - --mixed-precision-recipe {precision_recipe} \ - --overlap-param-gather \ - --overlap-grad-reduce \ - {dp_rank_check_str} \ - --use-precision-aware-optimizer --dataset-seed 33 \ - --seed 41 --spike-no-more-embedding-init \ - --no-weight-decay-embeddings --cross-entropy-loss-fusion \ - --grad-reduce-in-fp32 \ - --decay-steps 1000 --warmup-steps 10 \ - --eod-pad-in-loss-mask \ - --log-interval 1 \ - """ - - # Split the command and run it - cmd_parts = shlex.split(cmd1) - env = copy.deepcopy(PRETEST_ENV) - env["NCCL_P2P_DISABLE"] = "1" - result = subprocess.run(cmd_parts, check=False, capture_output=True, text=True, cwd=run_dir, env=env) - - stdout = result.stdout - stderr = result.stderr - returncode = result.returncode - - # For debugging, print the output - print(f"Return code: {returncode}") - print(f"STDOUT:\n{stdout}") - print(f"STDERR:\n{stderr}") - - # Assert the command succeeded - assert returncode == 0, f"Command failed with return code {returncode}\nSTDERR:\n{stderr}" - result_dir = run_dir / "evo2" - ckpt_dir = result_dir / "checkpoints" - tb_log_dir = result_dir / "tb_logs" - assert ckpt_dir.exists() and ckpt_dir.is_dir(), "Checkpoints directory not found" - assert tb_log_dir.exists() and tb_log_dir.is_dir(), "TensorBoard logs directory not found" - iter_5_dir = ckpt_dir / "iter_0000005" - assert iter_5_dir.exists() and iter_5_dir.is_dir(), f"No iterations 5 checkpoint found in {ckpt_dir}" - assert len(list(ckpt_dir.glob("iter_*"))) == 1, f"Expected 1 iterations, found {list(ckpt_dir.glob('iter_*'))}" - # Load tensorboard logs to verify they were written correctly - - # Find the events file(s) in tb_log_dir - event_files = list(tb_log_dir.rglob("events.out.*")) - assert len(event_files) > 0, f"No tensorboard event files found in {tb_log_dir}" - - # Load events from the event files - event_acc = EventAccumulator(str(tb_log_dir)) - event_acc.Reload() - - # 1. collect the last loss, as well as the average of the last step validation losses, as well as the last step - # Note: EventAccumulator.Scalars returns a list of ScalarEvent(wall_time, step, value) - lm_loss_events = event_acc.Scalars("lm loss") - - assert len(lm_loss_events) > 0, "No 'lm loss' events found in run 1" - last_lm_loss_step = lm_loss_events[-1].step - - assert last_lm_loss_step == 5, f"Expected run 1 to end at step 5, but got {last_lm_loss_step}" - - # 2. run the above training command a second time, this time set max_steps to 10. Verify that the run resumes from the last step. - # Do this by moving the tb_logs to a different directory from the first part so the second run makes fresh logs. - tb_log_dir_run1 = result_dir / "tb_logs_run1" - if tb_log_dir.exists(): - shutil.move(str(tb_log_dir), str(tb_log_dir_run1)) - - # Modify the command to increase max steps to 10 - # We reuse the same result_dir so it picks up the checkpoint - ft_run_dir = ( - tmp_path / f"ft_run_tp{tp_size}_pp{pp_size}_cp{cp_size}_dp{dp_size}_rc{dp_rank_check}_pr{precision_recipe}" - ) - ft_run_dir.mkdir(parents=True, exist_ok=True) - cmd2 = cmd1.rstrip().replace(f"--result-dir {run_dir}", f"--result-dir {ft_run_dir}") - cmd2 += f" --finetune-ckpt-dir {ckpt_dir} " - cmd_parts_2 = shlex.split(cmd2) - - print("Starting Run 2 (resuming to step 10)...") - result_2 = subprocess.run(cmd_parts_2, check=False, capture_output=True, text=True, cwd=run_dir, env=env) - - print(f"Run 2 Return code: {result_2.returncode}") - if result_2.returncode != 0: - print(f"Run 2 STDERR:\n{result_2.stderr}") - - assert result_2.returncode == 0, f"Run 2 failed with return code {result_2.returncode}" - - # 3. Load the new tb logs as before, and sanity check my recommendations as well as any others that make sense. - ft_result_dir = ft_run_dir / "evo2" - ft_tb_log_dir = ft_result_dir / "tb_logs" - assert ft_tb_log_dir.exists(), "TensorBoard logs directory not found after Run 2" - - event_acc_2 = EventAccumulator(str(ft_tb_log_dir)) - event_acc_2.Reload() - - lm_loss_events_2 = event_acc_2.Scalars("lm loss") - assert len(lm_loss_events_2) > 0, "No 'lm loss' events found in run 2" - - first_step_run2 = lm_loss_events_2[0].step - first_step_run1 = lm_loss_events[0].step - last_step_run2 = lm_loss_events_2[-1].step - - # Sanity checks: - # 1. Resumption: Should start after step 5 (e.g., step 6) - assert first_step_run2 == first_step_run1, ( - f"Run 2 FT steps should match run 1, but started at {first_step_run2} vs {first_step_run1}" - ) - - # 2. Completion: Should reach step 5 like run 1 - assert last_step_run2 == 5, f"Run 2 should reach step 5, but ended at {last_step_run2}" - - # 3. Loss Continuity check (basic): The first loss of run 2 should be reasonably close to the last loss of run 1, - # or at least not exploding, though optimization steps might cause fluctuations. - first_loss_run1 = lm_loss_events[0].value - first_loss_run2 = lm_loss_events_2[0].value - last_loss_run1 = lm_loss_events[-1].value - assert first_loss_run1 > last_loss_run1, ( - f"Run 1 first loss {first_loss_run1} is less than run 1 last loss {last_loss_run1}" - ) - assert first_loss_run2 < first_loss_run1, ( - f"Run 2 first loss {first_loss_run2} is greater than run 1 first loss {first_loss_run1}" - ) - assert abs(first_loss_run2 - first_loss_run1) > abs(last_loss_run1 - first_loss_run2), ( - f"Run 2 beginning {first_loss_run2} should be closer to end of run 1 {last_loss_run1} than beginning {first_loss_run1}." - ) - assert first_loss_run2 - last_loss_run1 < 0.1, ( - f"Run 2 first loss {first_loss_run2} is not better than run 1 last loss {last_loss_run1} by no worse than 0.1" - ) diff --git a/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/utils.py b/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/utils.py new file mode 100644 index 0000000000..8793071d05 --- /dev/null +++ b/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/utils.py @@ -0,0 +1,90 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared test utilities for evo2 tests.""" + +import socket + +import torch + + +def find_free_network_port(address: str = "localhost") -> int: + """Find a free port on localhost for distributed testing.""" + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind((address, 0)) + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + return s.getsockname()[1] + + +def get_compute_capability() -> tuple[int, int]: + """Get the compute capability of the current device.""" + if not torch.cuda.is_available(): + return (0, 0) + # Returns a tuple, e.g., (9, 0) for H100 + return torch.cuda.get_device_capability() + + +def is_fp8_supported() -> bool: + """Check if FP8 is supported on the current device. + + FP8 is supported on Ada Lovelace (8.9) and Hopper (9.0+). + """ + cc = get_compute_capability() + return cc >= (8, 9) + + +def is_fp4_supported() -> bool: + """Check if FP4 is supported on the current device. + + Native support requires Blackwell (10.0+). + """ + cc = get_compute_capability() + return (10, 0) <= cc < (12, 0) + + +def is_mxfp8_supported() -> bool: + """Check if MXFP8 is supported on the current device. + + Native support requires Blackwell (10.0+). + """ + cc = get_compute_capability() + return (10, 0) <= cc < (12, 0) + + +def check_fp8_support(device_id: int = 0) -> tuple[bool, str, str]: + """Check if FP8 is supported on the current GPU. + + FP8 requires compute capability 8.9+ (Ada Lovelace/Hopper architecture or newer). + + Returns: + Tuple of (is_supported, compute_capability_string, device_info_message). + """ + if not torch.cuda.is_available(): + return False, "0.0", "CUDA not available" + device_props = torch.cuda.get_device_properties(device_id) + compute_capability = f"{device_props.major}.{device_props.minor}" + device_name = device_props.name + # FP8 is supported on compute capability 8.9+ (Ada Lovelace/Hopper architecture) + is_supported = (device_props.major > 8) or (device_props.major == 8 and device_props.minor >= 9) + return is_supported, compute_capability, f"Device: {device_name}, Compute Capability: {compute_capability}" + + +def is_a6000_gpu() -> bool: + """Check if any of the visible GPUs is an A6000.""" + for i in range(torch.cuda.device_count()): + device_name = torch.cuda.get_device_name(i) + if "A6000" in device_name: + return True + return False diff --git a/bionemo-recipes/recipes/evo2_megatron/tokenizers/nucleotide_fast_tokenizer_256/tokenizer.json b/bionemo-recipes/recipes/evo2_megatron/tokenizers/nucleotide_fast_tokenizer_256/tokenizer.json index 1a90f8c370..1dfd9641ca 100644 --- a/bionemo-recipes/recipes/evo2_megatron/tokenizers/nucleotide_fast_tokenizer_256/tokenizer.json +++ b/bionemo-recipes/recipes/evo2_megatron/tokenizers/nucleotide_fast_tokenizer_256/tokenizer.json @@ -136,7 +136,9 @@ } } }, - "decoder": null, + "decoder": { + "type": "Fuse" + }, "model": { "type": "WordLevel", "vocab": { diff --git a/bionemo-recipes/recipes/evo2_megatron/tokenizers/nucleotide_fast_tokenizer_512/tokenizer.json b/bionemo-recipes/recipes/evo2_megatron/tokenizers/nucleotide_fast_tokenizer_512/tokenizer.json index 29d00fd74e..045b220cbc 100644 --- a/bionemo-recipes/recipes/evo2_megatron/tokenizers/nucleotide_fast_tokenizer_512/tokenizer.json +++ b/bionemo-recipes/recipes/evo2_megatron/tokenizers/nucleotide_fast_tokenizer_512/tokenizer.json @@ -109,7 +109,9 @@ } } }, - "decoder": null, + "decoder": { + "type": "Fuse" + }, "model": { "type": "WordLevel", "vocab": {