Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
62 commits
Select commit Hold shift + click to select a range
85efc1f
WIP engine and context cleanup
santhnm2 Jan 14, 2026
ca8b278
Tests pass
santhnm2 Jan 15, 2026
fafb2d2
Merge with main
santhnm2 Jan 15, 2026
01a8237
Undo attention.py change
santhnm2 Jan 15, 2026
faa935b
Merge remote-tracking branch 'upstream/main' into engine_context_cleanup
santhnm2 Jan 15, 2026
6da70fb
Fix unit tests
santhnm2 Jan 15, 2026
2f67856
Fix typo and clean up run_inference_performance_test.py
santhnm2 Jan 20, 2026
c8d3934
Merge branch 'main' into engine_context_cleanup
santhnm2 Jan 21, 2026
b801c9c
Address reviewer feedback
santhnm2 Jan 21, 2026
cff3c1e
Fix copyright
santhnm2 Jan 21, 2026
b6d9c93
Add back deprecated args with exception
santhnm2 Jan 22, 2026
ed6f96a
Merge with main
santhnm2 Jan 22, 2026
98d5acc
Add deprecation exception for GPTInferenceWrapper
santhnm2 Jan 22, 2026
ec2fad2
Fix wandb test
santhnm2 Jan 22, 2026
2dea9b3
Address reviewer comments
santhnm2 Jan 22, 2026
e5edb9a
Bug fixes
santhnm2 Jan 22, 2026
4b6cb9c
Major refactor - introduce DynamicInferenceConfig
santhnm2 Jan 23, 2026
8e8225a
Merge with main
santhnm2 Jan 23, 2026
bf66e63
Add inference config unit test
santhnm2 Jan 23, 2026
d14f476
Merge remote-tracking branch 'upstream/main' into engine_context_cleanup
santhnm2 Jan 23, 2026
f1af84f
Merge remote-tracking branch 'upstream/main' into engine_context_cleanup
santhnm2 Jan 23, 2026
188dcf2
Address reviewer feedback
santhnm2 Jan 23, 2026
92b1a8d
Fix copyright
santhnm2 Jan 23, 2026
c76d953
Fix import
santhnm2 Jan 23, 2026
71e790c
Merge remote-tracking branch 'upstream/main' into engine_context_cleanup
santhnm2 Jan 23, 2026
952dd38
Fix mamba test
santhnm2 Jan 23, 2026
84c6123
Fix unit tests
santhnm2 Jan 23, 2026
c50d844
Merge branch 'main' into engine_context_cleanup
santhnm2 Jan 23, 2026
ef50892
Fix example scripts
santhnm2 Jan 23, 2026
39c92c6
Update tools/run_inference_performance_test.py
santhnm2 Jan 24, 2026
9665804
Merge remote-tracking branch 'upstream/main' into engine_context_cleanup
santhnm2 Jan 24, 2026
fbe9db6
Merge remote-tracking branch 'upstream/main' into engine_context_cleanup
santhnm2 Jan 26, 2026
bdde060
Add explicit deprecation error
santhnm2 Jan 26, 2026
9bf08a8
Merge remote-tracking branch 'upstream/main' into engine_context_cleanup
santhnm2 Jan 26, 2026
964302c
Remove --inference-max-batch-size
santhnm2 Jan 27, 2026
ad80dc7
Merge remote-tracking branch 'upstream/main' into engine_context_cleanup
santhnm2 Jan 27, 2026
97cb2d6
RL fixes
santhnm2 Jan 27, 2026
58f5be5
Merge remote-tracking branch 'upstream/main' into engine_context_cleanup
santhnm2 Jan 27, 2026
70e4e3a
Fix perf test
santhnm2 Jan 27, 2026
c40e18e
Merge with main
santhnm2 Jan 29, 2026
0e4e286
Merge branch 'main' into engine_context_cleanup
santhnm2 Jan 29, 2026
6d85bf2
Merge with main
santhnm2 Jan 30, 2026
7080773
Merge with main
santhnm2 Jan 30, 2026
c3e10a2
Fix static functional test
santhnm2 Jan 30, 2026
b9fd22a
Merge remote-tracking branch 'upstream/main' into engine_context_cleanup
santhnm2 Jan 30, 2026
a57af29
Try removing arguments that are now in TransformerConfig
santhnm2 Jan 30, 2026
21c6d2d
Merge remote-tracking branch 'upstream/main' into engine_context_cleanup
santhnm2 Jan 30, 2026
dbb8614
Bug fixes
santhnm2 Jan 30, 2026
5474dea
Merge remote-tracking branch 'upstream/main' into engine_context_cleanup
santhnm2 Jan 30, 2026
ca27708
Fix formatting
santhnm2 Jan 30, 2026
80ddc9a
More bug fixes
santhnm2 Jan 30, 2026
4df1ee9
Merge remote-tracking branch 'upstream/main' into engine_context_cleanup
santhnm2 Feb 3, 2026
19c79a2
Fix static functional tests
santhnm2 Feb 3, 2026
3ab2a1a
Fix CI
ko3n1g Feb 3, 2026
019b1f4
Merge remote-tracking branch 'upstream/main' into engine_context_cleanup
santhnm2 Feb 3, 2026
2b0873e
Bug fixes
santhnm2 Feb 3, 2026
4c03a04
Merge remote-tracking branch 'upstream/main' into engine_context_cleanup
santhnm2 Feb 3, 2026
5953507
Fix linting
santhnm2 Feb 4, 2026
fe9e95f
Merge remote-tracking branch 'upstream/main' into engine_context_cleanup
santhnm2 Feb 4, 2026
60c03cf
Merge remote-tracking branch 'upstream/main' into engine_context_cleanup
santhnm2 Feb 4, 2026
aff0fa2
Mark EP test as broken
santhnm2 Feb 4, 2026
3d00cc5
Merge remote-tracking branch 'upstream/main' into engine_context_cleanup
santhnm2 Feb 4, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
339 changes: 77 additions & 262 deletions examples/inference/gpt/gpt_dynamic_inference.py

