Skip to content

Commit ec258e0

Browse files
authored
add int4 non-gptq and bugfixes (#119)
Summary: int4weightlinear had a bug that made it not pad when it should have Test Plan: python test/quantization/test_quant_api.py -k "int4wo" Reviewers: Subscribers: Tasks: Tags:
1 parent b0a333c commit ec258e0

File tree

5 files changed

+133
-14
lines changed

5 files changed

+133
-14
lines changed

test/quantization/test_quant_api.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,6 @@ def test_gptq_quantizer_gpt_fast(self):
300300
@unittest.skip("skipping until we get checkpoints for gpt-fast")
301301
def test_gptq_quantizer_int4wo(self):
302302
from torchao.quantization.GPTQ import Int4WeightOnlyGPTQQuantizer, InputRecorder, TransformerEvalWrapper
303-
# should be similar to TorchCompileDynamicQuantizer
304303
precision = torch.bfloat16
305304
device = "cuda"
306305
checkpoint_path = Path("../gpt-fast/checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth")
@@ -357,6 +356,41 @@ def test_gptq_quantizer_int4wo(self):
357356
f"accuracy regressed from 7.76 to {result['results']['wikitext']['word_perplexity,none']}"
358357
)
359358

359+
@unittest.skip("skipping until we get checkpoints for gpt-fast")
360+
def test_quantizer_int4wo(self):
361+
from torchao.quantization.GPTQ import Int4WeightOnlyQuantizer, TransformerEvalWrapper
362+
precision = torch.bfloat16
363+
device = "cuda"
364+
checkpoint_path = Path("../gpt-fast/checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth")
365+
model = Transformer.from_name(checkpoint_path.parent.name)
366+
checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
367+
model.load_state_dict(checkpoint, assign=True)
368+
model = model.to(dtype=precision, device=device)
369+
model.eval()
370+
tokenizer_path = checkpoint_path.parent / "tokenizer.model"
371+
assert tokenizer_path.is_file(), tokenizer_path
372+
tokenizer = SentencePieceProcessor( # pyre-ignore[28]
373+
model_file=str(tokenizer_path)
374+
)
375+
groupsize = 128
376+
quantizer = Int4WeightOnlyQuantizer(
377+
groupsize,
378+
)
379+
model = quantizer.quantize(model).cuda()
380+
result = TransformerEvalWrapper(
381+
model,
382+
tokenizer,
383+
model.config.block_size,
384+
prepare_inputs_for_model,
385+
device,
386+
).run_eval(
387+
["wikitext"],
388+
1,
389+
)
390+
assert result['results']['wikitext']['word_perplexity,none'] < 8.24, (
391+
f"accuracy regressed from 8.23 to {result['results']['wikitext']['word_perplexity,none']}"
392+
)
393+
360394
@unittest.skip("skipping until we get checkpoints for gpt-fast")
361395
def test_eval_wrapper(self):
362396
from torchao.quantization.GPTQ import TransformerEvalWrapper

torchao/quantization/GPTQ.py

Lines changed: 94 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
groupwise_affine_quantize_tensor_from_qparams,
2929
groupwise_affine_dequantize_tensor_from_qparams,
3030
pack_tinygemm_scales_and_zeros,
31+
groupwise_affine_quantize_tensor,
3132
)
3233
aten = torch.ops.aten
3334

@@ -65,8 +66,8 @@
6566

6667
__all__ = [
6768
"MultiInput",
68-
"WeightOnlyInt4Linear",
6969
"Int4WeightOnlyGPTQQuantizer",
70+
"Int4WeightOnlyQuantizer",
7071
] + add_ons
7172

7273
if lm_eval_available:
@@ -117,7 +118,10 @@ def __init__(
117118

118119
@property
119120
def eot_token_id(self):
120-
return self._tokenizer.eos_id()
121+
try:
122+
return self._tokenizer.eos_id()
123+
except:
124+
return self._tokenizer.eos_id
121125

122126
@property
123127
def max_length(self):
@@ -139,7 +143,10 @@ def tok_encode(self, string: str, **kwargs):
139143
# TODO: verify this for multi-batch as well
140144
tokens = self._tokenizer.encode(string)
141145
if hasattr(self._tokenizer, "bos_id"):
142-
tokens = [self._tokenizer.bos_id()] + tokens
146+
try:
147+
tokens = [self._tokenizer.bos_id()] + tokens
148+
except:
149+
tokens = [self._tokenizer.bos_id] + tokens
143150
return tokens
144151

145152
def tok_decode(self, tokens):
@@ -747,6 +754,12 @@ def _convert_for_runtime(self, model: torch.nn.Module) -> "nn.Module":
747754
def quantize(self, model: torch.nn.Module, inputs: List[MultiInput], **kwargs: Any) -> torch.nn.Module:
748755
pass
749756

757+
def _check_linear_int4_k(k, groupsize = 1, inner_k_tiles = None):
758+
k_divisible_by_groupsize = k % groupsize == 0
759+
if inner_k_tiles is not None:
760+
k_divisible_by_16_times_inner_k_tiles = k % (inner_k_tiles * 16) == 0
761+
return k_divisible_by_groupsize and k_divisible_by_16_times_inner_k_tiles
762+
return k_divisible_by_groupsize
750763

751764
def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, groupsize):
752765
origin_x_size = x.size()
@@ -767,7 +780,7 @@ def __init__(
767780
bias=False, device=None, dtype=None, groupsize: int = 128, inner_k_tiles: int = 8, use_cuda=True,
768781
) -> None:
769782
super().__init__()
770-
self.padding = _check_linear_int4_k(in_features, groupsize, inner_k_tiles)
783+
self.padding = not _check_linear_int4_k(in_features, groupsize, inner_k_tiles)
771784
if self.padding:
772785
from model import find_multiple
773786
self.origin_in_features = in_features
@@ -806,14 +819,6 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
806819
self.weight, self.scales_and_zeros, self.out_features, self.groupsize
807820
)
808821

