Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/NeMo
Submodule NeMo updated from 164d12 to b97e42
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,86 @@ def run_subprocess_safely(command: str, timeout: int = 2000) -> Dict[str, Any]:
The result of the subprocess.
"""
try:
result = subprocess.run(shlex.split(command), capture_output=True, timeout=timeout, check=True, text=True)
# Use Popen to enable real-time output while still capturing it
process = subprocess.Popen(
shlex.split(command),
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
bufsize=1,
universal_newlines=True,
)

stdout_lines = []
stderr_lines = []

# Read output in real-time
import select
import sys

while True:
# Use select to check for available output (Unix/Linux/Mac only)
if hasattr(select, "select"):
ready, _, _ = select.select([process.stdout, process.stderr], [], [], 0.1)

if process.stdout in ready:
line = process.stdout.readline()
if line:
stdout_lines.append(line)
print(line.rstrip(), file=sys.stdout, flush=True)

if process.stderr in ready:
line = process.stderr.readline()
if line:
stderr_lines.append(line)
print(line.rstrip(), file=sys.stderr, flush=True)
else:
# Fallback for Windows - read with timeout
try:
stdout_data, stderr_data = process.communicate(timeout=0.1)
if stdout_data:
stdout_lines.extend(stdout_data.splitlines(keepends=True))
print(stdout_data.rstrip(), file=sys.stdout, flush=True)
if stderr_data:
stderr_lines.extend(stderr_data.splitlines(keepends=True))
print(stderr_data.rstrip(), file=sys.stderr, flush=True)
break
except subprocess.TimeoutExpired:
pass

# Check if process has finished
if process.poll() is not None:
# Read any remaining output
remaining_stdout, remaining_stderr = process.communicate()
if remaining_stdout:
stdout_lines.extend(remaining_stdout.splitlines(keepends=True))
print(remaining_stdout.rstrip(), file=sys.stdout, flush=True)
if remaining_stderr:
stderr_lines.extend(remaining_stderr.splitlines(keepends=True))
print(remaining_stderr.rstrip(), file=sys.stderr, flush=True)
break

# Check for timeout
try:
process.wait(timeout=timeout)
except subprocess.TimeoutExpired:
process.kill()
raise

# Check return code
if process.returncode != 0:
raise subprocess.CalledProcessError(
process.returncode, command, output="".join(stdout_lines), stderr="".join(stderr_lines)
)

# Create result object similar to subprocess.run
class Result:
def __init__(self, stdout, stderr, returncode):
self.stdout = stdout
self.stderr = stderr
self.returncode = returncode

result = Result("".join(stdout_lines), "".join(stderr_lines), process.returncode)
return {"stdout": result.stdout, "stderr": result.stderr, "returncode": result.returncode}
except subprocess.TimeoutExpired as e:
logger.error(f"Command timed out. Command: {command}\nstdout:\n{e.stdout}\nstderr:\n{e.stderr}")
Expand Down
59 changes: 40 additions & 19 deletions sub-packages/bionemo-evo2/src/bionemo/evo2/run/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,15 @@


import argparse
import sys
import time
from typing import Literal, Optional

import nemo.lightning as nl
import torch
from megatron.core.inference.common_inference_params import CommonInferenceParams
from nemo.collections.llm import generate
from megatron.core.inference.inference_request import InferenceRequest
from nemo.collections.llm import inference
from nemo.utils import logging


Expand Down Expand Up @@ -82,15 +85,17 @@ def parse_args():
help="Specify checkpoint format to use. Defaults to 'torch_dist', as 'zarr' is deprecated.",
)
ap.add_argument(
"--vortex-style-fp8",
"--fp8",
type=bool,
action="store_true",
default=False,
help="Whether to use vortex style FP8. Defaults to False.",
)
ap.add_argument(
"--flash-decode",
type=bool,
default=True,
action="store_true",
default=False,
help="Whether to use flash decode. Defaults to True.",
)
return ap.parse_args()
Expand All @@ -110,8 +115,9 @@ def infer(
ckpt_format: CheckpointFormats = "torch_dist",
seed: Optional[int] = None,
vortex_style_fp8: bool = False,
flash_decode: bool = True,
):
flash_decode: bool = False,
return_log_probs: bool = False,
) -> list[InferenceRequest]:
"""Inference workflow for Evo2.

Args:
Expand All @@ -129,6 +135,7 @@ def infer(
seed (int): Random seed for generation.
vortex_style_fp8 (bool): Whether to use vortex style FP8.
flash_decode (bool): Whether to use flash decode.
return_log_probs (bool): Whether to return log probabilities.

Returns:
None
Expand Down Expand Up @@ -162,31 +169,45 @@ def infer(
params_dtype=torch.bfloat16,
),
)