Large diffs are not rendered by default.

117 changes: 42 additions & 75 deletions examples/inference/gpt/gpt_dynamic_inference_with_coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,43 +2,33 @@

import asyncio
import json
import logging
import os
import time
import torch
import torch.distributed as dist
import warnings
from collections import defaultdict
from tqdm import tqdm
from typing import List
import warnings
import logging

from examples.inference.gpt.gpt_dynamic_inference import (
add_dynamic_inference_args,
get_inference_context,
get_inference_controller,
get_model,
)
from examples.inference.gpt.utils import (
Request,
build_dynamic_engine_setup_prefix,
build_requests,
add_common_inference_args
)
import torch
import torch.distributed as dist

from megatron.core import parallel_state
from examples.inference.gpt.utils import Request, build_dynamic_engine_setup_prefix, build_requests
from megatron.core.inference.engines import DynamicInferenceEngine
from megatron.core.inference.inference_client import InferenceClient
from megatron.core.inference.inference_request import DynamicInferenceRequestRecord
from megatron.core.inference.sampling_params import SamplingParams
from megatron.core.utils import get_mamba_inference_state_config_from_model

from megatron.inference.utils import (
add_inference_args,
get_dynamic_inference_engine,
get_model_for_inference,
)
from megatron.training import get_args, get_tokenizer, initialize_megatron
from megatron.training.arguments import parse_args

# pylint: disable=line-too-long

logging.basicConfig(level=logging.INFO, force=True)


