22
33import asyncio
44import json
5+ import logging
56import os
67import time
7- import torch
8- import torch .distributed as dist
8+ import warnings
99from collections import defaultdict
10- from tqdm import tqdm
1110from typing import List
12- import warnings
13- import logging
1411
15- from examples .inference .gpt .gpt_dynamic_inference import (
16- add_dynamic_inference_args ,
17- get_inference_context ,
18- get_inference_controller ,
19- get_model ,
20- )
21- from examples .inference .gpt .utils import (
22- Request ,
23- build_dynamic_engine_setup_prefix ,
24- build_requests ,
25- add_common_inference_args
26- )
12+ import torch
13+ import torch .distributed as dist
2714
28- from megatron . core import parallel_state
15+ from examples . inference . gpt . utils import Request , build_dynamic_engine_setup_prefix , build_requests
2916from megatron .core .inference .engines import DynamicInferenceEngine
3017from megatron .core .inference .inference_client import InferenceClient
3118from megatron .core .inference .inference_request import DynamicInferenceRequestRecord
3219from megatron .core .inference .sampling_params import SamplingParams
33- from megatron .core .utils import get_mamba_inference_state_config_from_model
34-
20+ from megatron .inference .utils import (
21+ add_inference_args ,
22+ get_dynamic_inference_engine ,
23+ get_model_for_inference ,
24+ )
3525from megatron .training import get_args , get_tokenizer , initialize_megatron
36- from megatron .training .arguments import parse_args
3726
3827# pylint: disable=line-too-long
3928
4029logging .basicConfig (level = logging .INFO , force = True )
4130
31+
4232async def main (
4333 engine : DynamicInferenceEngine ,
4434 requests : List [Request ],
@@ -51,12 +41,11 @@ async def main(
5141 "Sampling parameters are specified per request." ,
5242 DeprecationWarning ,
5343 )
54-
44+
5545 # once you call engine.start_listening_to_data_parallel_coordinator,
5646 # the engine will start accepting requests from the data parallel coordinator.
5747 # and processing them in an asyncio coroutine.
5848 # leaving inference_coordinator_port as None will find a free port automatically.
59-
6049 dp_addr = await engine .start_listening_to_data_parallel_coordinator (
6150 inference_coordinator_port = port ,
6251 launch_inference_coordinator = True ,
@@ -69,14 +58,11 @@ async def main(
6958 # Since the client doesn't directly call engine.async_step here, we test
7059 # the suspend-resume system ~4 times.
7160 suspend_resume_interval = max (1 , len (requests ) // 4 )
72- suspend_idxs = set (range (
73- suspend_resume_interval ,
74- len (requests ) + 1 ,
75- suspend_resume_interval ,
76- ))
61+ suspend_idxs = set (
62+ range (suspend_resume_interval , len (requests ) + 1 , suspend_resume_interval )
63+ )
7764 resume_idxs = set (
78- min (len (requests ), i + suspend_resume_interval // 2 )
79- for i in suspend_idxs
65+ min (len (requests ), i + suspend_resume_interval // 2 ) for i in suspend_idxs
8066 )
8167 else :
8268 suspend_idxs = set ()
@@ -98,7 +84,10 @@ async def main(
9884 current_time = time .time_ns () / 10 ** 9
9985 if args .incoming_requests_per_step is None :
10086 # Only add requests that have arrived at the current time.
101- while num_requests_added < num_requests_total and requests [num_requests_added ].time_arrival <= current_time :
87+ while (
88+ num_requests_added < num_requests_total
89+ and requests [num_requests_added ].time_arrival <= current_time
90+ ):
10291 request = requests [num_requests_added ]
10392 # These add-request calls will queue up the request on a zmq socket and return
10493 # instantaneously. They will return an asyncio future which can be awaited for
@@ -114,10 +103,9 @@ async def main(
114103
115104 else :
116105 # Add deterministic number of requests (generally used for debugging).
117- for i in range (min (
118- args .incoming_requests_per_step ,
119- num_requests_total - num_requests_added
120- )):
106+ for i in range (
107+ min (args .incoming_requests_per_step , num_requests_total - num_requests_added )
108+ ):
121109 # Change sampling parameters to force different generation lengths.
122110 request = requests [num_requests_added ]
123111 n = request .sampling_params .num_tokens_to_generate
@@ -135,7 +123,7 @@ async def main(
135123 break
136124 # Relinquish control since there are no more requests to add at the moment. This allows the engine to run.
137125 await asyncio .sleep (0 )
138-
126+
139127 # While we wait for the requests to complete, the engine runs in the background.
140128 results : List [DynamicInferenceRequestRecord ] = await asyncio .gather (* futures )
141129
@@ -170,16 +158,19 @@ async def main(
170158 req = record .merge ()
171159 unique_prompt_map [req .prompt ].append (req )
172160 for idx , (prompt_text , reqs ) in enumerate (unique_prompt_map .items ()):
173- print (f"%d/%d. prompt '%s' ... [%d] output '%s'." % (
174- idx ,
175- len (unique_prompt_map ),
176- prompt_text .replace ("\n " , "\\ n" ),
177- len (reqs ),
178- reqs [0 ].generated_text .replace ("\n " , "\\ n" ),
179- ))
161+ print (
162+ f"%d/%d. prompt '%s' ... [%d] output '%s'."
163+ % (
164+ idx ,
165+ len (unique_prompt_map ),
166+ prompt_text .replace ("\n " , "\\ n" ),
167+ len (reqs ),
168+ reqs [0 ].generated_text .replace ("\n " , "\\ n" ),
169+ )
170+ )
180171
181172 # kill the engines and suspend the client
182- # Right now, we can only call stop when all requests are done.
173+ # Right now, we can only call stop when all requests are done.
183174 # Todo: Make this explicit in the Client class....
184175 await client .stop_engines ()
185176 client .stop ()
@@ -190,11 +181,11 @@ async def main(
190181
191182
192183if __name__ == "__main__" :
193- # enable inference mode in the very beginning as some fp-8 optimizations
184+ # enable inference mode in the very beginning as some fp8 optimizations
194185 # check for it.
195186 with torch .inference_mode ():
196187 initialize_megatron (
197- extra_args_provider = add_dynamic_inference_args ,
188+ extra_args_provider = add_inference_args ,
198189 args_defaults = {'no_load_rng' : True , 'no_load_optim' : True },
199190 )
200191
@@ -213,34 +204,16 @@ async def main(
213204 ),
214205 )
215206
216- # Requests, context, conroller.
217- model = get_model ()
218- mamba_inference_state_config = get_mamba_inference_state_config_from_model (model )
207+ model = get_model_for_inference ()
208+
219209 requests = (
220210 build_requests (args , tokenizer , sampling_params ) if dist .get_rank () == 0 else None
221211 )
222212
223- context = get_inference_context (
224- None ,
225- None ,
226- calculate_max_sequence_length_from_requests = False ,
227- mamba_inference_state_config = mamba_inference_state_config ,
228- )
229-
230- controller = get_inference_controller (model , context )
231-
232- # Inference engine.
233- engine = DynamicInferenceEngine (
234- controller ,
235- context ,
236- enable_cuda_graph = args .cuda_graph_impl == "local" ,
237- random_seed = args .seed ,
238- enable_chunked_prefill = not args .disable_chunked_prefill ,
239- inference_logging_step_interval = args .inference_logging_step_interval ,
240- )
213+ engine = get_dynamic_inference_engine (model = model )
241214
242215 if dist .get_rank () == 0 :
243- setup_prefix = build_dynamic_engine_setup_prefix (args , model , context , requests )
216+ setup_prefix = build_dynamic_engine_setup_prefix (args , model , engine . context , requests )
244217 print ("~~~" )
245218 print (setup_prefix )
246219 print ("~~~" )
@@ -249,13 +222,7 @@ async def main(
249222 if os .environ .get ("NSIGHT_PREFIX" ):
250223 torch .cuda .cudart ().cudaProfilerStart ()
251224
252- asyncio .run (
253- main (
254- engine ,
255- requests ,
256- args .inference_coordinator_port ,
257- )
258- )
225+ asyncio .run (main (engine , requests , args .inference_coordinator_port ))
259226
260227 # Stop Nsight profiler.
261228 if os .environ .get ("NSIGHT_PREFIX" ):
0 commit comments