Skip to content

Commit 40ff314

Browse files
committed
improving GPTQ defauls
Summary: previously wikitext was default task for GPTQ which with other defaults, wouldn't collect any examples. also improved error message. Test Plan: python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode 8da4w-gptq --calibration_limit 5 python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode 8da4w-gptq --calibration_limit 5 --calibration_tasks hellaswag Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 9275ece Pull Request resolved: #104
1 parent 2037199 commit 40ff314

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

quantize.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,11 @@ def get_inputs(model, tokenizer, calibration_tasks, calibration_limit, calibrati
263263
limit=calibration_limit,
264264
)
265265
inputs = input_recorder.get_recorded_inputs()
266+
assert inputs is not None, (
267+
f"No inputs were collected, use a task other than {calibration_tasks}, "+
268+
f"use option pad_calibration_inputs, or decrease calibration_sequence_length (currently "+
269+
f"{calibration_seq_length})"
270+
)
266271
print(f"Obtained {len(inputs[0].values)} calibration samples")
267272
return inputs
268273

@@ -597,7 +602,7 @@ def quantize(
597602
parser.add_argument('--checkpoint_path', type=Path, default=Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"), help='Path to the model checkpoint to be quantized.')
598603
parser.add_argument('--mode', '-q', type=str, default='int8', choices=['int8', 'int4', 'int4-gptq'], help='type of quantization to perform')
599604
parser.add_argument('--groupsize', type=int, default=32, help='Group size for int4 quantization.')
600-
parser.add_argument('--calibration_tasks', type=str, nargs='+', default=['hellaswag'], help='tasks to do gptq calibration on, if doing gptq')
605+
parser.add_argument('--calibration_tasks', type=str, nargs='+', default=['wikitext'], help='tasks to do gptq calibration on, if doing gptq')
601606
parser.add_argument('--calibration_limit', type=int, default=1000, help='number of samples to use for gptq calibration')
602607
parser.add_argument('--calibration_seq_length', type=int, default=100, help='length of sequences to use for gptq calibration')
603608
parser.add_argument('--pad_calibration_inputs', type=bool, default=False, help='pads sequences shorter than calibration_seq_length to that length, yielding more calibration inputs but running much slower')

0 commit comments

Comments
 (0)