@@ -160,47 +160,59 @@ def _poll_weight_update(self) -> None:
160160 def _run (self ) -> None :
161161 self ._init_weight_sync ()
162162
163- while not self ._shutdown_event .is_set ():
164- # Poll for weight updates between batches (non-blocking)
165- self ._poll_weight_update ()
166-
167- self .transport .wait_for_work (timeout = self .timeout )
168-
163+ try :
164+ while not self ._shutdown_event .is_set ():
165+ self ._poll_weight_update ()
166+
167+ self .transport .wait_for_work (timeout = self .timeout )
168+
169+ items , callbacks = self .transport .drain (self .max_batch_size )
170+ if not items :
171+ continue
172+
173+ # Accumulate up to min_batch_size (or until timeout expires)
174+ if len (items ) < self .min_batch_size :
175+ deadline = time .monotonic () + self .timeout
176+ while len (items ) < self .min_batch_size :
177+ remaining = deadline - time .monotonic ()
178+ if remaining <= 0 :
179+ break
180+ self .transport .wait_for_work (timeout = remaining )
181+ more_items , more_cbs = self .transport .drain (
182+ self .max_batch_size - len (items )
183+ )
184+ items .extend (more_items )
185+ callbacks .extend (more_cbs )
186+
187+ batch = self .collate_fn (items )
188+ if self .device is not None :
189+ batch = batch .to (self .device )
190+
191+ try :
192+ with self ._model_lock :
193+ results = self .model (batch ).unbind (0 )
194+ if len (results ) != len (callbacks ):
195+ raise RuntimeError (
196+ f"Model returned { len (results )} results for a "
197+ f"batch of { len (callbacks )} inputs."
198+ )
199+ for cb , res in zip (callbacks , results ):
200+ self .transport .resolve (cb , res )
201+ except Exception as exc :
202+ for cb in callbacks :
203+ self .transport .resolve_exception (cb , exc )
204+ finally :
205+ self ._drain_pending_on_shutdown ()
206+
207+ def _drain_pending_on_shutdown (self ) -> None :
208+ """Resolve all pending requests with an error during shutdown."""
209+ shutdown_exc = RuntimeError ("InferenceServer is shutting down." )
210+ while True :
169211 items , callbacks = self .transport .drain (self .max_batch_size )
170212 if not items :
171- continue
172-
173- # Accumulate up to min_batch_size (or until timeout expires)
174- if len (items ) < self .min_batch_size :
175- deadline = time .monotonic () + self .timeout
176- while len (items ) < self .min_batch_size :
177- remaining = deadline - time .monotonic ()
178- if remaining <= 0 :
179- break
180- self .transport .wait_for_work (timeout = remaining )
181- more_items , more_cbs = self .transport .drain (
182- self .max_batch_size - len (items )
183- )
184- items .extend (more_items )
185- callbacks .extend (more_cbs )
186-
187- batch = self .collate_fn (items )
188- if self .device is not None :
189- batch = batch .to (self .device )
190-
191- try :
192- with self ._model_lock :
193- results = self .model (batch ).unbind (0 )
194- if len (results ) != len (callbacks ):
195- raise RuntimeError (
196- f"Model returned { len (results )} results for a "
197- f"batch of { len (callbacks )} inputs."
198- )
199- for cb , res in zip (callbacks , results ):
200- self .transport .resolve (cb , res )
201- except Exception as exc :
202- for cb in callbacks :
203- self .transport .resolve_exception (cb , exc )
213+ break
214+ for cb in callbacks :
215+ self .transport .resolve_exception (cb , shutdown_exc )
204216
205217 # -- context manager ------------------------------------------------------
206218
0 commit comments