Skip to content

Commit 962d141

Browse files
authored
Hyena Inference Updates to support Flash Decode (#1000)
### Description - Support for Flash Decode - Do not yet support cudagraph. Depends on NeMo PR: NVIDIA-NeMo/NeMo#14315 --------- Signed-off-by: John St John <jstjohn@nvidia.com>
1 parent dcdeaa7 commit 962d141

File tree

7 files changed

+343
-56
lines changed

7 files changed

+343
-56
lines changed

3rdparty/NeMo

Submodule NeMo updated from 164d12b to b97e42b

sub-packages/bionemo-core/src/bionemo/core/utils/subprocess_utils.py

Lines changed: 80 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,86 @@ def run_subprocess_safely(command: str, timeout: int = 2000) -> Dict[str, Any]:
3434
The result of the subprocess.
3535
"""
3636
try:
37-
result = subprocess.run(shlex.split(command), capture_output=True, timeout=timeout, check=True, text=True)
37+
# Use Popen to enable real-time output while still capturing it
38+
process = subprocess.Popen(
39+
shlex.split(command),
40+
stdout=subprocess.PIPE,
41+
stderr=subprocess.PIPE,
42+
text=True,
43+
bufsize=1,
44+
universal_newlines=True,
45+
)
46+
47+
stdout_lines = []
48+
stderr_lines = []
49+
50+
# Read output in real-time
51+
import select
52+
import sys
53+
54+
while True:
55+
# Use select to check for available output (Unix/Linux/Mac only)
56+
if hasattr(select, "select"):
57+
ready, _, _ = select.select([process.stdout, process.stderr], [], [], 0.1)
58+
59+
if process.stdout in ready:
60+
line = process.stdout.readline()
61+
if line:
62+
stdout_lines.append(line)
63+
print(line.rstrip(), file=sys.stdout, flush=True)
64+
65+
if process.stderr in ready:
66+
line = process.stderr.readline()
67+
if line:
68+
stderr_lines.append(line)
69+
print(line.rstrip(), file=sys.stderr, flush=True)
70+
else:
71+
# Fallback for Windows - read with timeout
72+
try:
73+
stdout_data, stderr_data = process.communicate(timeout=0.1)
74+
if stdout_data:
75+
stdout_lines.extend(stdout_data.splitlines(keepends=True))
76+
print(stdout_data.rstrip(), file=sys.stdout, flush=True)
77+
if stderr_data:
78+
stderr_lines.extend(stderr_data.splitlines(keepends=True))
79+
print(stderr_data.rstrip(), file=sys.stderr, flush=True)
80+
break
81+
except subprocess.TimeoutExpired:
82+
pass
83+
84+
# Check if process has finished
85+
if process.poll() is not None:
86+
# Read any remaining output
87+
remaining_stdout, remaining_stderr = process.communicate()
88+
if remaining_stdout:
89+
stdout_lines.extend(remaining_stdout.splitlines(keepends=True))
90+
print(remaining_stdout.rstrip(), file=sys.stdout, flush=True)
91+
if remaining_stderr:
92+
stderr_lines.extend(remaining_stderr.splitlines(keepends=True))
93+
print(remaining_stderr.rstrip(), file=sys.stderr, flush=True)
94+
break
95+
96+
# Check for timeout
97+
try:
98+
process.wait(timeout=timeout)
99+
except subprocess.TimeoutExpired:
100+
process.kill()
101+
raise
102+
103+
# Check return code
104+
if process.returncode != 0:
105+
raise subprocess.CalledProcessError(
106+
process.returncode, command, output="".join(stdout_lines), stderr="".join(stderr_lines)
107+
)
108+
109+
# Create result object similar to subprocess.run
110+
class Result:
111+
def __init__(self, stdout, stderr, returncode):
112+
self.stdout = stdout
113+
self.stderr = stderr
114+
self.returncode = returncode
115+
116+
result = Result("".join(stdout_lines), "".join(stderr_lines), process.returncode)
38117
return {"stdout": result.stdout, "stderr": result.stderr, "returncode": result.returncode}
39118
except subprocess.TimeoutExpired as e:
40119
logger.error(f"Command timed out. Command: {command}\nstdout:\n{e.stdout}\nstderr:\n{e.stderr}")

sub-packages/bionemo-evo2/src/bionemo/evo2/run/infer.py

Lines changed: 40 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,15 @@
1818

1919

2020
import argparse
21+
import sys
22+
import time
2123
from typing import Literal, Optional
2224

2325
import nemo.lightning as nl
2426
import torch
2527
from megatron.core.inference.common_inference_params import CommonInferenceParams
26-
from nemo.collections.llm import generate
28+
from megatron.core.inference.inference_request import InferenceRequest
29+
from nemo.collections.llm import inference
2730
from nemo.utils import logging
2831

2932

@@ -82,15 +85,17 @@ def parse_args():
8285
help="Specify checkpoint format to use. Defaults to 'torch_dist', as 'zarr' is deprecated.",
8386
)
8487
ap.add_argument(
85-
"--vortex-style-fp8",
88+
"--fp8",
8689
type=bool,
90+
action="store_true",
8791
default=False,
8892
help="Whether to use vortex style FP8. Defaults to False.",
8993
)
9094
ap.add_argument(
9195
"--flash-decode",
9296
type=bool,
93-
default=True,
97+
action="store_true",
98+
default=False,
9499
help="Whether to use flash decode. Defaults to True.",
95100
)
96101
return ap.parse_args()
@@ -110,8 +115,9 @@ def infer(
110115
ckpt_format: CheckpointFormats = "torch_dist",
111116
seed: Optional[int] = None,
112117
vortex_style_fp8: bool = False,
113-
flash_decode: bool = True,
114-
):
118+
flash_decode: bool = False,
119+
return_log_probs: bool = False,
120+
) -> list[InferenceRequest]:
115121
"""Inference workflow for Evo2.
116122
117123
Args:
@@ -129,6 +135,7 @@ def infer(
129135
seed (int): Random seed for generation.
130136
vortex_style_fp8 (bool): Whether to use vortex style FP8.
131137
flash_decode (bool): Whether to use flash decode.
138+
return_log_probs (bool): Whether to return log probabilities.
132139
133140
Returns:
134141
None
@@ -162,31 +169,45 @@ def infer(
162169
params_dtype=torch.bfloat16,
163170
),
164171
)
165-
166-
# transformers generate method has more options than NeMo/Megatron.
167-
results = generate(
172+
inference_wrapped_model, mcore_tokenizer = inference.setup_model_and_tokenizer(
168173
path=ckpt_dir,
169-
prompts=[prompt],
170174
trainer=trainer,
175+
params_dtype=torch.bfloat16,
176+
inference_batch_times_seqlen_threshold=8192, # TODO
177+
inference_max_seq_length=8192, # TODO
178+
recompute_granularity=None,
179+
recompute_num_layers=None,
180+
recompute_method=None,
181+
vortex_style_fp8=vortex_style_fp8,
182+
flash_decode=flash_decode,
183+
enable_flash_decode=flash_decode,
184+
)
185+
t0 = time.perf_counter_ns()
186+
# TODO: fix return type in NeMo inference.generate (it is a list[InferenceRequest] not a dict)
187+
results: list[InferenceRequest] = inference.generate(
188+
model=inference_wrapped_model,
189+
max_batch_size=1, # vortex only supports batch size 1
190+
tokenizer=mcore_tokenizer,
191+
prompts=[prompt],
192+
random_seed=seed,
171193
inference_params=CommonInferenceParams(
172-
temperature,
173-
top_k,
174-
top_p,
175-
return_log_probs=False,
194+
temperature=temperature,
195+
top_k=top_k,
196+
top_p=top_p,
197+
return_log_probs=return_log_probs,
176198
num_tokens_to_generate=max_new_tokens,
177199
),
178-
text_only=True,
179-
random_seed=seed if seed is not None else None,
180-
vortex_style_fp8=vortex_style_fp8,
181-
flash_decode=flash_decode,
182200
)
201+
dt = (time.perf_counter_ns() - t0) / 1e9 # seconds
202+
tokens_per_sec = (len(results[0].generated_text) + 1) / dt # +1 for the prompt
183203

204+
print(f"Inference time: {dt} seconds, {tokens_per_sec} tokens/sec", file=sys.stderr)
184205
if torch.distributed.get_rank() == 0:
185206
if output_file is None:
186207
logging.info(results)
187208
else:
188209
with open(output_file, "w") as f:
189-
f.write(f"{results}\n")
210+
f.write(f"{results[0]}\n")
190211

191212
return results
192213

@@ -208,7 +229,7 @@ def main():
208229
output_file=args.output_file,
209230
ckpt_format=args.ckpt_format,
210231
seed=args.seed,
211-
vortex_style_fp8=args.vortex_style_fp8,
232+
vortex_style_fp8=args.fp8, # Vortex only applied FP8 to some layers.
212233
flash_decode=args.flash_decode,
213234
)
214235

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: LicenseRef-Apache2
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
17+
# conftest.py
18+
import gc
19+
20+
import pytest
21+
import torch
22+
23+
24+
def pytest_sessionstart(session):
25+
"""Called at the start of the test session."""
26+
if torch.cuda.is_available():
27+
torch.cuda.reset_peak_memory_stats()
28+
print(f"Starting test session. Initial GPU memory: {torch.cuda.memory_allocated() / 1024**3:.3f} GB")
29+
30+
31+
def pytest_sessionfinish(session, exitstatus):
32+
"""Called at the end of the test session."""
33+
if torch.cuda.is_available():
34+
peak_memory = torch.cuda.max_memory_allocated()
35+
final_memory = torch.cuda.memory_allocated()
36+
print("\nTest session complete:")
37+
print(f" Peak GPU memory: {peak_memory / 1024**3:.3f} GB")
38+
print(f" Final GPU memory: {final_memory / 1024**3:.3f} GB")
39+
40+
41+
@pytest.fixture(autouse=True)
42+
def cleanup_after_test():
43+
"""Clean up GPU memory after each test."""
44+
yield
45+
if torch.cuda.is_available():
46+
torch.cuda.empty_cache()
47+
gc.collect()

sub-packages/bionemo-evo2/tests/bionemo/evo2/run/test_infer.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,20 @@
1717
# limitations under the License.
1818

1919

20+
import pytest
21+
import torch
22+
2023
from bionemo.core.data.load import load
2124
from bionemo.evo2.run.infer import infer
2225
from bionemo.testing.megatron_parallel_state_utils import clean_parallel_state_context
26+
from bionemo.testing.torch import check_fp8_support
2327

2428

2529
RANDOM_SEED = 42
2630

2731

28-
def test_run_infer():
32+
@pytest.mark.parametrize("fast", [True, False])
33+
def test_run_infer(fast: bool):
2934
# Create PTL trainer.
3035
tensor_parallel_size = 1
3136
pipeline_model_parallel_size = 1
@@ -56,6 +61,8 @@ def test_run_infer():
5661
else:
5762
raise e
5863

64+
is_fp8_supported, _, _ = check_fp8_support(torch.cuda.current_device())
65+
5966
with clean_parallel_state_context():
6067
infer(
6168
prompt=default_prompt,
@@ -67,4 +74,6 @@ def test_run_infer():
6774
tensor_parallel_size=tensor_parallel_size,
6875
pipeline_model_parallel_size=pipeline_model_parallel_size,
6976
context_parallel_size=context_parallel_size,
77+
vortex_style_fp8=is_fp8_supported,
78+
flash_decode=fast,
7079
)

0 commit comments

Comments
 (0)