Skip to content

Commit 4503fe6

Browse files
committed
fix code style
1 parent 05fb207 commit 4503fe6

File tree

6 files changed

+195
-120
lines changed

6 files changed

+195
-120
lines changed

iotdb-core/ainode/ainode/core/inference/request.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,42 +15,46 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717
#
18-
from typing import Callable, Optional, List, Dict, Any
18+
from typing import Any, Callable, Dict, List, Optional
19+
1920
import torch
2021

22+
2123
class Request:
2224
def __init__(
2325
self,
2426
id: int,
2527
all_input_ids: torch.Tensor,
2628
max_new_steps: int = 96,
2729
post_inference_fn: Optional[Callable] = None,
28-
chunk_size: int = 96, # token size, how many time steps a token has
30+
chunk_size: int = 96, # token size, how many time steps a token has
2931
**model_kwargs,
30-
):
32+
):
3133
if all_input_ids.ndim == 1:
3234
all_input_ids = all_input_ids.unsqueeze(0)
3335

3436
self.id = id
3537
self.all_input_ids = all_input_ids
3638
self.model_kwargs = model_kwargs
37-
self.max_new_steps = max_new_steps # Number of time steps to generate
39+
self.max_new_steps = max_new_steps # Number of time steps to generate
3840
self.chunk_size = chunk_size
3941
self.post_inference_fn = post_inference_fn
4042

4143
self.batch_size = all_input_ids.size(0)
42-
self.state = 'waiting'
43-
self.cur_step_idx = 0 # Current write position in the output step index
44+
self.state = "waiting"
45+
self.cur_step_idx = 0 # Current write position in the output step index
4446

4547
# Preallocate output buffer [batch_size, max_new_tokens]
4648
device = all_input_ids.device
47-
self.output_tensor = torch.zeros(self.batch_size, max_new_steps, device=device) # shape: [self.batch_size, max_new_steps]
49+
self.output_tensor = torch.zeros(
50+
self.batch_size, max_new_steps, device=device
51+
) # shape: [self.batch_size, max_new_steps]
4852

4953
def mark_running(self):
50-
self.state = 'running'
54+
self.state = "running"
5155

5256
def mark_finished(self):
53-
self.state = 'finished'
57+
self.state = "finished"
5458

5559
def is_finished(self) -> bool:
5660
return self.cur_step_idx >= self.max_new_steps
@@ -66,25 +70,26 @@ def write_step_output(self, step_output: torch.Tensor):
6670

6771
if end_idx > self.max_new_steps:
6872
# raise ValueError(f"write_step_output exceeds allocated output space: {end_idx} > {self.max_new_steps}")
69-
self.output_tensor[:, self.cur_step_idx:] = step_output[:, :self.max_new_steps - self.cur_step_idx]
73+
self.output_tensor[:, self.cur_step_idx :] = step_output[
74+
:, : self.max_new_steps - self.cur_step_idx
75+
]
7076
self.cur_step_idx = self.max_new_steps
7177
else:
72-
self.output_tensor[:, self.cur_step_idx:end_idx] = step_output
78+
self.output_tensor[:, self.cur_step_idx : end_idx] = step_output
7379
self.cur_step_idx = end_idx
7480

7581
if self.is_finished():
7682
self.mark_finished()
7783

7884
def get_final_output(self) -> torch.Tensor:
79-
return self.output_tensor[:, :self.cur_step_idx]
85+
return self.output_tensor[:, : self.cur_step_idx]
8086

8187
def run_post_inference_fn(self) -> Optional[torch.Tensor]:
8288
if self.post_inference_fn is not None:
8389
return self.post_inference_fn(self.get_final_output())
8490
return self.get_final_output()
8591

8692
def reset(self):
87-
self.state = 'waiting'
93+
self.state = "waiting"
8894
self.cur_step_idx = 0
8995
self.output_tensor.zero_()
90-

0 commit comments

Comments
 (0)