1515# specific language governing permissions and limitations
1616# under the License.
1717#
18- from typing import Callable , Optional , List , Dict , Any
18+ from typing import Any , Callable , Dict , List , Optional
19+
1920import torch
2021
22+
2123class Request :
2224 def __init__ (
2325 self ,
2426 id : int ,
2527 all_input_ids : torch .Tensor ,
2628 max_new_steps : int = 96 ,
2729 post_inference_fn : Optional [Callable ] = None ,
28- chunk_size : int = 96 , # token size, how many time steps a token has
30+ chunk_size : int = 96 , # token size, how many time steps a token has
2931 ** model_kwargs ,
30- ):
32+ ):
3133 if all_input_ids .ndim == 1 :
3234 all_input_ids = all_input_ids .unsqueeze (0 )
3335
3436 self .id = id
3537 self .all_input_ids = all_input_ids
3638 self .model_kwargs = model_kwargs
37- self .max_new_steps = max_new_steps # Number of time steps to generate
39+ self .max_new_steps = max_new_steps # Number of time steps to generate
3840 self .chunk_size = chunk_size
3941 self .post_inference_fn = post_inference_fn
4042
4143 self .batch_size = all_input_ids .size (0 )
42- self .state = ' waiting'
43- self .cur_step_idx = 0 # Current write position in the output step index
44+ self .state = " waiting"
45+ self .cur_step_idx = 0 # Current write position in the output step index
4446
4547 # Preallocate output buffer [batch_size, max_new_tokens]
4648 device = all_input_ids .device
47- self .output_tensor = torch .zeros (self .batch_size , max_new_steps , device = device ) # shape: [self.batch_size, max_new_steps]
49+ self .output_tensor = torch .zeros (
50+ self .batch_size , max_new_steps , device = device
51+ ) # shape: [self.batch_size, max_new_steps]
4852
4953 def mark_running (self ):
50- self .state = ' running'
54+ self .state = " running"
5155
5256 def mark_finished (self ):
53- self .state = ' finished'
57+ self .state = " finished"
5458
5559 def is_finished (self ) -> bool :
5660 return self .cur_step_idx >= self .max_new_steps
@@ -66,25 +70,26 @@ def write_step_output(self, step_output: torch.Tensor):
6670
6771 if end_idx > self .max_new_steps :
6872 # raise ValueError(f"write_step_output exceeds allocated output space: {end_idx} > {self.max_new_steps}")
69- self .output_tensor [:, self .cur_step_idx :] = step_output [:, :self .max_new_steps - self .cur_step_idx ]
73+ self .output_tensor [:, self .cur_step_idx :] = step_output [
74+ :, : self .max_new_steps - self .cur_step_idx
75+ ]
7076 self .cur_step_idx = self .max_new_steps
7177 else :
72- self .output_tensor [:, self .cur_step_idx : end_idx ] = step_output
78+ self .output_tensor [:, self .cur_step_idx : end_idx ] = step_output
7379 self .cur_step_idx = end_idx
7480
7581 if self .is_finished ():
7682 self .mark_finished ()
7783
7884 def get_final_output (self ) -> torch .Tensor :
79- return self .output_tensor [:, :self .cur_step_idx ]
85+ return self .output_tensor [:, : self .cur_step_idx ]
8086
8187 def run_post_inference_fn (self ) -> Optional [torch .Tensor ]:
8288 if self .post_inference_fn is not None :
8389 return self .post_inference_fn (self .get_final_output ())
8490 return self .get_final_output ()
8591
8692 def reset (self ):
87- self .state = ' waiting'
93+ self .state = " waiting"
8894 self .cur_step_idx = 0
8995 self .output_tensor .zero_ ()
90-
0 commit comments