809-
810-
def _check_linear_int4_k(k, groupsize = 1, inner_k_tiles = None):
811-
k_divisible_by_groupsize = k % groupsize == 0
812-
if inner_k_tiles is not None:
813-
k_divisible_by_16_times_inner_k_tiles = k % (inner_k_tiles * 16) == 0
814-
return k_divisible_by_groupsize and k_divisible_by_16_times_inner_k_tiles
815-
return k_divisible_by_groupsize
816-
817822
def replace_linear_int4(module, groupsize, inner_k_tiles, padding_allowed, use_cuda=True, skip_layer_func = None):
818823

819824
for name, child in module.named_children():
@@ -826,6 +831,83 @@ def replace_linear_int4(module, groupsize, inner_k_tiles, padding_allowed, use_c
826831
else:
827832
replace_linear_int4(child, groupsize, inner_k_tiles, padding_allowed, use_cuda, skip_layer_func)
828833

834+
class Int4WeightOnlyQuantizer(Quantizer):
835+
def __init__(
836+
self,
837+
groupsize: int = 256,
838+
padding_allowed: bool = True,
839+
inner_k_tiles: Optional[int] = 8,
840+
) -> None:
841+
super().__init__()
842+
assert inner_k_tiles in [2, 4, 8]
843+
assert groupsize in [32, 64, 128, 256]
844+
845+
self.inner_k_tiles = inner_k_tiles
846+
self.groupsize: int = groupsize
847+
self.padding_allowed: bool = padding_allowed
848+
849+
@torch.no_grad()
850+
def _create_quantized_state_dict(
851+
self, model: torch.nn.Module
852+
) -> Dict[str, torch.Tensor]:
853+
cur_state_dict = model.state_dict()
854+
for fqn, mod in model.named_modules():
855+
if isinstance(mod, torch.nn.Linear):
856+
assert not mod.bias
857+
out_features = mod.out_features
858+
in_features = mod.in_features
859+
# assert out_features % 8 == 0, "require out_features % 8 == 0"
860+
print(f"linear: {fqn}, in={in_features}, out={out_features}")
861+
862+
assert (
863+
in_features % self.groupsize == 0
864+
), f"require in_features:{in_features} % self.groupsize:{self.groupsize} == 0"
865+
866+
weight = mod.weight.data
867+
if not _check_linear_int4_k(
868+
in_features, self.groupsize, self.inner_k_tiles
869+
):
870+
if self.padding_allowed:
871+
from .utils import find_multiple
872+
import torch.nn.functional as F
873+
print(f"warning: {fqn} is padded to satisfy in_features % 1024 == 0")
874+
padded_in_features = find_multiple(in_features, 1024)
875+
weight = F.pad(weight, pad=(0, padded_in_features - in_features))
876+
else:
877+
print(f"warning: {fqn} is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, " +
878+
"and that groupsize and inner_k_tiles*16 evenly divide into it")
879+
continue
880+
(
881+
w_int4x8,
882+
scales_and_zeros
883+
) = groupwise_affine_quantize_tensor(
884+
weight,
885+
4, # n_bit
886+
self.groupsize,
887+
)
888+
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(w_int4x8.to("cuda"), self.inner_k_tiles)
889+
cur_state_dict[f"{fqn}.weight"] = weight_int4pack.to("cuda")
890+
cur_state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros.to("cuda")
891+
return cur_state_dict
892+
893+
def _convert_for_runtime(self, model: torch.nn.Module) -> torch.nn.Module:
894+
replace_linear_int4(
895+
model,
896+
self.groupsize,
897+
self.inner_k_tiles,
898+
self.padding_allowed,
899+
)
900+
return model
901+
902+
def quantize(
903+
self, model: torch.nn.Module, *args: Any, **kwargs: Any
904+
) -> torch.nn.Module:
905+
state_dict = self._create_quantized_state_dict(model)
906+
model = self._convert_for_runtime(model)
907+
# TODO: make it strict
908+
model.load_state_dict(state_dict, strict=False)
909+
return model
910+
829911
class Int4WeightOnlyGPTQQuantizer(GPTQQuantizer):
830912
def __init__(
831913
self,

torchao/quantization/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,4 +42,6 @@
4242
"compute_error",
4343
"get_model_size_in_bytes",
4444
"WeightOnlyInt8QuantLinear",
45+
"Int4WeightOnlyGPTQQuantizer",
46+
"Int4WeightOnlyQuantizer",
4547
]

torchao/quantization/quant_api.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from .unified import Quantizer, TwoStepQuantizer
3333
from .GPTQ import (
3434
Int4WeightOnlyGPTQQuantizer,
35+
Int4WeightOnlyQuantizer,
3536
)
3637

3738

@@ -45,6 +46,7 @@
4546
"Quantizer",
4647
"TwoStepQuantizer",
4748
"Int4WeightOnlyGPTQQuantizer",
49+
"Int4WeightOnlyQuantizer"
4850
]
4951

5052
if TORCH_VERSION_AFTER_2_3:

torchao/quantization/quant_primitives.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -383,7 +383,6 @@ def pack_tinygemm_scales_and_zeros(scales, zeros):
383383

384384
def unpack_tinygemm_scales_and_zeros(scales_and_zeros):
385385
assert len(scales_and_zeros.shape) == 3 and scales_and_zeros.shape[2] == 2
386-
assert scales_and_zeros.dtype == torch.float
387386
return torch.split(scales_and_zeros.transpose(0, 1), 1, 2)
388387

389388

0 commit comments

Comments
 (0)