Skip to content

[BUG] Evo2 infer_evo2 - RuntimeError: If key is supplied, seqlens_k must also be passed #1015

@not-a-feature

Description

@not-a-feature

BioNeMo Framework Version

v2.6.3

Bug Description

Hi there,

thanks for releasing v2.6.3.
Unfortunately with that BioNeMo version, the infer_evo2 command is broken and i can't generate new sequences.
It seems to be a problem with the KV Cache of FlashAttention:

File "/usr/local/lib/python3.12/dist-packages/flash_attn/flash_attn_interface.py", line 1589, in flash_attn_with_kvcache
[rank0]:     out, softmax_lse = flash_attn_gpu.fwd_kvcache(
[rank0]:                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: RuntimeError: If key is supplied, seqlens_k must also be passed in

predict_evo2 seems to work fine.

Steps to Reproduce

EVO2_CONTAINER_VERSION=2.6.3
EVO2_CONTAINER="bionemo-framework_${EVO2_CONTAINER_VERSION}.sif"
MODEL_DIR="nemo2_evo2_7b_8k"

# Pull the new container
singularity pull docker://nvcr.io/nvidia/clara/bionemo-framework:$EVO2_CONTAINER_VERSION

# Download and convert the EVO2 model.
singularity exec --nv $EVO2_CONTAINER evo2_convert_to_nemo2 --model-path "hf://arcinstitute/savanna_evo2_7b_base" --model-size 7b --output-dir $MODEL_DIR

singularity exec --nv $EVO2_CONTAINER infer_evo2 --prompt AGTAGTAGTATGATAGTAGT --ckpt-dir $MODEL_DIR --max-new-tokens 10

Error Messages and Logs

singularity exec --nv --bind /home/pfeifer/ppu738/Evo2-human-dev:/home/pfeifer/ppu738/Evo2-human-dev --bind /home/pfeifer/ppu738/Evo2-human-dev/scratch_local:/scratch_local bionemo-framework_2.6.3.sif infer_evo2 --prompt AGTAGTAGTATGATAGTAGT --ckpt-dir results/evo2
/checkpoints/epoch\=0-step\=0-consumed_samples\=0.0/ --max-new-tokens 10
15:4: not a valid test operator:
15:4: not a valid test operator: 12.9
21:4: not a valid test operator: (
21:4: not a valid test operator: 535.129.03
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
[NeMo I 2025-08-01 10:14:25 nemo_logging:393] Enabling Flash Decode for in-framework inference
[NeMo I 2025-08-01 10:14:25 nemo_logging:393] Rank 0 has data parallel group : [0]
[NeMo I 2025-08-01 10:14:25 nemo_logging:393] Rank 0 has combined group of data parallel and context parallel : [0]
[NeMo I 2025-08-01 10:14:25 nemo_logging:393] All data parallel group ranks with context parallel combined: [[0]]
[NeMo I 2025-08-01 10:14:25 nemo_logging:393] Ranks 0 has data parallel rank: 0
[NeMo I 2025-08-01 10:14:25 nemo_logging:393] Rank 0 has context parallel group: [0]
[NeMo I 2025-08-01 10:14:25 nemo_logging:393] All context parallel group ranks: [[0]]
[NeMo I 2025-08-01 10:14:25 nemo_logging:393] Ranks 0 has context parallel rank: 0
[NeMo I 2025-08-01 10:14:25 nemo_logging:393] Rank 0 has model parallel group: [0]
[NeMo I 2025-08-01 10:14:25 nemo_logging:393] All model parallel group ranks: [[0]]
[NeMo I 2025-08-01 10:14:25 nemo_logging:393] Rank 0 has tensor model parallel group: [0]
[NeMo I 2025-08-01 10:14:25 nemo_logging:393] All tensor model parallel group ranks: [[0]]
[NeMo I 2025-08-01 10:14:25 nemo_logging:393] Rank 0 has tensor model parallel rank: 0
[NeMo I 2025-08-01 10:14:25 nemo_logging:393] Rank 0 has pipeline model parallel group: [0]
[NeMo I 2025-08-01 10:14:25 nemo_logging:393] Rank 0 has embedding group: [0]
[NeMo I 2025-08-01 10:14:25 nemo_logging:393] All pipeline model parallel group ranks: [[0]]
[NeMo I 2025-08-01 10:14:25 nemo_logging:393] Rank 0 has pipeline model parallel rank 0
[NeMo I 2025-08-01 10:14:25 nemo_logging:393] All embedding group ranks: [[0]]
[NeMo I 2025-08-01 10:14:25 nemo_logging:393] Rank 0 has embedding rank: 0
INFO:pytorch_lightning.utilities.rank_zero:----------------------------------------------------------------------------------------------------
distributed_backend=nccl
All distributed processes registered. Starting with 1 processes
----------------------------------------------------------------------------------------------------

[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[NeMo I 2025-08-01 10:14:25 nemo_logging:393] Padded vocab_size: 512, original vocab_size: 512, dummy tokens: 0.
WARNING:megatron.core.tensor_parallel.random:CPU RNG state changed within GPU RNG context
WARNING:megatron.core.tensor_parallel.random:CPU RNG state changed within GPU RNG context
WARNING:megatron.core.tensor_parallel.random:CPU RNG state changed within GPU RNG context
WARNING:megatron.core.tensor_parallel.random:CPU RNG state changed within GPU RNG context
WARNING:megatron.core.tensor_parallel.random:CPU RNG state changed within GPU RNG context
WARNING:megatron.core.tensor_parallel.random:CPU RNG state changed within GPU RNG context
WARNING:megatron.core.tensor_parallel.random:CPU RNG state changed within GPU RNG context
WARNING:megatron.core.tensor_parallel.random:CPU RNG state changed within GPU RNG context
WARNING:megatron.core.tensor_parallel.random:CPU RNG state changed within GPU RNG context
[NeMo I 2025-08-01 10:14:26 nemo_logging:393]  > number of parameters on (tensor, pipeline) model parallel rank (0 ,0): 6481649408
[NeMo I 2025-08-01 10:14:26 nemo_logging:393] Doing selective restore from RestoreConfig(path='results/evo2/checkpoints/epoch=0-step=0-consumed_samples=0.0/', load_model_state=True, load_optim_state=False, load_artifacts=True)
[NeMo I 2025-08-01 10:14:26 nemo_logging:393] Using <megatron.core.dist_checkpointing.strategies.fully_parallel.FullyParallelLoadStrategyWrapper object at 0x155127b69eb0> dist-ckpt load strategy.
[NeMo I 2025-08-01 10:14:31 nemo_logging:393] Global Checkpoint Load : Rank : 0 : Start time : 1754036066.601s : Time spent in load_checkpoint: 4.890s
[NeMo I 2025-08-01 10:14:31 nemo_logging:393] Restoring model weights from RestoreConfig(path='results/evo2/checkpoints/epoch=0-step=0-consumed_samples=0.0/', load_model_state=True, load_optim_state=False, load_artifacts=True)
[NeMo I 2025-08-01 10:14:31 nemo_logging:393] Finished restoring from RestoreConfig(path='results/evo2/checkpoints/epoch=0-step=0-consumed_samples=0.0/', load_model_state=True, load_optim_state=False, load_artifacts=True), cleaning up.
static requests:   0%|                                                                                                                | 0/1 [00:00<?, ?it/s]WARNING:DotProductAttention:flash-attn v3 may provide important feature support or performance improvement. Please install flash-attn v3 by
(1) git clone https://github.com/Dao-AILab/flash-attention.git
(2) cd flash-attention/ && git checkout 27f501d && cd hopper/ && python setup.py install
(3) python_path=`python -c "import site; print(site.getsitepackages()[0])"`
(4) mkdir -p $python_path/flash_attn_3
(5) wget -P $python_path/flash_attn_3 https://raw.githubusercontent.com/Dao-AILab/flash-attention/27f501dbe011f4371bff938fe7e09311ab3002fa/hopper/flash_attn_interface.py
[rank0]: Traceback (most recent call last):
[rank0]:   File "/usr/local/bin/infer_evo2", line 10, in <module>
[rank0]:     sys.exit(main())
[rank0]:              ^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/bionemo/evo2/run/infer.py", line 198, in main
[rank0]:     infer(
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/bionemo/evo2/run/infer.py", line 167, in infer
[rank0]:     results = generate(
[rank0]:               ^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/nemo/collections/llm/api.py", line 1162, in generate
[rank0]:     results_on_this_dp_rank = inference.generate(
[rank0]:                               ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/nemo/collections/llm/inference/base.py", line 296, in generate
[rank0]:     results = mcore_engine.generate(
[rank0]:               ^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/megatron/core/inference/engines/static_engine.py", line 191, in generate
[rank0]:     self.run_engine()
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/megatron/core/inference/engines/static_engine.py", line 225, in run_engine
[rank0]:     self.text_generation_controller.generate_all_output_tokens_static_batch(
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank0]:     return func(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/megatron/core/inference/text_generation_controllers/text_generation_controller.py", line 765, in generate_all_output_tokens_static_batch
[rank0]:     logits = self.inference_wrapped_model.run_one_forward_step(
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank0]:     return func(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/megatron/core/inference/model_inference_wrappers/abstract_model_inference_wrapper.py", line 375, in run_one_forward_step
[rank0]:     return self.forward_pass_without_pipeline_parallel(inference_input)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/megatron/core/inference/model_inference_wrappers/abstract_model_inference_wrapper.py", line 202, in forward_pass_without_pipeline_parallel
[rank0]:     logits = self._forward(inference_input)
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/megatron/core/inference/model_inference_wrappers/abstract_model_inference_wrapper.py", line 155, in _forward
[rank0]:     return self.model(
[rank0]:            ^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1762, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/nemo/collections/llm/gpt/model/megatron/hyena/hyena_model.py", line 265, in forward
[rank0]:     hidden_states = self.decoder(
[rank0]:                     ^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1762, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/nemo/collections/llm/gpt/model/megatron/hyena/hyena_block.py", line 302, in forward
[rank0]:     hidden_states = layer(
[rank0]:                     ^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/megatron/core/transformer/transformer_layer.py", line 875, in __call__
[rank0]:     return super(MegatronModule, self).__call__(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1762, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/megatron/core/transformer/transformer_layer.py", line 441, in forward
[rank0]:     hidden_states, context = self._forward_attention(*args, **kwargs)
[rank0]:                              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/megatron/core/transformer/transformer_layer.py", line 501, in _forward_attention
[rank0]:     attention_output_with_bias = self.self_attention(
[rank0]:                                  ^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1762, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/megatron/core/transformer/attention.py", line 629, in forward
[rank0]:     output = self.flash_decode(
[rank0]:              ^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/megatron/core/transformer/attention.py", line 424, in flash_decode
[rank0]:     out = flash_attn_with_kvcache(
[rank0]:           ^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/flash_attn/flash_attn_interface.py", line 1589, in flash_attn_with_kvcache
[rank0]:     out, softmax_lse = flash_attn_gpu.fwd_kvcache(
[rank0]:                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: RuntimeError: If key is supplied, seqlens_k must also be passed in
static requests:   0%|                                                                                                                | 0/1 [00:06<?, ?it/s]

Docker Image

nvcr.io/nvidia/clara/bionemo-framework:2.6.3-arm

System Information

GPU Details:

  • GPU Model: 8x NVIDIA H100 80GB HBM3
  • GPU Memory: 8x 80GB
  • CUDA Version: 12.2
  • CUDA Driver: 535.129.03

Additional Context

No response

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions