|
4 | 4 | import rpyc |
5 | 5 | import torch |
6 | 6 | import socket |
7 | | -import time |
8 | 7 | from datetime import timedelta |
9 | 8 | from typing import Dict, List, Tuple, Callable, Optional |
10 | 9 | from transformers.configuration_utils import PretrainedConfig |
@@ -251,81 +250,6 @@ def _post_handle( |
251 | 250 | is_chuncked_mode: bool, |
252 | 251 | do_filter_finished_reqs: bool, |
253 | 252 | extra_post_req_handle_func: Optional[Callable[[InferReq, int, float], None]] = None, |
254 | | - ) -> List[int]: |
255 | | - """ |
256 | | - extra_post_req_handle_func 用于提供在一个请求确定输出的时候,给出额外的后处理操作,主要是用于 |
257 | | - 约束输出等模式,设置自己请求内部的状态机的状态,并添加额外的停止判定条件等。 |
258 | | - """ |
259 | | - if not hasattr(self, "_post_handle_impl"): |
260 | | - try: |
261 | | - finished_req_ids = self._fast_post_handle( |
262 | | - run_reqs, |
263 | | - next_token_ids, |
264 | | - next_token_logprobs, |
265 | | - is_chuncked_mode, |
266 | | - do_filter_finished_reqs, |
267 | | - extra_post_req_handle_func, |
268 | | - ) |
269 | | - self._post_handle_impl = self._fast_post_handle |
270 | | - self.logger.info("use _fast_post_handle") |
271 | | - return finished_req_ids |
272 | | - except: |
273 | | - finished_req_ids = self._python_post_handle( |
274 | | - run_reqs, |
275 | | - next_token_ids, |
276 | | - next_token_logprobs, |
277 | | - is_chuncked_mode, |
278 | | - do_filter_finished_reqs, |
279 | | - extra_post_req_handle_func, |
280 | | - ) |
281 | | - self.logger.info("use _python_post_handle") |
282 | | - self._post_handle_impl = self._python_post_handle |
283 | | - return finished_req_ids |
284 | | - else: |
285 | | - return self._post_handle_impl( |
286 | | - run_reqs, |
287 | | - next_token_ids, |
288 | | - next_token_logprobs, |
289 | | - is_chuncked_mode, |
290 | | - do_filter_finished_reqs, |
291 | | - extra_post_req_handle_func, |
292 | | - ) |
293 | | - |
294 | | - def _fast_post_handle( |
295 | | - self, |
296 | | - run_reqs: List[InferReq], |
297 | | - next_token_ids, |
298 | | - next_token_logprobs, |
299 | | - is_chuncked_mode: bool, |
300 | | - do_filter_finished_reqs: bool, |
301 | | - extra_post_req_handle_func: Optional[Callable[[InferReq, int, float], None]] = None, |
302 | | - ): |
303 | | - from . import cython_fast_impl |
304 | | - |
305 | | - start = time.time() |
306 | | - finished_req_ids = cython_fast_impl.fast_post_handle( |
307 | | - self, |
308 | | - run_reqs, |
309 | | - next_token_ids, |
310 | | - next_token_logprobs, |
311 | | - is_chuncked_mode, |
312 | | - do_filter_finished_reqs, |
313 | | - extra_post_req_handle_func, |
314 | | - ) |
315 | | - cost_time = time.time() - start |
316 | | - if self.is_master_in_dp and cost_time > 0.001: |
317 | | - self.logger.info(f"post handle cost time {cost_time} s, batch_size: {len(run_reqs)}") |
318 | | - return finished_req_ids |
319 | | - |
320 | | - # 一些可以复用的通用功能函数 |
321 | | - def _python_post_handle( |
322 | | - self, |
323 | | - run_reqs: List[InferReq], |
324 | | - next_token_ids, |
325 | | - next_token_logprobs, |
326 | | - is_chuncked_mode: bool, |
327 | | - do_filter_finished_reqs: bool, |
328 | | - extra_post_req_handle_func: Optional[Callable[[InferReq, int, float], None]] = None, |
329 | 253 | ) -> List[int]: |
330 | 254 | """ |
331 | 255 | extra_post_req_handle_func 用于提供在一个请求确定输出的时候,给出额外的后处理操作,主要是用于 |
|
0 commit comments