Skip to content

Commit d8eaf0d

Browse files
committed
simplify vllm preprocessing input ids
1 parent 0472f44 commit d8eaf0d

File tree

1 file changed

+4
-6
lines changed

1 file changed

+4
-6
lines changed

applications/ColossalChat/coati/distributed/inference_backend.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -212,13 +212,11 @@ def __init__(
212212
def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
213213
micro_batch_size = input_ids.size(0)
214214
response_start_idx = input_ids.size(1)
215+
first_non_padding_token_idx = (input_ids != self.tokenizer.pad_token_id).int().argmax(dim=1)
215216
micro_batch_input_ids = input_ids.tolist()
216-
micro_batch_input_ids_no_padding = []
217-
for i in range(micro_batch_size):
218-
for j in range(input_ids.size(1)):
219-
if micro_batch_input_ids[i][j] != self.tokenizer.pad_token_id:
220-
micro_batch_input_ids_no_padding.append(micro_batch_input_ids[i][j:])
221-
break
217+
micro_batch_input_ids_no_padding = [
218+
micro_batch_input_ids[i][first_non_padding_token_idx[i] :] for i in range(micro_batch_size)
219+
]
222220
outputs = self.llm.generate(
223221
prompt_token_ids=micro_batch_input_ids_no_padding, sampling_params=self.generate_config, use_tqdm=False
224222
)

0 commit comments

Comments
 (0)