2121import torch
2222
2323from ainode .core .inference .strategy .abstract_strategy import AbstractStrategy
24+ from ainode .core .log import Logger
25+
26+ logger = Logger ()
2427
2528
2629class InferenceRequestState :
@@ -32,7 +35,7 @@ class InferenceRequestState:
3235class InferenceRequest :
3336 def __init__ (
3437 self ,
35- req_id : int ,
38+ req_id : str ,
3639 inputs : torch .Tensor ,
3740 strategy : AbstractStrategy ,
3841 max_new_tokens : int = 96 ,
@@ -41,7 +44,7 @@ def __init__(
4144 if inputs .ndim == 1 :
4245 inputs = inputs .unsqueeze (0 )
4346
44- self .id = req_id
47+ self .req_id = req_id
4548 self .inputs = inputs
4649 self .infer_kwargs = infer_kwargs
4750 self .strategy = strategy
@@ -59,9 +62,6 @@ def __init__(
5962 self .batch_size , max_new_tokens , device = device
6063 ) # shape: [self.batch_size, max_new_steps]
6164
62- self ._lock = threading .Lock ()
63- self ._condition = threading .Condition (self ._lock )
64-
6565 def mark_running (self ):
6666 self .state = InferenceRequestState .RUNNING
6767
@@ -75,34 +75,45 @@ def is_finished(self) -> bool:
7575 )
7676
7777 def write_step_output (self , step_output : torch .Tensor ):
78- with self ._lock :
79- if step_output .ndim == 1 :
80- step_output = step_output .unsqueeze (0 )
78+ if step_output .ndim == 1 :
79+ step_output = step_output .unsqueeze (0 )
8180
82- batch_size , step_size = step_output .shape
83- end_idx = self .cur_step_idx + step_size
81+ batch_size , step_size = step_output .shape
82+ end_idx = self .cur_step_idx + step_size
8483
85- if end_idx > self .max_new_tokens :
86- self .output_tensor [:, self .cur_step_idx :] = step_output [
87- :, : self .max_new_tokens - self .cur_step_idx
88- ]
89- self .cur_step_idx = self .max_new_tokens
90- else :
91- self .output_tensor [:, self .cur_step_idx : end_idx ] = step_output
92- self .cur_step_idx = end_idx
84+ if end_idx > self .max_new_tokens :
85+ self .output_tensor [:, self .cur_step_idx :] = step_output [
86+ :, : self .max_new_tokens - self .cur_step_idx
87+ ]
88+ self .cur_step_idx = self .max_new_tokens
89+ else :
90+ self .output_tensor [:, self .cur_step_idx : end_idx ] = step_output
91+ self .cur_step_idx = end_idx
9392
94- if self .is_finished ():
95- self .mark_finished ()
93+ if self .is_finished ():
94+ self .mark_finished ()
9695
9796 def get_final_output (self ) -> torch .Tensor :
98- with self ._lock :
99- return self .output_tensor [:, : self .cur_step_idx ]
97+ return self .output_tensor [:, : self .cur_step_idx ]
98+
99+
100+ class InferenceRequestProxy :
101+ """
102+ Wrap the raw request for handling multiprocess processing.
103+ """
104+
105+ def __init__ (self , req_id : str ):
106+ self .req_id = req_id
107+ self .result = None
108+ self ._lock = threading .Lock ()
109+ self ._condition = threading .Condition (self ._lock )
100110
101- def notify_completion (self ):
111+ def set_result (self , result : Any ):
102112 with self ._lock :
113+ self .result = result
103114 self ._condition .notify_all ()
104115
105116 def wait_for_completion (self ) -> Any :
106117 with self ._lock :
107- while self .state != InferenceRequestState . FINISHED :
108- self ._condition . wait ()
118+ self ._condition . wait ()
119+ return self .result
0 commit comments