Skip to content

Commit a6368bf

Browse files
committed
clean up eager quant in llm_export
1 parent 51befee commit a6368bf

File tree

2 files changed

+54
-271
lines changed

2 files changed

+54
-271
lines changed

examples/models/llama/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -418,7 +418,7 @@ python -m examples.models.llama.export_llama \
418418
```
419419
420420
A few notes:
421-
- If your model shares embedding/unembedding weights (like Llama1B and Llama3B do), you can add `--use_shared_embedding` to take advantage of this and reduce memory. When this option is enabled, you can specify whether embeddings are quantized asymmetrically or not by specifying a third argument. For example, `-E "torchao:4,32,true"` means that the embedding is quantized to 4-bits with group_size=32 and is asymmetric (this is the default behavior if you simply use `-E "torchao:4,32"`), whereas `-E "torchao:4,32,false"` means that the embedding is quantized to 4-bits with group_size=32 and is symmetric. If `--use_shared_embedding` is specified, the unembedding (i.e., the final linear layer) is quantized in the same way, but also uses 8-bit dynamically quantized activations.
421+
- If your model shares embedding/unembedding weights (like Llama1B and Llama3B do), you can add `--use_shared_embedding` to take advantage of this and reduce memory. When this option is enabled, you can specify whether embeddings are quantized symmetrically or not by specifying a third argument. For example, `-E "torchao:4,32,true"` means that the embedding is quantized to 4-bits with group_size=32 and is symmetric (this is the default behavior if you simply use `-E "torchao:4,32"`), whereas `-E "torchao:4,32,false"` means that the embedding is quantized to 4-bits with group_size=32 and is asymmetric. If `--use_shared_embedding` is specified, the unembedding (i.e., the final linear layer) is quantized in the same way, but also uses 8-bit dynamically quantized activations.
422422
- To do channelwise quantization, specify group_size to 0. This works for both linear and embedding layers.
423423
424424
Once the model is exported, we need to build ExecuTorch and the runner with the low-bit kernels.

examples/models/llama/source_transformation/quantize.py

Lines changed: 53 additions & 270 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,15 @@
1818

1919
from sentencepiece import SentencePieceProcessor
2020

21+
from torchao.dtypes import PackedLinearInt8DynamicActivationIntxWeightLayout
22+
from torchao.quantization.granularity import PerAxis, PerGroup
23+
from torchao.quantization.quant_api import (
24+
Int8DynamicActivationIntxWeightConfig,
25+
IntxWeightOnlyConfig,
26+
MappingType,
27+
quantize_,
28+
)
29+
2130

2231
try:
2332
from fairseq2.nn.embedding import (
@@ -118,15 +127,6 @@ def quantize( # noqa C901
118127
assert len(matches) == 1, f"Expected 1 match for pattern but got {len(matches)}"
119128
bitwidth = int(matches[0][0])
120129

121-
from torchao.dtypes import PackedLinearInt8DynamicActivationIntxWeightLayout
122-
from torchao.quantization.granularity import PerAxis, PerGroup
123-
from torchao.quantization.quant_api import (
124-
Int8DynamicActivationIntxWeightConfig,
125-
MappingType,
126-
quantize_,
127-
)
128-
from torchao.utils import unwrap_tensor_subclass
129-
130130
with torch.no_grad():
131131
# Computation dtype is fixed to fp32 in the implementation of quantize_, so
132132
# no way to decouple checkpoint and computation dtype.
@@ -141,7 +141,6 @@ def quantize( # noqa C901
141141
layout=PackedLinearInt8DynamicActivationIntxWeightLayout(),
142142
),
143143
)
144-
model = unwrap_tensor_subclass(model)
145144
if verbose:
146145
print("quantized model:", model)
147146
return model
@@ -150,14 +149,17 @@ def quantize( # noqa C901
150149
if group_size is None:
151150
raise Exception("For 8da4w quantization, group size must be specified.")
152151

153-
from torchao.quantization import int8_dynamic_activation_int4_weight, quantize_
154-
from torchao.utils import unwrap_tensor_subclass
155-
156-
quantize_(model, int8_dynamic_activation_int4_weight(group_size=group_size))
157-
model = unwrap_tensor_subclass(model)
158-
152+
quantize_(
153+
model,
154+
Int8DynamicActivationIntxWeightConfig(
155+
weight_dtype=torch.int4,
156+
weight_granularity=(
157+
PerAxis(0) if group_size == 0 else PerGroup(group_size)
158+
),
159+
weight_mapping_type=MappingType.SYMMETRIC,
160+
),
161+
)
159162
# TODO: deal with checkpoint / computation dtype decoupling.
160-
161163
if verbose:
162164
print("quantized model:", model)
163165
return model
@@ -563,254 +565,32 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
563565
)
564566

565567

566-
#########################################################################
567-
##### embedding table quantization ######
568-
569-
570-
def replace_embedding_weight_only_grouped_int8_per_channel(
571-
module, device, bitwidth: int = 8, group_size: Optional[int] = None, packed=False
572-
):
573-
for name, child in module.named_children():
574-
# print(f"name: {name}")
575-
if isinstance(child, nn.Embedding):
576-
# print(f"{name, child}")
577-
# print(f"weights size: {child.weight.size()}")
578-
setattr(
579-
module,
580-
name,
581-
QuantizedGroupEmbedding(
582-
device=device,
583-
vocab_size=child.weight.shape[0],
584-
embedding_dim=child.weight.shape[1],
585-
group_size=group_size,
586-
dtype=child.weight.dtype,
587-
packed=packed,
588-
bitwidth=bitwidth,
589-
),
590-
)
591-
else:
592-
replace_embedding_weight_only_grouped_int8_per_channel(
593-
child, device, bitwidth, group_size, packed
594-
)
595-
596-
597-
class EmbeddingQuantHandler(QuantHandler):
598-
def __init__(
599-
self,
600-
mod,
601-
device="cpu",
602-
*,
603-
bitwidth: int = 8,
604-
group_size: Optional[int] = None,
605-
packed=False,
606-
precision: Optional[torch.dtype] = None,
607-
):
608-
if isinstance(packed, str):
609-
packed = packed == "True"
610-
self.mod = mod
611-
self.device = device
612-
self.group_size = group_size
613-
self.bitwidth = bitwidth
614-
self.packed = packed
615-
# Dtype of the weights right before quantization.
616-
self.precision = precision
617-
if (bitwidth not in [2, 4]) and packed:
618-
raise RuntimeError("pack only works with bitsize 2, 4")
619-
620-
@torch.no_grad()
621-
def create_quantized_state_dict(self, packed=False) -> Dict:
622-
cur_state_dict = self.mod.state_dict()
623-
624-
if self.bitwidth == 2:
625-
range_min = -2
626-
range_max = 1
627-
elif self.bitwidth == 4:
628-
range_min = -8
629-
range_max = 7
630-
elif self.bitwidth == 8:
631-
range_min = -128
632-
range_max = 127
633-
else:
634-
raise ValueError(f"Unsupported bitwidth {self.bitwidth}")
635-
636-
for fqn, mod in self.mod.named_modules():
637-
if isinstance(mod, nn.Embedding):
638-
# print("****")
639-
# print(f"Embedding identified: {fqn, mod}")
640-
# print(f"weights size: {mod.weight.size()}")
641-
# print(f"quantize {fqn}...")
642-
643-
print(
644-
f"quantize {fqn, mod} with group_size {self.group_size}, bitwidth {self.bitwidth}"
645-
)
646-
weight, scales, _ = dynamically_quantize_per_channel(
647-
(
648-
mod.weight.to(dtype=self.precision)
649-
if self.precision
650-
else mod.weight
651-
),
652-
range_min,
653-
range_max,
654-
torch.int8,
655-
self.group_size,
656-
scales_dtype=mod.weight.dtype,
657-
)
658-
659-
if packed:
660-
if self.bitwidth == 2:
661-
if weight.shape[-1] % 4 != 0:
662-
raise RuntimeError("automatic padding not implemented yet")
663-
weight_range_shifted = weight.add(2).view(torch.uint8)
664-
weight_view = weight_range_shifted.view(
665-
weight.shape[0], weight.shape[1] // 4, 4
666-
)
667-
weight_0 = weight_view[:, :, 0]
668-
weight_1 = weight_view[:, :, 1] << 2
669-
weight_2 = weight_view[:, :, 2] << 4
670-
weight_3 = weight_view[:, :, 3] << 6
671-
weight_packed = weight_0 + weight_1 + weight_2 + weight_3
672-
weight = weight_packed
673-
elif self.bitwidth == 4:
674-
if weight.shape[-1] % 2 != 0:
675-
raise RuntimeError("automatic padding not implemented yet")
676-
weight_range_shifted = weight.add(8).view(torch.uint8)
677-
weight_view = weight_range_shifted.view(
678-
weight.shape[0], weight.shape[1] // 2, 2
679-
)
680-
weight_even = weight_view[:, :, 0] * 16 # left shift 4
681-
weight_odd = weight_view[:, :, 1]
682-
weight_packed = weight_even + weight_odd
683-
weight = weight_packed
684-
685-
weight = weight.to(device=self.device)
686-
scales = scales.to(device=self.device)
687-
# Update state dict
688-
cur_state_dict[f"{fqn}.weight"] = weight
689-
# squeeze makes group_size=rowsize unidimensional
690-
cur_state_dict[f"{fqn}.scales"] = scales.squeeze(dim=-1)
691-
692-
return cur_state_dict
693-
694-
def convert_for_runtime(self) -> nn.Module:
695-
replace_embedding_weight_only_grouped_int8_per_channel(
696-
self.mod, self.device, self.bitwidth, self.group_size, self.packed
697-
)
698-
return self.mod
699-
700-
def quantized_model(self) -> nn.Module:
701-
model_updated_state_dict = self.create_quantized_state_dict(self.packed)
702-
self.convert_for_runtime()
703-
self.mod.load_state_dict(model_updated_state_dict, assign=True)
704-
return self.mod
705-
706-
707-
class QuantizedGroupEmbedding(torch.nn.Module):
708-
def __init__(
709-
self,
710-
device,
711-
vocab_size: int,
712-
embedding_dim: int,
713-
group_size: Optional[int] = None,
714-
dtype=torch.half,
715-
packed=False,
716-
bitwidth: int = 8,
717-
) -> None:
718-
super().__init__()
719-
if group_size is None or group_size == 0:
720-
group_size = embedding_dim
721-
self.group_size = group_size
722-
self.dtype = dtype
723-
self.packed = packed
724-
self.bitwidth = bitwidth
725-
if not packed:
726-
self.register_buffer(
727-
"weight",
728-
torch.zeros(
729-
(vocab_size, embedding_dim), dtype=torch.int8, device=device
730-
),
731-
)
732-
else: # packed
733-
if bitwidth == 2:
734-
self.register_buffer(
735-
"weight",
736-
torch.zeros(
737-
(vocab_size, embedding_dim // 4),
738-
dtype=torch.uint8,
739-
device=device,
740-
),
741-
)
742-
elif bitwidth == 4:
743-
self.register_buffer(
744-
"weight",
745-
torch.zeros(
746-
(vocab_size, embedding_dim // 2),
747-
dtype=torch.uint8,
748-
device=device,
749-
),
750-
)
751-
752-
groups_per_row = (embedding_dim + group_size - 1) // group_size
753-
if groups_per_row > 1:
754-
self.register_buffer(
755-
"scales",
756-
torch.ones(
757-
(vocab_size, groups_per_row), dtype=torch.float16, device=device
758-
),
759-
)
760-
else:
761-
self.register_buffer(
762-
"scales", torch.ones((vocab_size,), dtype=torch.float16, device=device)
763-
)
764-
765-
@torch.no_grad()
766-
def forward(self, indices: torch.Tensor) -> torch.Tensor:
767-
if not self.packed: # 8bit
768-
return torch.ops.quantized_decomposed.embedding_byte.dtype(
769-
self.weight, self.scales, None, -128, 127, indices, dtype=self.dtype
770-
)
771-
else: # packed
772-
if self.bitwidth == 2:
773-
return torch.ops.quantized_decomposed.embedding_2bit.dtype(
774-
self.weight, self.scales, None, -2, 1, indices, dtype=self.dtype
775-
)
568+
############################ Source Transform Start #######################
776569

777-
# Remaining case (always return to make pyre happy)
778-
assert self.bitwidth == 4
779-
return torch.ops.quantized_decomposed.embedding_4bit.dtype(
780-
self.weight, self.scales, None, -8, 7, indices, dtype=self.dtype
781-
)
782570

571+
def get_quant_embedding_transform(args, dtype_override: Optional[DType] = None):
572+
use_torchao = args.embedding_quantize.startswith("torchao:")
573+
if use_torchao:
574+
quant_args = args.embedding_quantize.split(":")[1].split(",")
575+
else:
576+
quant_args = args.embedding_quantize.split(",")
783577

784-
############################ Source Transform Start #######################
578+
bitwidth = int(quant_args[0])
579+
group_size = quant_args[0]
580+
if group_size in ["none", "None", "0"]:
581+
group_size = 0
582+
group_size = int(group_size)
583+
is_symmetric = bool(quant_args[3]) if len(quant_args) > 2 else True
785584

585+
weight_dtype = getattr(torch, f"int{bitwidth}")
586+
granularity = PerAxis(0) if group_size == 0 else PerGroup(group_size)
587+
mapping_type = MappingType.SYMMETRIC if is_symmetric else MappingType.ASYMMETRIC
786588

787-
def get_quant_embedding_transform(args, dtype_override: Optional[DType] = None):
788-
if args.embedding_quantize.startswith("torchao:"):
589+
if use_torchao:
789590
from torchao.experimental.quant_api import (
790591
EmbeddingQuantizer,
791592
SharedEmbeddingQuantizer,
792593
)
793-
from torchao.quantization.granularity import PerAxis, PerGroup
794-
from torchao.quantization.quant_api import MappingType
795-
796-
quant_args = args.embedding_quantize.split(":")[1].split(",")
797-
if len(quant_args) == 2:
798-
bitwidth, group_size = quant_args
799-
is_asymmetric = True
800-
else:
801-
bitwidth, group_size, is_asymmetric = quant_args
802-
803-
if group_size in ["none", "None", "0"]:
804-
group_size = 0
805-
806-
group_size = int(group_size)
807-
bitwidth = int(bitwidth)
808-
is_asymmetric = bool(is_asymmetric)
809-
weight_dtype = getattr(torch, f"int{bitwidth}")
810-
granularity = PerAxis(0) if group_size == 0 else PerGroup(group_size)
811-
mapping_type = (
812-
MappingType.ASYMMETRIC if is_asymmetric else MappingType.SYMMETRIC
813-
)
814594

815595
def _torchao_embedding_quantizer(model):
816596
with torch.no_grad():
@@ -831,20 +611,23 @@ def _torchao_embedding_quantizer(model):
831611

832612
return _torchao_embedding_quantizer
833613

834-
bitwidth, group_size = args.embedding_quantize.split(",")
835-
if group_size == "none" or group_size == "None" or group_size == "0":
836-
group_size = None
837-
else:
838-
group_size = int(group_size)
839-
bitwidth = int(bitwidth)
840-
torch_dtype = dtype_override.to_torch_dtype() if dtype_override else None
841-
return lambda model: EmbeddingQuantHandler(
842-
model,
843-
bitwidth=bitwidth,
844-
group_size=group_size,
845-
packed=(bitwidth in [2, 4]),
846-
precision=torch_dtype,
847-
).quantized_model()
614+
def _quantize_embedding(model):
615+
assert weight_dtype in [
616+
torch.int2,
617+
torch.int4,
618+
torch.int8,
619+
], "Only 2, 4, or 8-bit embeddings are supported unless using torchao"
620+
quantize_(
621+
model,
622+
IntxWeightOnlyConfig(
623+
weight_dtype=weight_dtype,
624+
granularity=granularity,
625+
mapping_type=mapping_type,
626+
),
627+
)
628+
return model
629+
630+
return _quantize_embedding
848631

849632

850633
def get_quant_weight_transform(

0 commit comments

Comments
 (0)