Skip to content

Commit 39d8e77

Browse files
santhnm2ko3n1g
authored andcommitted
Miscellaneous inference cleanup (Replay of !2955) (NVIDIA#3232)
Signed-off-by: Keshav Santhanam <ksanthanam@nvidia.com> Co-authored-by: oliver könig <okoenig@nvidia.com>
1 parent bb7ff3f commit 39d8e77

File tree

68 files changed

+1372
-2087
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

68 files changed

+1372
-2087
lines changed

examples/inference/gpt/gpt_dynamic_inference.py

Lines changed: 77 additions & 262 deletions
Large diffs are not rendered by default.

examples/inference/gpt/gpt_dynamic_inference_with_coordinator.py

Lines changed: 42 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -2,43 +2,33 @@
22

33
import asyncio
44
import json
5+
import logging
56
import os
67
import time
7-
import torch
8-
import torch.distributed as dist
8+
import warnings
99
from collections import defaultdict
10-
from tqdm import tqdm
1110
from typing import List
12-
import warnings
13-
import logging
1411

15-
from examples.inference.gpt.gpt_dynamic_inference import (
16-
add_dynamic_inference_args,
17-
get_inference_context,
18-
get_inference_controller,
19-
get_model,
20-
)
21-
from examples.inference.gpt.utils import (
22-
Request,
23-
build_dynamic_engine_setup_prefix,
24-
build_requests,
25-
add_common_inference_args
26-
)
12+
import torch
13+
import torch.distributed as dist
2714

28-
from megatron.core import parallel_state
15+
from examples.inference.gpt.utils import Request, build_dynamic_engine_setup_prefix, build_requests
2916
from megatron.core.inference.engines import DynamicInferenceEngine
3017
from megatron.core.inference.inference_client import InferenceClient
3118
from megatron.core.inference.inference_request import DynamicInferenceRequestRecord
3219
from megatron.core.inference.sampling_params import SamplingParams
33-
from megatron.core.utils import get_mamba_inference_state_config_from_model
34-
20+
from megatron.inference.utils import (
21+
add_inference_args,
22+
get_dynamic_inference_engine,
23+
get_model_for_inference,
24+
)
3525
from megatron.training import get_args, get_tokenizer, initialize_megatron
36-
from megatron.training.arguments import parse_args
3726

3827
# pylint: disable=line-too-long
3928

4029
logging.basicConfig(level=logging.INFO, force=True)
4130

31+
4232
async def main(
4333
engine: DynamicInferenceEngine,
4434
requests: List[Request],
@@ -51,12 +41,11 @@ async def main(
5141
"Sampling parameters are specified per request.",
5242
DeprecationWarning,
5343
)
54-
44+
5545
# once you call engine.start_listening_to_data_parallel_coordinator,
5646
# the engine will start accepting requests from the data parallel coordinator.
5747
# and processing them in an asyncio coroutine.
5848
# leaving inference_coordinator_port as None will find a free port automatically.
59-
6049
dp_addr = await engine.start_listening_to_data_parallel_coordinator(
6150
inference_coordinator_port=port,
6251
launch_inference_coordinator=True,
@@ -69,14 +58,11 @@ async def main(
6958
# Since the client doesn't directly call engine.async_step here, we test
7059
# the suspend-resume system ~4 times.
7160
suspend_resume_interval = max(1, len(requests) // 4)
72-
suspend_idxs = set(range(
73-
suspend_resume_interval,
74-
len(requests) + 1,
75-
suspend_resume_interval,
76-
))
61+
suspend_idxs = set(
62+
range(suspend_resume_interval, len(requests) + 1, suspend_resume_interval)
63+
)
7764
resume_idxs = set(
78-
min(len(requests), i + suspend_resume_interval // 2)
79-
for i in suspend_idxs
65+
min(len(requests), i + suspend_resume_interval // 2) for i in suspend_idxs
8066
)
8167
else:
8268
suspend_idxs = set()
@@ -98,7 +84,10 @@ async def main(
9884
current_time = time.time_ns() / 10**9
9985
if args.incoming_requests_per_step is None:
10086
# Only add requests that have arrived at the current time.
101-
while num_requests_added < num_requests_total and requests[num_requests_added].time_arrival <= current_time:
87+
while (
88+
num_requests_added < num_requests_total
89+
and requests[num_requests_added].time_arrival <= current_time
90+
):
10291
request = requests[num_requests_added]
10392
# These add-request calls will queue up the request on a zmq socket and return
10493
# instantaneously. They will return an asyncio future which can be awaited for
@@ -114,10 +103,9 @@ async def main(
114103

115104
else:
116105
# Add deterministic number of requests (generally used for debugging).
117-
for i in range(min(
118-
args.incoming_requests_per_step,
119-
num_requests_total - num_requests_added
120-
)):
106+
for i in range(
107+
min(args.incoming_requests_per_step, num_requests_total - num_requests_added)
108+
):
121109
# Change sampling parameters to force different generation lengths.
122110
request = requests[num_requests_added]
123111
n = request.sampling_params.num_tokens_to_generate
@@ -135,7 +123,7 @@ async def main(
135123
break
136124
# Relinquish control since there are no more requests to add at the moment. This allows the engine to run.
137125
await asyncio.sleep(0)
138-
126+
139127
# While we wait for the requests to complete, the engine runs in the background.
140128
results: List[DynamicInferenceRequestRecord] = await asyncio.gather(*futures)
141129

@@ -170,16 +158,19 @@ async def main(
170158
req = record.merge()
171159
unique_prompt_map[req.prompt].append(req)
172160
for idx, (prompt_text, reqs) in enumerate(unique_prompt_map.items()):
173-
print(f"%d/%d. prompt '%s' ... [%d] output '%s'." % (
174-
idx,
175-
len(unique_prompt_map),
176-
prompt_text.replace("\n", "\\n"),
177-
len(reqs),
178-
reqs[0].generated_text.replace("\n", "\\n"),
179-
))
161+
print(
162+
f"%d/%d. prompt '%s' ... [%d] output '%s'."
163+
% (
164+
idx,
165+
len(unique_prompt_map),
166+
prompt_text.replace("\n", "\\n"),
167+
len(reqs),
168+
reqs[0].generated_text.replace("\n", "\\n"),
169+
)
170+
)
180171

181172
# kill the engines and suspend the client
182-
# Right now, we can only call stop when all requests are done.
173+
# Right now, we can only call stop when all requests are done.
183174
# Todo: Make this explicit in the Client class....
184175
await client.stop_engines()
185176
client.stop()
@@ -190,11 +181,11 @@ async def main(
190181

191182

192183
if __name__ == "__main__":
193-
# enable inference mode in the very beginning as some fp-8 optimizations
184+
# enable inference mode in the very beginning as some fp8 optimizations
194185
# check for it.
195186
with torch.inference_mode():
196187
initialize_megatron(
197-
extra_args_provider=add_dynamic_inference_args,
188+
extra_args_provider=add_inference_args,
198189
args_defaults={'no_load_rng': True, 'no_load_optim': True},
199190
)
200191

@@ -213,34 +204,16 @@ async def main(
213204
),
214205
)
215206

216-
# Requests, context, conroller.
217-
model = get_model()
218-
mamba_inference_state_config = get_mamba_inference_state_config_from_model(model)
207+
model = get_model_for_inference()
208+
219209
requests = (
220210
build_requests(args, tokenizer, sampling_params) if dist.get_rank() == 0 else None
221211
)
222212

223-
context = get_inference_context(
224-
None,
225-
None,
226-
calculate_max_sequence_length_from_requests=False,
227-
mamba_inference_state_config=mamba_inference_state_config,
228-
)
229-
230-
controller = get_inference_controller(model, context)
231-
232-
# Inference engine.
233-
engine = DynamicInferenceEngine(
234-
controller,
235-
context,
236-
enable_cuda_graph=args.cuda_graph_impl == "local",
237-
random_seed=args.seed,
238-
enable_chunked_prefill=not args.disable_chunked_prefill,
239-
inference_logging_step_interval=args.inference_logging_step_interval,
240-
)
213+
engine = get_dynamic_inference_engine(model=model)
241214

242215
if dist.get_rank() == 0:
243-
setup_prefix = build_dynamic_engine_setup_prefix(args, model, context, requests)
216+
setup_prefix = build_dynamic_engine_setup_prefix(args, model, engine.context, requests)
244217
print("~~~")
245218
print(setup_prefix)
246219
print("~~~")
@@ -249,13 +222,7 @@ async def main(
249222
if os.environ.get("NSIGHT_PREFIX"):
250223
torch.cuda.cudart().cudaProfilerStart()
251224

252-
asyncio.run(
253-
main(
254-
engine,
255-
requests,
256-
args.inference_coordinator_port,
257-
)
258-
)
225+
asyncio.run(main(engine, requests, args.inference_coordinator_port))
259226

260227
# Stop Nsight profiler.
261228
if os.environ.get("NSIGHT_PREFIX"):

examples/inference/gpt/gpt_static_inference.py

Lines changed: 13 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,58 +1,43 @@
11
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
22

33
import os
4-
from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import (
5-
InferenceWrapperConfig,
6-
)
7-
from model_provider import model_provider
8-
from gpt_builders import gpt_builder
9-
from mamba_builders import mamba_builder
10-
import torch
114
import sys
125
import time
13-
import warnings
14-
from functools import partial
156
from argparse import Namespace
167

178
import torch
18-
import tqdm
199

2010
from megatron.core.inference.contexts import StaticInferenceContext
2111
from megatron.core.inference.engines import StaticInferenceEngine
2212
from megatron.core.inference.inference_request import InferenceRequest
2313
from megatron.core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper import (
2414
GPTInferenceWrapper,
2515
)
26-
from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import (
27-
InferenceWrapperConfig,
28-
)
2916
from megatron.core.inference.sampling_params import SamplingParams
3017
from megatron.core.inference.text_generation_controllers.text_generation_controller import (
3118
TextGenerationController,
3219
)
3320
from megatron.core.tokenizers.text.utils.build_tokenizer import build_tokenizer
3421
from megatron.core.transformer.module import MegatronModule
35-
from pretrain_gpt import model_provider as gpt_model_provider
36-
from pretrain_mamba import model_provider as mamba_model_provider
3722

3823
sys.path.append(
3924
os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir, os.path.pardir))
4025
)
4126

4227
import asyncio
4328
import json
44-
from typing import Any, AsyncIterator, List
29+
from typing import List
4530

46-
from examples.inference.gpt.utils import add_common_inference_args, build_requests
47-
from megatron.core import mpu
48-
from megatron.training import get_args, get_model, get_tokenizer, print_rank_0
49-
from megatron.training.checkpointing import load_checkpoint
31+
from examples.inference.gpt.utils import build_requests
32+
from megatron.inference.utils import add_inference_args, get_model_for_inference
33+
from megatron.training import get_args, get_tokenizer, print_rank_0
5034
from megatron.training.initialize import initialize_megatron
5135

36+
5237
def add_static_inference_args(parser):
5338
"""Static inference arguments."""
5439

55-
add_common_inference_args(parser)
40+
add_inference_args(parser)
5641

5742
group = parser.add_argument_group(title='Static inference')
5843
group.add_argument(
@@ -83,30 +68,16 @@ def get_inference_engine(args: Namespace, model: MegatronModule) -> StaticInfere
8368
tokenizer = get_tokenizer()
8469
else:
8570
tokenizer = build_tokenizer(args)
86-
inference_wrapper_config = InferenceWrapperConfig(
87-
hidden_size=args.hidden_size,
88-
inference_batch_times_seqlen_threshold=args.inference_batch_times_seqlen_threshold,
89-
fp32_residual_connection=args.fp32_residual_connection,
90-
params_dtype=args.params_dtype,
91-
padded_vocab_size=args.padded_vocab_size,
92-
inference_max_requests=args.inference_max_batch_size,
93-
inference_max_seq_length=args.inference_max_seq_length,
94-
nccl_all_reduce_for_prefill=args.nccl_all_reduce_for_prefill,
95-
fp8=args.fp8,
96-
moe_pad_experts_for_cuda_graph_inference = args.moe_pad_experts_for_cuda_graph_inference
97-
)
98-
99-
inference_context = StaticInferenceContext.from_config(inference_wrapper_config)
100-
101-
inference_wrapped_model = GPTInferenceWrapper(
102-
model, inference_wrapper_config, inference_context
71+
inference_context = StaticInferenceContext(
72+
args.inference_max_requests, args.inference_max_seq_length
10373
)
74+
inference_wrapped_model = GPTInferenceWrapper(model, inference_context)
10475
text_generation_controller = TextGenerationController(
10576
inference_wrapped_model=inference_wrapped_model, tokenizer=tokenizer
10677
)
10778
engine_kwargs = {
108-
"text_generation_controller" : text_generation_controller,
109-
"legacy" : args.use_legacy_static_engine,
79+
"text_generation_controller": text_generation_controller,
80+
"legacy": args.use_legacy_static_engine,
11081
}
11182
if not args.use_legacy_static_engine:
11283
engine_kwargs["buffer_size_gb"] = args.inference_dynamic_batching_buffer_size_gb
@@ -165,22 +136,7 @@ def main():
165136

166137
args = get_args()
167138

168-
if args.max_batch_size is not None:
169-
warnings.warn(
170-
f"`--max-batch-size` has been deprecated in favor of `--inference-max-requests`."
171-
)
172-
args.inference_max_batch_size = max(args.max_batch_size, args.inference_max_batch_size)
173-
174-
# Set up model and load checkpoint
175-
if args.model_provider == "gpt":
176-
model_builder = gpt_builder
177-
elif args.model_provider == "mamba":
178-
model_builder = mamba_builder
179-
else:
180-
raise ValueError(f"Invalid model provider {args.model_provider}")
181-
model = get_model(partial(model_provider, model_builder), wrap_with_ddp=False)
182-
load_checkpoint(model, None, None, strict=False)
183-
model = model[0]
139+
model = get_model_for_inference()
184140

185141
inference_engine = get_inference_engine(args, model)
186142

@@ -276,7 +232,7 @@ def main():
276232
)
277233
),
278234
len(requests),
279-
args.inference_max_batch_size,
235+
args.inference_max_requests,
280236
stats["allocated_bytes.all.peak"] / (1024**3),
281237
stats["reserved_bytes.all.peak"] / (1024**3),
282238
latency,
@@ -293,6 +249,5 @@ def main():
293249
torch.distributed.destroy_process_group()
294250

295251

296-
297252
if __name__ == "__main__":
298253
main()

0 commit comments

Comments
 (0)