Skip to content

Commit 046dc98

Browse files
authored
Some fixes in the recently added quantizer (#111)
1 parent a7963ae commit 046dc98

File tree

2 files changed

+8
-4
lines changed

2 files changed

+8
-4
lines changed

test/quantization/model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@
1212
from torch.nn import functional as F
1313

1414
def prepare_inputs_for_model(inps):
15+
# this is because input from lm-eval is 2d
16+
if input.dim() != 2:
17+
raise ValueError(f"Expected input to be of dim 2, but got {input.dim()}")
18+
1519
inps = inps.squeeze(0)
1620
# setup inputs in correct format
1721
max_new_tokens = 1

torchao/quantization/GPTQ.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1225,9 +1225,9 @@ def __init__(
12251225
calibration_limit,
12261226
calibration_seq_length,
12271227
pad_calibration_inputs,
1228-
inner_k_tiles=8,
1229-
padding_allowed=True,
1230-
precision=torch.float32,
1231-
_is_gpt_fast=True,
1228+
inner_k_tiles=inner_k_tiles,
1229+
padding_allowed=padding_allowed,
1230+
precision=precision,
1231+
_is_gpt_fast=_is_gpt_fast,
12321232
_use_cuda=_use_cuda,
12331233
)

0 commit comments

Comments
 (0)