Skip to content

Commit 959c825

Browse files
committed
Fix grouping func
1 parent 56c0679 commit 959c825

File tree

3 files changed

+3
-3
lines changed

3 files changed

+3
-3
lines changed

lmms_eval/models/chat/huggingface.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ def _collate(x):
184184
# we group requests by their generation_kwargs,
185185
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
186186
# in the same batch.
187-
re_ords = utils.Collator([reg.args for reg in requests], _collate, grouping=True)
187+
re_ords = utils.Collator([reg.args for reg in requests], _collate, group_fn=lambda x: x[2], grouping=True)
188188
chunks = re_ords.get_batched(n=self.batch_size, batch_fn=None)
189189
num_iters = len(requests) // self.batch_size if len(requests) % self.batch_size == 0 else len(requests) // self.batch_size + 1
190190
pbar = tqdm(total=num_iters, disable=(self.rank != 0), desc="Model Responding")

lmms_eval/models/chat/llava_hf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def _collate(x):
4949
# we group requests by their generation_kwargs,
5050
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
5151
# in the same batch.
52-
re_ords = utils.Collator([reg.args for reg in requests], _collate, grouping=True)
52+
re_ords = utils.Collator([reg.args for reg in requests], _collate, group_fn=lambda x: x[2], grouping=True)
5353
chunks = re_ords.get_batched(n=self.batch_size, batch_fn=None)
5454
num_iters = len(requests) // self.batch_size if len(requests) % self.batch_size == 0 else len(requests) // self.batch_size + 1
5555
pbar = tqdm(total=num_iters, disable=(self.rank != 0), desc="Model Responding")

lmms_eval/models/chat/qwen2_5_vl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def _collate(x):
3434
# we group requests by their generation_kwargs,
3535
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
3636
# in the same batch.
37-
re_ords = utils.Collator([reg.args for reg in requests], _collate, grouping=False)
37+
re_ords = utils.Collator([reg.args for reg in requests], _collate, group_fn=lambda x: x[2], grouping=True)
3838
chunks = re_ords.get_batched(n=self.batch_size, batch_fn=None)
3939
num_iters = len(requests) // self.batch_size if len(requests) % self.batch_size == 0 else len(requests) // self.batch_size + 1
4040
pbar = tqdm(total=num_iters, disable=(self.rank != 0), desc="Model Responding")

0 commit comments

Comments
 (0)