1616# under the License.
1717#
1818
19+ import gc
1920import random
2021import threading
2122import time
2728
2829from ainode .core .config import AINodeDescriptor
2930from ainode .core .inference .inference_request import InferenceRequest
31+ from ainode .core .inference .scheduler .basic_scheduler import BasicScheduler
3032from ainode .core .log import Logger
3133from ainode .core .manager .model_manager import ModelManager
3234
@@ -61,70 +63,92 @@ def __init__(
6163 self ._model_manager = None
6264 self .device = None
6365
64- # TODO: A scheduler is necessary for better handling following queues
6566 self ._threads = []
6667 self ._waiting_queue = request_queue # Requests that are waiting to be processed
6768 self ._running_queue = mp .Queue () # Requests that are currently being processed
6869 self ._finished_queue = result_queue # Requests that are finished
70+ self ._scheduler = BasicScheduler (
71+ self ._waiting_queue , self ._running_queue , self ._finished_queue , self .pool_id
72+ )
6973 self ._stop_event = mp .Event ()
7074
7175 # Fix inference seed
7276 random .seed (self .FIX_SEED )
7377 torch .manual_seed (self .FIX_SEED )
7478 np .random .seed (self .FIX_SEED )
7579
76- def memory_is_available (self , request ):
77- # need test with several rounds of dummy data
78- pass
80+ def _warm_up_and_estimate_memory (self ):
81+ # TODO: Test per token memory usage, add support for cpu in the future
82+ torch .cuda .empty_cache ()
83+ gc .collect ()
84+ dummy_input = torch .zeros (
85+ (1 , self .config .input_token_len ), dtype = torch .float32
86+ ).to (self .device )
87+
88+ # force cuda synchronization to avoid any asynchronous memory allocation issues
89+ torch .cuda .reset_peak_memory_stats (self .device )
90+ torch .cuda .synchronize (self .device )
91+ memory_before_warmup = torch .cuda .memory_allocated (self .device )
92+ logger .info (
93+ f"[Inference][Device-{ self .device } ][Pool-{ self .pool_id } ] Before warm-up, peak memory usage: { memory_before_warmup :.2f} bytes"
94+ )
7995
80- def _activate_requests (self ):
81- if self ._waiting_queue .empty ():
82- return
83- request : InferenceRequest = self ._waiting_queue .get ()
84- # TODO: Check memory size before activating requests
85- request .inputs = request .inference_pipeline .preprocess_inputs (request .inputs )
86- request .mark_running ()
87- logger .debug (
88- f"[Inference][Device-{ self .device } ][Pool-{ self .pool_id } ][ID-{ request .req_id } ] Request is activated with inputs shape { request .inputs .shape } "
96+ # warm-up
97+ with torch .no_grad ():
98+ self .model .generate (dummy_input , max_new_tokens = 1 )
99+ torch .cuda .synchronize (self .device )
100+ peak_memory_1_token = torch .cuda .max_memory_allocated (self .device )
101+ logger .info (
102+ f"[Inference][Device-{ self .device } ][Pool-{ self .pool_id } ] Baseline memory usage for 1 token: { peak_memory_1_token :.2f} bytes"
103+ )
104+ logger .info (
105+ f"[Inference][Device-{ self .device } ][Pool-{ self .pool_id } ] Differentiation : { peak_memory_1_token - memory_before_warmup :.2f} bytes"
89106 )
90- self ._running_queue .put (request )
107+
108+ def _activate_requests (self ):
109+ requests = self ._scheduler .schedule_activate ()
110+ for request in requests :
111+ request .inputs = request .inference_pipeline .preprocess_inputs (
112+ request .inputs
113+ )
114+ request .mark_running ()
115+ self ._running_queue .put (request )
116+ logger .debug (
117+ f"[Inference][Device-{ self .device } ][Pool-{ self .pool_id } ][ID-{ request .req_id } ] Request is activated with inputs shape { request .inputs .shape } "
118+ )
91119
92120 def _requests_activate_loop (self ):
93121 while not self ._stop_event .is_set ():
94122 time .sleep (self .WAITING_INTERVAL_IN_MS / 1000 )
95123 self ._activate_requests ()
96124
97125 def _step (self ):
98- if self ._running_queue .empty ():
99- return
126+ requests = self ._scheduler .schedule_step ()
100127 # TODO: We need a batcher to accelerate the concurrent inference
101- # TODO: Check memory size before executing requests
102- request : InferenceRequest = self ._running_queue .get ()
103- inputs = request .inputs .to (self .device )
104- output = self .model .generate (
105- inputs ,
106- max_new_tokens = request .max_new_tokens ,
107- num_samples = 10 ,
108- revin = True ,
109- )
110- request .output_tensor = request .output_tensor .to (
111- self .device
112- ) # Ensure output tensor is on the same device
113- request .write_step_output (output [0 ].mean (dim = 0 ))
114- request .inference_pipeline .post_decode ()
115- if request .is_finished ():
116- request .inference_pipeline .post_inference ()
117- logger .debug (
118- f"[Inference][Device-{ self .device } ][Pool-{ self .pool_id } ][ID-{ request .req_id } ] Request is finished"
119- )
120- # ensure the output tensor is on CPU before sending to result queue
121- request .output_tensor = request .output_tensor .cpu ()
122- self ._finished_queue .put (request )
123- else :
124- logger .debug (
125- f"[Inference][Device-{ self .device } ][Pool-{ self .pool_id } ][ID-{ request .req_id } ] Request is not finished, re-queueing"
128+ for request in requests :
129+ request .inputs = request .inputs .to (self .device )
130+ output = self .model .generate (
131+ request .inputs ,
132+ max_new_tokens = request .max_new_tokens ,
133+ num_samples = 10 ,
134+ revin = True ,
126135 )
127- self ._waiting_queue .put (request )
136+ request .output_tensor = request .output_tensor .to (self .device )
137+ request .write_step_output (output [0 ].mean (dim = 0 ))
138+ request .inference_pipeline .post_decode ()
139+ if request .is_finished ():
140+ request .inference_pipeline .post_inference ()
141+ logger .debug (
142+ f"[Inference][Device-{ self .device } ][Pool-{ self .pool_id } ][ID-{ request .req_id } ] Request is finished"
143+ )
144+ # ensure the output tensor is on CPU before sending to result queue
145+ request .output_tensor = request .output_tensor .cpu ()
146+ self ._finished_queue .put (request )
147+ else :
148+ logger .debug (
149+ f"[Inference][Device-{ self .device } ][Pool-{ self .pool_id } ][ID-{ request .req_id } ] Request is not finished, re-queueing"
150+ )
151+ self ._waiting_queue .put (request )
128152
129153 def _requests_execute_loop (self ):
130154 while not self ._stop_event .is_set ():
@@ -134,8 +158,11 @@ def _requests_execute_loop(self):
134158 def run (self ):
135159 self ._model_manager = ModelManager ()
136160 self .device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
161+ self ._scheduler .device = self .device
137162 self .model = self ._model_manager .load_model (self .model_id , {}).to (self .device )
138163
164+ # self._warm_up_and_estimate_memory()
165+
139166 activate_daemon = threading .Thread (
140167 target = self ._requests_activate_loop , daemon = True
141168 )
@@ -151,3 +178,15 @@ def run(self):
151178
152179 def stop (self ):
153180 self ._stop_event .set ()
181+ logger .info (
182+ f"[Inference][Device-{ self .device } ][Pool-{ self .pool_id } ] Stopping and releasing resources."
183+ )
184+ try :
185+ del self .model
186+ if "cuda" in str (self .device ):
187+ torch .cuda .empty_cache ()
188+ gc .collect ()
189+ except Exception as e :
190+ logger .warning (
191+ f"[Inference][Device-{ self .device } ][Pool-{ self .pool_id } ] Failed to clean up: { e } "
192+ )
0 commit comments