Skip to content

Commit 6ffd10c

Browse files
authored
Fix (brevitas_examples/llm): correct batch size for lm_eval (#1430)
1 parent 39e0547 commit 6ffd10c

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

src/brevitas_examples/llm/main.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -654,9 +654,11 @@ def quantize_llm(args, extra_args=None):
654654
from lm_eval.models.huggingface import HFLM
655655
with torch.no_grad(), quant_inference_mode(model, compile=args.compile_eval):
656656
model(**calibration_loader[0])
657+
batch_size = 'auto' if args.few_shot_override_batch_size is None else args.few_shot_override_batch_size
657658

658659
wrapped_model = HFLM(
659-
pretrained=model, add_bos_token=True) # need to wrap for LLM eval
660+
pretrained=model, add_bos_token=True,
661+
batch_size=batch_size) # need to wrap for LLM eval
660662
few_shot_eval_results = evaluator.simple_evaluate(
661663
model=wrapped_model,
662664
model_args=None,

0 commit comments

Comments
 (0)