22
33import asyncio
44import json
5- import logging
65import os
76import time
8- import warnings
7+ import torch
8+ import torch .distributed as dist
99from collections import defaultdict
10+ from tqdm import tqdm
1011from typing import List
12+ import warnings
13+ import logging
1114
12- import torch
13- import torch .distributed as dist
15+ from examples .inference .gpt .gpt_dynamic_inference import (
16+ add_dynamic_inference_args ,
17+ get_inference_context ,
18+ get_inference_controller ,
19+ get_model ,
20+ )
21+ from examples .inference .gpt .utils import (
22+ Request ,
23+ build_dynamic_engine_setup_prefix ,
24+ build_requests ,
25+ add_common_inference_args
26+ )
1427
15- from examples . inference . gpt . utils import Request , build_dynamic_engine_setup_prefix , build_requests
28+ from megatron . core import parallel_state
1629from megatron .core .inference .engines import DynamicInferenceEngine
1730from megatron .core .inference .inference_client import InferenceClient
1831from megatron .core .inference .inference_request import DynamicInferenceRequestRecord
1932from megatron .core .inference .sampling_params import SamplingParams
20- from megatron .inference .utils import (
21- add_inference_args ,
22- get_dynamic_inference_engine ,
23- get_model_for_inference ,
24- )
33+ from megatron .core .utils import get_mamba_inference_state_config_from_model
34+
2535from megatron .training import get_args , get_tokenizer , initialize_megatron
36+ from megatron .training .arguments import parse_args
2637
2738# pylint: disable=line-too-long
2839
2940logging .basicConfig (level = logging .INFO , force = True )
3041
31-
3242async def main (
3343 engine : DynamicInferenceEngine ,
3444 requests : List [Request ],
@@ -41,11 +51,12 @@ async def main(
4151 "Sampling parameters are specified per request." ,
4252 DeprecationWarning ,
4353 )
44-
54+
4555 # once you call engine.start_listening_to_data_parallel_coordinator,
4656 # the engine will start accepting requests from the data parallel coordinator.
4757 # and processing them in an asyncio coroutine.
4858 # leaving inference_coordinator_port as None will find a free port automatically.
59+
4960 dp_addr = await engine .start_listening_to_data_parallel_coordinator (
5061 inference_coordinator_port = port ,
5162 launch_inference_coordinator = True ,
@@ -58,11 +69,14 @@ async def main(
5869 # Since the client doesn't directly call engine.async_step here, we test
5970 # the suspend-resume system ~4 times.
6071 suspend_resume_interval = max (1 , len (requests ) // 4 )
61- suspend_idxs = set (
62- range (suspend_resume_interval , len (requests ) + 1 , suspend_resume_interval )
63- )
72+ suspend_idxs = set (range (
73+ suspend_resume_interval ,
74+ len (requests ) + 1 ,
75+ suspend_resume_interval ,
76+ ))
6477 resume_idxs = set (
65- min (len (requests ), i + suspend_resume_interval // 2 ) for i in suspend_idxs
78+ min (len (requests ), i + suspend_resume_interval // 2 )
79+ for i in suspend_idxs
6680 )
6781 else :
6882 suspend_idxs = set ()
@@ -84,10 +98,7 @@ async def main(
8498 current_time = time .time_ns () / 10 ** 9
8599 if args .incoming_requests_per_step is None :
86100 # Only add requests that have arrived at the current time.
87- while (
88- num_requests_added < num_requests_total
89- and requests [num_requests_added ].time_arrival <= current_time
90- ):
101+ while num_requests_added < num_requests_total and requests [num_requests_added ].time_arrival <= current_time :
91102 request = requests [num_requests_added ]
92103 # These add-request calls will queue up the request on a zmq socket and return
93104 # instantaneously. They will return an asyncio future which can be awaited for
@@ -103,9 +114,10 @@ async def main(
103114
104115 else :
105116 # Add deterministic number of requests (generally used for debugging).
106- for i in range (
107- min (args .incoming_requests_per_step , num_requests_total - num_requests_added )
108- ):
117+ for i in range (min (
118+ args .incoming_requests_per_step ,
119+ num_requests_total - num_requests_added
120+ )):
109121 # Change sampling parameters to force different generation lengths.
110122 request = requests [num_requests_added ]
111123 n = request .sampling_params .num_tokens_to_generate
@@ -123,7 +135,7 @@ async def main(
123135 break
124136 # Relinquish control since there are no more requests to add at the moment. This allows the engine to run.
125137 await asyncio .sleep (0 )
126-
138+
127139 # While we wait for the requests to complete, the engine runs in the background.
128140 results : List [DynamicInferenceRequestRecord ] = await asyncio .gather (* futures )
129141
@@ -161,19 +173,16 @@ async def main(
161173 req = record .merge ()
162174 unique_prompt_map [req .prompt ].append (req )
163175 for idx , (prompt_text , reqs ) in enumerate (unique_prompt_map .items ()):
164- print (
165- f"%d/%d. prompt '%s' ... [%d] output '%s'."
166- % (
167- idx ,
168- len (unique_prompt_map ),
169- prompt_text .replace ("\n " , "\\ n" ),
170- len (reqs ),
171- reqs [0 ].generated_text .replace ("\n " , "\\ n" ),
172- )
173- )
176+ print (f"%d/%d. prompt '%s' ... [%d] output '%s'." % (
177+ idx ,
178+ len (unique_prompt_map ),
179+ prompt_text .replace ("\n " , "\\ n" ),
180+ len (reqs ),
181+ reqs [0 ].generated_text .replace ("\n " , "\\ n" ),
182+ ))
174183
175184 # kill the engines and suspend the client
176- # Right now, we can only call stop when all requests are done.
185+ # Right now, we can only call stop when all requests are done.
177186 # Todo: Make this explicit in the Client class....
178187 await client .stop_engines ()
179188 client .stop ()
@@ -184,11 +193,11 @@ async def main(
184193
185194
186195if __name__ == "__main__" :
187- # enable inference mode in the very beginning as some fp8 optimizations
196+ # enable inference mode in the very beginning as some fp-8 optimizations
188197 # check for it.
189198 with torch .inference_mode ():
190199 initialize_megatron (
191- extra_args_provider = add_inference_args ,
200+ extra_args_provider = add_dynamic_inference_args ,
192201 args_defaults = {'no_load_rng' : True , 'no_load_optim' : True },
193202 )
194203
@@ -207,16 +216,34 @@ async def main(
207216 ),
208217 )
209218
210- model = get_model_for_inference ()
211-
219+ # Requests, context, conroller.
220+ model = get_model ()
221+ mamba_inference_state_config = get_mamba_inference_state_config_from_model (model )
212222 requests = (
213223 build_requests (args , tokenizer , sampling_params ) if dist .get_rank () == 0 else None
214224 )
215225
216- engine = get_dynamic_inference_engine (model = model )
226+ context = get_inference_context (
227+ None ,
228+ None ,
229+ calculate_max_sequence_length_from_requests = False ,
230+ mamba_inference_state_config = mamba_inference_state_config ,
231+ )
232+
233+ controller = get_inference_controller (model , context )
234+
235+ # Inference engine.
236+ engine = DynamicInferenceEngine (
237+ controller ,
238+ context ,
239+ enable_cuda_graph = args .cuda_graph_impl == "local" ,
240+ random_seed = args .seed ,
241+ enable_chunked_prefill = not args .disable_chunked_prefill ,
242+ inference_logging_step_interval = args .inference_logging_step_interval ,
243+ )
217244
218245 if dist .get_rank () == 0 :
219- setup_prefix = build_dynamic_engine_setup_prefix (args , model , engine . context , requests )
246+ setup_prefix = build_dynamic_engine_setup_prefix (args , model , context , requests )
220247 print ("~~~" )
221248 print (setup_prefix )
222249 print ("~~~" )
@@ -225,7 +252,13 @@ async def main(
225252 if os .environ .get ("NSIGHT_PREFIX" ):
226253 torch .cuda .cudart ().cudaProfilerStart ()
227254
228- asyncio .run (main (engine , requests , args .inference_coordinator_port ))
255+ asyncio .run (
256+ main (
257+ engine ,
258+ requests ,
259+ args .inference_coordinator_port ,
260+ )
261+ )
229262
230263 # Stop Nsight profiler.
231264 if os .environ .get ("NSIGHT_PREFIX" ):
0 commit comments