Skip to content

Commit 2050da3

Browse files
committed
Revert "Miscellaneous inference cleanup (Replay of !2955) (NVIDIA#3232)"
This reverts commit 43db8c1.
1 parent 60a25aa commit 2050da3

File tree

67 files changed

+2168
-1232
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

67 files changed

+2168
-1232
lines changed

examples/inference/gpt/gpt_dynamic_inference.py

Lines changed: 263 additions & 78 deletions
Large diffs are not rendered by default.

examples/inference/gpt/gpt_dynamic_inference_with_coordinator.py

Lines changed: 75 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -2,33 +2,43 @@
22

33
import asyncio
44
import json
5-
import logging
65
import os
76
import time
8-
import warnings
7+
import torch
8+
import torch.distributed as dist
99
from collections import defaultdict
10+
from tqdm import tqdm
1011
from typing import List
12+
import warnings
13+
import logging
1114

12-
import torch
13-
import torch.distributed as dist
15+
from examples.inference.gpt.gpt_dynamic_inference import (
16+
add_dynamic_inference_args,
17+
get_inference_context,
18+
get_inference_controller,
19+
get_model,
20+
)
21+
from examples.inference.gpt.utils import (
22+
Request,
23+
build_dynamic_engine_setup_prefix,
24+
build_requests,
25+
add_common_inference_args
26+
)
1427

15-
from examples.inference.gpt.utils import Request, build_dynamic_engine_setup_prefix, build_requests
28+
from megatron.core import parallel_state
1629
from megatron.core.inference.engines import DynamicInferenceEngine
1730
from megatron.core.inference.inference_client import InferenceClient
1831
from megatron.core.inference.inference_request import DynamicInferenceRequestRecord
1932
from megatron.core.inference.sampling_params import SamplingParams
20-
from megatron.inference.utils import (
21-
add_inference_args,
22-
get_dynamic_inference_engine,
23-
get_model_for_inference,
24-
)
33+
from megatron.core.utils import get_mamba_inference_state_config_from_model
34+
2535
from megatron.training import get_args, get_tokenizer, initialize_megatron
36+
from megatron.training.arguments import parse_args
2637

2738
# pylint: disable=line-too-long
2839

2940
logging.basicConfig(level=logging.INFO, force=True)
3041