# transformers generate method has more options than NeMo/Megatron.
results = generate(
inference_wrapped_model, mcore_tokenizer = inference.setup_model_and_tokenizer(
path=ckpt_dir,
prompts=[prompt],
trainer=trainer,
params_dtype=torch.bfloat16,
inference_batch_times_seqlen_threshold=8192, # TODO
inference_max_seq_length=8192, # TODO
recompute_granularity=None,
recompute_num_layers=None,
recompute_method=None,
vortex_style_fp8=vortex_style_fp8,
flash_decode=flash_decode,
enable_flash_decode=flash_decode,
)
t0 = time.perf_counter_ns()
# TODO: fix return type in NeMo inference.generate (it is a list[InferenceRequest] not a dict)
results: list[InferenceRequest] = inference.generate(
model=inference_wrapped_model,
max_batch_size=1, # vortex only supports batch size 1
tokenizer=mcore_tokenizer,
prompts=[prompt],
random_seed=seed,
inference_params=CommonInferenceParams(
temperature,
top_k,
top_p,
return_log_probs=False,
temperature=temperature,
top_k=top_k,
top_p=top_p,
return_log_probs=return_log_probs,
num_tokens_to_generate=max_new_tokens,
),
text_only=True,
random_seed=seed if seed is not None else None,
vortex_style_fp8=vortex_style_fp8,
flash_decode=flash_decode,
)
dt = (time.perf_counter_ns() - t0) / 1e9 # seconds
tokens_per_sec = (len(results[0].generated_text) + 1) / dt # +1 for the prompt

print(f"Inference time: {dt} seconds, {tokens_per_sec} tokens/sec", file=sys.stderr)
if torch.distributed.get_rank() == 0:
if output_file is None:
logging.info(results)
else:
with open(output_file, "w") as f:
f.write(f"{results}\n")
f.write(f"{results[0]}\n")

return results

Expand All @@ -208,7 +229,7 @@ def main():
output_file=args.output_file,
ckpt_format=args.ckpt_format,
seed=args.seed,
vortex_style_fp8=args.vortex_style_fp8,
vortex_style_fp8=args.fp8, # Vortex only applied FP8 to some layers.
flash_decode=args.flash_decode,
)

Expand Down
47 changes: 47 additions & 0 deletions sub-packages/bionemo-evo2/tests/bionemo/evo2/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-Apache2
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


# conftest.py
import gc

import pytest
import torch


def pytest_sessionstart(session):
"""Called at the start of the test session."""
if torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats()
print(f"Starting test session. Initial GPU memory: {torch.cuda.memory_allocated() / 1024**3:.3f} GB")


def pytest_sessionfinish(session, exitstatus):
"""Called at the end of the test session."""
if torch.cuda.is_available():
peak_memory = torch.cuda.max_memory_allocated()
final_memory = torch.cuda.memory_allocated()
print("\nTest session complete:")
print(f" Peak GPU memory: {peak_memory / 1024**3:.3f} GB")
print(f" Final GPU memory: {final_memory / 1024**3:.3f} GB")


@pytest.fixture(autouse=True)
def cleanup_after_test():
"""Clean up GPU memory after each test."""
yield
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
11 changes: 10 additions & 1 deletion sub-packages/bionemo-evo2/tests/bionemo/evo2/run/test_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,20 @@
# limitations under the License.


import pytest
import torch

from bionemo.core.data.load import load
from bionemo.evo2.run.infer import infer
from bionemo.testing.megatron_parallel_state_utils import clean_parallel_state_context
from bionemo.testing.torch import check_fp8_support


RANDOM_SEED = 42


def test_run_infer():
@pytest.mark.parametrize("fast", [True, False])
def test_run_infer(fast: bool):
# Create PTL trainer.
tensor_parallel_size = 1
pipeline_model_parallel_size = 1
Expand Down Expand Up @@ -56,6 +61,8 @@ def test_run_infer():
else:
raise e

is_fp8_supported, _, _ = check_fp8_support(torch.cuda.current_device())

with clean_parallel_state_context():
infer(
prompt=default_prompt,
Expand All @@ -67,4 +74,6 @@ def test_run_infer():
tensor_parallel_size=tensor_parallel_size,
pipeline_model_parallel_size=pipeline_model_parallel_size,
context_parallel_size=context_parallel_size,
vortex_style_fp8=is_fp8_supported,
flash_decode=fast,
)
Loading