async def main(
engine: DynamicInferenceEngine,
requests: List[Request],
Expand All @@ -51,12 +41,11 @@ async def main(
"Sampling parameters are specified per request.",
DeprecationWarning,
)

# once you call engine.start_listening_to_data_parallel_coordinator,
# the engine will start accepting requests from the data parallel coordinator.
# and processing them in an asyncio coroutine.
# leaving inference_coordinator_port as None will find a free port automatically.

dp_addr = await engine.start_listening_to_data_parallel_coordinator(
inference_coordinator_port=port,
launch_inference_coordinator=True,
Expand All @@ -69,14 +58,11 @@ async def main(
# Since the client doesn't directly call engine.async_step here, we test
# the suspend-resume system ~4 times.
suspend_resume_interval = max(1, len(requests) // 4)
suspend_idxs = set(range(
suspend_resume_interval,
len(requests) + 1,
suspend_resume_interval,
))
suspend_idxs = set(
range(suspend_resume_interval, len(requests) + 1, suspend_resume_interval)
)
resume_idxs = set(
min(len(requests), i + suspend_resume_interval // 2)
for i in suspend_idxs
min(len(requests), i + suspend_resume_interval // 2) for i in suspend_idxs
)
else:
suspend_idxs = set()
Expand All @@ -98,7 +84,10 @@ async def main(
current_time = time.time_ns() / 10**9
if args.incoming_requests_per_step is None:
# Only add requests that have arrived at the current time.
while num_requests_added < num_requests_total and requests[num_requests_added].time_arrival <= current_time:
while (
num_requests_added < num_requests_total
and requests[num_requests_added].time_arrival <= current_time
):
request = requests[num_requests_added]
# These add-request calls will queue up the request on a zmq socket and return
# instantaneously. They will return an asyncio future which can be awaited for
Expand All @@ -114,10 +103,9 @@ async def main(

else:
# Add deterministic number of requests (generally used for debugging).
for i in range(min(
args.incoming_requests_per_step,
num_requests_total - num_requests_added
)):
for i in range(
min(args.incoming_requests_per_step, num_requests_total - num_requests_added)
):
# Change sampling parameters to force different generation lengths.
request = requests[num_requests_added]
n = request.sampling_params.num_tokens_to_generate
Expand All @@ -135,7 +123,7 @@ async def main(
break
# Relinquish control since there are no more requests to add at the moment. This allows the engine to run.
await asyncio.sleep(0)

# While we wait for the requests to complete, the engine runs in the background.
results: List[DynamicInferenceRequestRecord] = await asyncio.gather(*futures)

Expand Down Expand Up @@ -170,16 +158,19 @@ async def main(
req = record.merge()
unique_prompt_map[req.prompt].append(req)
for idx, (prompt_text, reqs) in enumerate(unique_prompt_map.items()):
print(f"%d/%d. prompt '%s' ... [%d] output '%s'." % (
idx,
len(unique_prompt_map),
prompt_text.replace("\n", "\\n"),
len(reqs),
reqs[0].generated_text.replace("\n", "\\n"),
))
print(
f"%d/%d. prompt '%s' ... [%d] output '%s'."
% (
idx,
len(unique_prompt_map),
prompt_text.replace("\n", "\\n"),
len(reqs),
reqs[0].generated_text.replace("\n", "\\n"),
)
)

# kill the engines and suspend the client
# Right now, we can only call stop when all requests are done.
# Right now, we can only call stop when all requests are done.
# Todo: Make this explicit in the Client class....
await client.stop_engines()
client.stop()
Expand All @@ -190,11 +181,11 @@ async def main(


if __name__ == "__main__":
# enable inference mode in the very beginning as some fp-8 optimizations
# enable inference mode in the very beginning as some fp8 optimizations
# check for it.
with torch.inference_mode():
initialize_megatron(
extra_args_provider=add_dynamic_inference_args,
extra_args_provider=add_inference_args,
args_defaults={'no_load_rng': True, 'no_load_optim': True},
)

Expand All @@ -213,34 +204,16 @@ async def main(
),
)

# Requests, context, conroller.
model = get_model()
mamba_inference_state_config = get_mamba_inference_state_config_from_model(model)
model = get_model_for_inference()

requests = (
build_requests(args, tokenizer, sampling_params) if dist.get_rank() == 0 else None
)

context = get_inference_context(
None,
None,
calculate_max_sequence_length_from_requests=False,
mamba_inference_state_config=mamba_inference_state_config,
)

controller = get_inference_controller(model, context)

# Inference engine.
engine = DynamicInferenceEngine(
controller,
context,
enable_cuda_graph=args.cuda_graph_impl == "local",
random_seed=args.seed,
enable_chunked_prefill=not args.disable_chunked_prefill,
inference_logging_step_interval=args.inference_logging_step_interval,
)
engine = get_dynamic_inference_engine(model=model)

if dist.get_rank() == 0:
setup_prefix = build_dynamic_engine_setup_prefix(args, model, context, requests)
setup_prefix = build_dynamic_engine_setup_prefix(args, model, engine.context, requests)
print("~~~")
print(setup_prefix)
print("~~~")
Expand All @@ -249,13 +222,7 @@ async def main(
if os.environ.get("NSIGHT_PREFIX"):
torch.cuda.cudart().cudaProfilerStart()

asyncio.run(
main(
engine,
requests,
args.inference_coordinator_port,
)
)
asyncio.run(main(engine, requests, args.inference_coordinator_port))

# Stop Nsight profiler.
if os.environ.get("NSIGHT_PREFIX"):
Expand Down
71 changes: 13 additions & 58 deletions examples/inference/gpt/gpt_static_inference.py
Original file line number Diff line number Diff line change
@@ -1,58 +1,43 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.

import os
from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import (
InferenceWrapperConfig,
)
from model_provider import model_provider
from gpt_builders import gpt_builder
from mamba_builders import mamba_builder
import torch
import sys
import time
import warnings
from functools import partial
from argparse import Namespace

import torch
import tqdm

from megatron.core.inference.contexts import StaticInferenceContext
from megatron.core.inference.engines import StaticInferenceEngine
from megatron.core.inference.inference_request import InferenceRequest
from megatron.core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper import (
GPTInferenceWrapper,
)
from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import (
InferenceWrapperConfig,
)
from megatron.core.inference.sampling_params import SamplingParams
from megatron.core.inference.text_generation_controllers.text_generation_controller import (
TextGenerationController,
)
from megatron.core.tokenizers.text.utils.build_tokenizer import build_tokenizer
from megatron.core.transformer.module import MegatronModule
from pretrain_gpt import model_provider as gpt_model_provider
from pretrain_mamba import model_provider as mamba_model_provider

sys.path.append(
os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir, os.path.pardir))
)

import asyncio
import json
from typing import Any, AsyncIterator, List
from typing import List

from examples.inference.gpt.utils import add_common_inference_args, build_requests
from megatron.core import mpu
from megatron.training import get_args, get_model, get_tokenizer, print_rank_0
from megatron.training.checkpointing import load_checkpoint
from examples.inference.gpt.utils import build_requests
from megatron.inference.utils import add_inference_args, get_model_for_inference
from megatron.training import get_args, get_tokenizer, print_rank_0
from megatron.training.initialize import initialize_megatron


def add_static_inference_args(parser):
"""Static inference arguments."""

add_common_inference_args(parser)
add_inference_args(parser)

group = parser.add_argument_group(title='Static inference')
group.add_argument(
Expand Down Expand Up @@ -83,30 +68,16 @@ def get_inference_engine(args: Namespace, model: MegatronModule) -> StaticInfere
tokenizer = get_tokenizer()
else:
tokenizer = build_tokenizer(args)
inference_wrapper_config = InferenceWrapperConfig(
hidden_size=args.hidden_size,
inference_batch_times_seqlen_threshold=args.inference_batch_times_seqlen_threshold,
fp32_residual_connection=args.fp32_residual_connection,
params_dtype=args.params_dtype,
padded_vocab_size=args.padded_vocab_size,
inference_max_requests=args.inference_max_batch_size,
inference_max_seq_length=args.inference_max_seq_length,
nccl_all_reduce_for_prefill=args.nccl_all_reduce_for_prefill,
fp8=args.fp8,
moe_pad_experts_for_cuda_graph_inference = args.moe_pad_experts_for_cuda_graph_inference
)

inference_context = StaticInferenceContext.from_config(inference_wrapper_config)

inference_wrapped_model = GPTInferenceWrapper(
model, inference_wrapper_config, inference_context
inference_context = StaticInferenceContext(
args.inference_max_requests, args.inference_max_seq_length
)
inference_wrapped_model = GPTInferenceWrapper(model, inference_context)
text_generation_controller = TextGenerationController(
inference_wrapped_model=inference_wrapped_model, tokenizer=tokenizer
)
engine_kwargs = {
"text_generation_controller" : text_generation_controller,
"legacy" : args.use_legacy_static_engine,
"text_generation_controller": text_generation_controller,
"legacy": args.use_legacy_static_engine,
}
if not args.use_legacy_static_engine:
engine_kwargs["buffer_size_gb"] = args.inference_dynamic_batching_buffer_size_gb
Expand Down Expand Up @@ -165,22 +136,7 @@ def main():

args = get_args()

if args.max_batch_size is not None:
warnings.warn(
f"`--max-batch-size` has been deprecated in favor of `--inference-max-requests`."
)
args.inference_max_batch_size = max(args.max_batch_size, args.inference_max_batch_size)

# Set up model and load checkpoint
if args.model_provider == "gpt":
model_builder = gpt_builder
elif args.model_provider == "mamba":
model_builder = mamba_builder
else:
raise ValueError(f"Invalid model provider {args.model_provider}")
model = get_model(partial(model_provider, model_builder), wrap_with_ddp=False)
load_checkpoint(model, None, None, strict=False)
model = model[0]
model = get_model_for_inference()

inference_engine = get_inference_engine(args, model)

Expand Down Expand Up @@ -276,7 +232,7 @@ def main():
)
),
len(requests),
args.inference_max_batch_size,
args.inference_max_requests,
stats["allocated_bytes.all.peak"] / (1024**3),
stats["reserved_bytes.all.peak"] / (1024**3),
latency,
Expand All @@ -293,6 +249,5 @@ def main():
torch.distributed.destroy_process_group()



if __name__ == "__main__":
main()
Loading
Loading