31-
3242
async def main(
3343
engine: DynamicInferenceEngine,
3444
requests: List[Request],
@@ -41,11 +51,12 @@ async def main(
4151
"Sampling parameters are specified per request.",
4252
DeprecationWarning,
4353
)
44-
54+
4555
# once you call engine.start_listening_to_data_parallel_coordinator,
4656
# the engine will start accepting requests from the data parallel coordinator.
4757
# and processing them in an asyncio coroutine.
4858
# leaving inference_coordinator_port as None will find a free port automatically.
59+
4960
dp_addr = await engine.start_listening_to_data_parallel_coordinator(
5061
inference_coordinator_port=port,
5162
launch_inference_coordinator=True,
@@ -58,11 +69,14 @@ async def main(
5869
# Since the client doesn't directly call engine.async_step here, we test
5970
# the suspend-resume system ~4 times.
6071
suspend_resume_interval = max(1, len(requests) // 4)
61-
suspend_idxs = set(
62-
range(suspend_resume_interval, len(requests) + 1, suspend_resume_interval)
63-
)
72+
suspend_idxs = set(range(
73+
suspend_resume_interval,
74+
len(requests) + 1,
75+
suspend_resume_interval,
76+
))
6477
resume_idxs = set(
65-
min(len(requests), i + suspend_resume_interval // 2) for i in suspend_idxs
78+
min(len(requests), i + suspend_resume_interval // 2)
79+
for i in suspend_idxs
6680
)
6781
else:
6882
suspend_idxs = set()
@@ -84,10 +98,7 @@ async def main(
8498
current_time = time.time_ns() / 10**9
8599
if args.incoming_requests_per_step is None:
86100
# Only add requests that have arrived at the current time.
87-
while (
88-
num_requests_added < num_requests_total
89-
and requests[num_requests_added].time_arrival <= current_time
90-
):
101+
while num_requests_added < num_requests_total and requests[num_requests_added].time_arrival <= current_time:
91102
request = requests[num_requests_added]
92103
# These add-request calls will queue up the request on a zmq socket and return
93104
# instantaneously. They will return an asyncio future which can be awaited for
@@ -103,9 +114,10 @@ async def main(
103114

104115
else:
105116
# Add deterministic number of requests (generally used for debugging).
106-
for i in range(
107-
min(args.incoming_requests_per_step, num_requests_total - num_requests_added)
108-
):
117+
for i in range(min(
118+
args.incoming_requests_per_step,
119+
num_requests_total - num_requests_added
120+
)):
109121
# Change sampling parameters to force different generation lengths.
110122
request = requests[num_requests_added]
111123
n = request.sampling_params.num_tokens_to_generate
@@ -123,7 +135,7 @@ async def main(
123135
break
124136
# Relinquish control since there are no more requests to add at the moment. This allows the engine to run.
125137
await asyncio.sleep(0)
126-
138+
127139
# While we wait for the requests to complete, the engine runs in the background.
128140
results: List[DynamicInferenceRequestRecord] = await asyncio.gather(*futures)
129141

@@ -161,19 +173,16 @@ async def main(
161173
req = record.merge()
162174
unique_prompt_map[req.prompt].append(req)
163175
for idx, (prompt_text, reqs) in enumerate(unique_prompt_map.items()):
164-
print(
165-
f"%d/%d. prompt '%s' ... [%d] output '%s'."
166-
% (
167-
idx,
168-
len(unique_prompt_map),
169-
prompt_text.replace("\n", "\\n"),
170-
len(reqs),
171-
reqs[0].generated_text.replace("\n", "\\n"),
172-
)
173-
)
176+
print(f"%d/%d. prompt '%s' ... [%d] output '%s'." % (
177+
idx,
178+
len(unique_prompt_map),
179+
prompt_text.replace("\n", "\\n"),
180+
len(reqs),
181+
reqs[0].generated_text.replace("\n", "\\n"),
182+
))
174183

175184
# kill the engines and suspend the client
176-
# Right now, we can only call stop when all requests are done.
185+
# Right now, we can only call stop when all requests are done.
177186
# Todo: Make this explicit in the Client class....
178187
await client.stop_engines()
179188
client.stop()
@@ -184,11 +193,11 @@ async def main(
184193

185194

186195
if __name__ == "__main__":
187-
# enable inference mode in the very beginning as some fp8 optimizations
196+
# enable inference mode in the very beginning as some fp-8 optimizations
188197
# check for it.
189198
with torch.inference_mode():
190199
initialize_megatron(
191-
extra_args_provider=add_inference_args,
200+
extra_args_provider=add_dynamic_inference_args,
192201
args_defaults={'no_load_rng': True, 'no_load_optim': True},
193202
)
194203

@@ -207,16 +216,34 @@ async def main(
207216
),
208217
)
209218

210-
model = get_model_for_inference()
211-
219+
# Requests, context, conroller.
220+
model = get_model()
221+
mamba_inference_state_config = get_mamba_inference_state_config_from_model(model)
212222
requests = (
213223
build_requests(args, tokenizer, sampling_params) if dist.get_rank() == 0 else None
214224
)
215225

216-
engine = get_dynamic_inference_engine(model=model)
226+
context = get_inference_context(
227+
None,
228+
None,
229+
calculate_max_sequence_length_from_requests=False,
230+
mamba_inference_state_config=mamba_inference_state_config,
231+
)
232+
233+
controller = get_inference_controller(model, context)
234+
235+
# Inference engine.
236+
engine = DynamicInferenceEngine(
237+
controller,
238+
context,
239+
enable_cuda_graph=args.cuda_graph_impl == "local",
240+
random_seed=args.seed,
241+
enable_chunked_prefill=not args.disable_chunked_prefill,
242+
inference_logging_step_interval=args.inference_logging_step_interval,
243+
)
217244

218245
if dist.get_rank() == 0:
219-
setup_prefix = build_dynamic_engine_setup_prefix(args, model, engine.context, requests)
246+
setup_prefix = build_dynamic_engine_setup_prefix(args, model, context, requests)
220247
print("~~~")
221248
print(setup_prefix)
222249
print("~~~")
@@ -225,7 +252,13 @@ async def main(
225252
if os.environ.get("NSIGHT_PREFIX"):
226253
torch.cuda.cudart().cudaProfilerStart()
227254

228-
asyncio.run(main(engine, requests, args.inference_coordinator_port))
255+
asyncio.run(
256+
main(
257+
engine,
258+
requests,
259+
args.inference_coordinator_port,
260+
)
261+
)
229262

230263
# Stop Nsight profiler.
231264
if os.environ.get("NSIGHT_PREFIX"):

examples/inference/gpt/gpt_static_inference.py

Lines changed: 62 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,58 @@
11
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
22

33
import os
4+
from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import (
5+
InferenceWrapperConfig,
6+
)
7+
from model_provider import model_provider
8+
from gpt_builders import gpt_builder
9+
from mamba_builders import mamba_builder
10+
import torch
411
import sys
512
import time
13+
import warnings
14+
from functools import partial
615
from argparse import Namespace
716

817
import torch
18+
import tqdm
919

1020
from megatron.core.inference.contexts import StaticInferenceContext
1121
from megatron.core.inference.engines import StaticInferenceEngine
1222
from megatron.core.inference.inference_request import InferenceRequest
1323
from megatron.core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper import (
1424
GPTInferenceWrapper,
1525
)
26+
from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import (
27+
InferenceWrapperConfig,
28+
)
1629
from megatron.core.inference.sampling_params import SamplingParams
1730
from megatron.core.inference.text_generation_controllers.text_generation_controller import (
1831
TextGenerationController,
1932
)
2033
from megatron.core.tokenizers.utils.build_tokenizer import build_tokenizer
2134
from megatron.core.transformer.module import MegatronModule
35+
from pretrain_gpt import model_provider as gpt_model_provider
36+
from pretrain_mamba import model_provider as mamba_model_provider
2237

2338
sys.path.append(
2439
os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir, os.path.pardir))
2540
)
2641

2742
import asyncio
2843
import json
29-
from typing import List
44+
from typing import Any, AsyncIterator, List
3045

31-
from examples.inference.gpt.utils import build_requests
32-
from megatron.inference.utils import add_inference_args, get_model_for_inference
33-
from megatron.training import get_args, get_tokenizer, print_rank_0
46+
from examples.inference.gpt.utils import add_common_inference_args, build_requests
47+
from megatron.core import mpu
48+
from megatron.training import get_args, get_model, get_tokenizer, print_rank_0
49+
from megatron.training.checkpointing import load_checkpoint
3450
from megatron.training.initialize import initialize_megatron
3551

36-
3752
def add_static_inference_args(parser):
3853
"""Static inference arguments."""
3954

40-
add_inference_args(parser)
55+
add_common_inference_args(parser)
4156

4257
group = parser.add_argument_group(title='Static inference')
4358
group.add_argument(
@@ -64,17 +79,34 @@ def get_inference_engine(args: Namespace, model: MegatronModule) -> StaticInfere
6479
Returns:
6580
AbstractBackend: The chosen backend
6681
"""
67-
tokenizer = build_tokenizer(args)
68-
inference_context = StaticInferenceContext(
69-
args.inference_max_requests, args.inference_max_seq_length
82+
if args.legacy_tokenizer:
83+
tokenizer = get_tokenizer()
84+
else:
85+
tokenizer = build_tokenizer(args)
86+
inference_wrapper_config = InferenceWrapperConfig(
87+
hidden_size=args.hidden_size,
88+
inference_batch_times_seqlen_threshold=args.inference_batch_times_seqlen_threshold,
89+
fp32_residual_connection=args.fp32_residual_connection,
90+
params_dtype=args.params_dtype,
91+
padded_vocab_size=args.padded_vocab_size,
92+
inference_max_requests=args.inference_max_batch_size,
93+
inference_max_seq_length=args.inference_max_seq_length,
94+
nccl_all_reduce_for_prefill=args.nccl_all_reduce_for_prefill,
95+
fp8=args.fp8,
96+
moe_pad_experts_for_cuda_graph_inference = args.moe_pad_experts_for_cuda_graph_inference
97+
)
98+
99+
inference_context = StaticInferenceContext.from_config(inference_wrapper_config)
100+
101+
inference_wrapped_model = GPTInferenceWrapper(
102+
model, inference_wrapper_config, inference_context
70103
)
71-
inference_wrapped_model = GPTInferenceWrapper(model, inference_context)
72104
text_generation_controller = TextGenerationController(
73105
inference_wrapped_model=inference_wrapped_model, tokenizer=tokenizer
74106
)
75107
engine_kwargs = {
76-
"text_generation_controller": text_generation_controller,
77-
"legacy": args.use_legacy_static_engine,
108+
"text_generation_controller" : text_generation_controller,
109+
"legacy" : args.use_legacy_static_engine,
78110
}
79111
if not args.use_legacy_static_engine:
80112
engine_kwargs["buffer_size_gb"] = args.inference_dynamic_batching_buffer_size_gb
@@ -133,7 +165,22 @@ def main():
133165

134166
args = get_args()
135167

136-
model = get_model_for_inference()
168+
if args.max_batch_size is not None:
169+
warnings.warn(
170+
f"`--max-batch-size` has been deprecated in favor of `--inference-max-requests`."
171+
)
172+
args.inference_max_batch_size = max(args.max_batch_size, args.inference_max_batch_size)
173+
174+
# Set up model and load checkpoint
175+
if args.model_provider == "gpt":
176+
model_builder = gpt_builder
177+
elif args.model_provider == "mamba":
178+
model_builder = mamba_builder
179+
else:
180+
raise ValueError(f"Invalid model provider {args.model_provider}")
181+
model = get_model(partial(model_provider, model_builder), wrap_with_ddp=False)
182+
load_checkpoint(model, None, None, strict=False)
183+
model = model[0]
137184

138185
inference_engine = get_inference_engine(args, model)
139186

@@ -228,7 +275,7 @@ def main():
228275
)
229276
),
230277
len(requests),
231-
args.inference_max_requests,
278+
args.inference_max_batch_size,
232279
stats["allocated_bytes.all.peak"] / (1024**3),
233280
stats["reserved_bytes.all.peak"] / (1024**3),
234281
latency,
@@ -245,5 +292,6 @@ def main():
245292
torch.distributed.destroy_process_group()
246293

247294

295+
248296
if __name__ == "__main__":
249297
main()

0 commit comments

Comments
 (0)