Skip to content

Commit 7aaceee

Browse files
Fix error due in Collating queries with different continuation lengths (fixes #2984) (#2987)
* FIX error due to grouping queries with different continuation length Make Collator choose query with the longest continuation as the candidate for generation * use max for key selection * added comments explaining variable cont length (identical ctx+cont[:-1]) --------- Co-authored-by: Baber <[email protected]>
1 parent 357d4ea commit 7aaceee

File tree

2 files changed

+13
-4
lines changed

2 files changed

+13
-4
lines changed

lm_eval/models/huggingface.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1136,7 +1136,7 @@ def _lookup_one_token_cont(req: Tuple[Tuple[str, str], List[int], List[int]]):
11361136
if self.backend == "causal":
11371137
total_length = len(context_enc) + len(continuation_enc)
11381138
if total_length > self.max_length + 1:
1139-
eval_logger.warn(
1139+
eval_logger.warning(
11401140
f"Combined length of context ({len(context_enc)}) and continuation ({len(continuation_enc)}) "
11411141
f"exceeds model's maximum length ({self.max_length}). "
11421142
f"Truncating {total_length - self.max_length + 1} tokens from the left."
@@ -1247,7 +1247,12 @@ def _lookup_one_token_cont(req: Tuple[Tuple[str, str], List[int], List[int]]):
12471247
cont_toks = torch.tensor(
12481248
cont_toks, dtype=torch.long, device=self.device
12491249
).unsqueeze(0) # [1, seq]
1250-
max_equal = (greedy_tokens == cont_toks).all()
1250+
# Use trailing slice [-cont_toks.shape[1]:] to handle variable length cont_len (but same ctx+cont[:-1]).
1251+
# i.e. continuations can be sliced at diff points. Collator ensures we have sufficient greedy_tokens
1252+
# by choosing key with longest cont if group_by="contexts".
1253+
max_equal = (
1254+
greedy_tokens[:, -cont_toks.shape[1] :] == cont_toks
1255+
).all()
12511256

12521257
# Obtain log-probs at the corresponding continuation token indices
12531258
# last_token_slice = logits[:, -1, :].squeeze(0).tolist()

lm_eval/models/utils.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -428,9 +428,13 @@ def get_batched(self, n: int = 1, batch_fn: Optional[Callable] = None) -> Iterat
428428
batch = self.get_chunks(values, n=n, fn=batch_fn)
429429
yield from batch
430430
elif self._group_by == "contexts":
431-
# Get one sample from each key
431+
# Get one sample from each key.
432+
# Select longest continuation per group to ensure sufficient context logits
432433
values = self._reorder(
433-
[value[0] for value in self._arr_with_indices.values()]
434+
[
435+
max(value, key=lambda x: len(x[1][-1]))
436+
for value in self._arr_with_indices.values()
437+
]
434438
)
435439
batch = self.get_chunks(values, n=n, fn=batch_fn)
436440
yield from batch

0 commit comments

Comments
 (0)