Skip to content

Commit 0112e8a

Browse files
authored
fix the batch generation issue (#308)
1 parent 8550599 commit 0112e8a

File tree

1 file changed

+75
-28
lines changed

1 file changed

+75
-28
lines changed

infscale/module/model_metadata.py

Lines changed: 75 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -198,27 +198,63 @@ def inner(
198198
seqno: int, outputs: dict[str, Tensor], attention_mask: Tensor
199199
) -> dict[str, Tensor]:
200200
next_token_logits = outputs["logits"][:, -1, :]
201+
batch_size = next_token_logits.size(0)
201202
device = next_token_logits.device
202203

203-
if self.do_sample:
204-
# Apply temperature
205-
next_token_logits = next_token_logits / self.temperature
204+
state = self.generated_tokens.get(seqno)
205+
if state is None:
206+
state = {
207+
"tokens": [[] for _ in range(batch_size)],
208+
"finished": [False] * batch_size,
209+
}
210+
self.generated_tokens[seqno] = state
206211

207-
# Apply top_p (nucleus) sampling
208-
sorted_logits, sorted_indices = torch.sort(
209-
next_token_logits, descending=True
210-
)
211-
cumulative_probs = torch.cumsum(
212-
torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1
212+
gen_tokens = state["tokens"]
213+
finished = state["finished"]
214+
215+
if len(gen_tokens) != batch_size:
216+
raise ValueError(
217+
"Mismatched batch size between cached tokens and new logits"
213218
)
214219

215-
sorted_indices_to_remove = cumulative_probs > self.top_p
216-
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
217-
..., :-1
218-
].clone()
219-
sorted_indices_to_remove[..., 0] = 0
220-
indices_to_remove = sorted_indices[sorted_indices_to_remove]
221-
next_token_logits[:, indices_to_remove] = float("-inf")
220+
if self.do_sample:
221+
if self.temperature > 0.0:
222+
# Apply temperature
223+
next_token_logits = next_token_logits / self.temperature
224+
else:
225+
self.do_sample = False
226+
logger.debug(f"temperature is 0.0, switching to greedy decoding")
227+
228+
eos_token_id = self.eos_token_id
229+
230+
for idx, is_finished in enumerate(finished):
231+
if is_finished:
232+
next_token_logits[idx] = float("-inf")
233+
next_token_logits[idx, eos_token_id] = 0.0
234+
235+
if self.do_sample:
236+
if 0.0 < self.top_p < 1.0:
237+
sorted_logits, sorted_indices = torch.sort(
238+
next_token_logits, descending=True
239+
)
240+
cumulative_probs = torch.cumsum(
241+
torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1
242+
)
243+
sorted_indices_to_remove = cumulative_probs > self.top_p
244+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
245+
..., :-1
246+
].clone()
247+
sorted_indices_to_remove[..., 0] = False
248+
249+
for batch_idx in range(batch_size):
250+
if finished[batch_idx]:
251+
continue
252+
indices_to_remove = sorted_indices[batch_idx][
253+
sorted_indices_to_remove[batch_idx]
254+
]
255+
next_token_logits[batch_idx, indices_to_remove] = float(
256+
"-inf"
257+
)
222258

223259
# Sample from the filtered distribution
224260
next_token_probs = torch.nn.functional.softmax(
@@ -229,20 +265,31 @@ def inner(
229265
# Greedy decoding
230266
next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
231267

232-
if seqno not in self.generated_tokens:
233-
self.generated_tokens[seqno] = [[] for _ in range(len(next_token))]
234-
gen_tokens = self.generated_tokens[seqno]
268+
for i in range(batch_size):
269+
if finished[i]:
270+
next_token[i, 0] = eos_token_id
271+
continue
272+
273+
token_id = next_token[i, 0].item()
274+
gen_tokens[i].append(token_id)
275+
276+
if token_id == eos_token_id or len(gen_tokens[i]) >= self.max_new_tokens:
277+
finished[i] = True
278+
if token_id != eos_token_id:
279+
next_token[i, 0] = eos_token_id
280+
281+
if all(finished):
282+
max_length = max(len(tokens) for tokens in gen_tokens)
283+
padded_tokens = []
284+
for tokens in gen_tokens:
285+
if len(tokens) < max_length:
286+
tokens = tokens + [eos_token_id] * (max_length - len(tokens))
287+
padded_tokens.append(tokens)
235288

236-
for i, token in enumerate(next_token):
237-
gen_tokens[i].append(token.item())
289+
tensor = torch.tensor(padded_tokens, dtype=torch.int64, device=device)
290+
if tensor.size(0) == 1:
291+
tensor = tensor[0]
238292

239-
# Check for EOS token or if max number of tokens are generated
240-
if (
241-
self.max_new_tokens == len(gen_tokens[0])
242-
or next_token[0].item() == self.eos_token_id
243-
):
244-
gen_tokens = gen_tokens if len(gen_tokens) > 1 else gen_tokens[0]
245-
tensor = torch.tensor(gen_tokens, dtype=torch.int64, device=device)
246293
del self.generated_tokens[seqno]
247294

248295
return {"tokens": tensor}

0 commit comments

Comments
 (0)