Skip to content

Commit c29a9f6

Browse files
committed
up
1 parent a7382c9 commit c29a9f6

File tree

2 files changed

+231
-3
lines changed

2 files changed

+231
-3
lines changed

examples/models/llama/source_transformation/quantize.py

Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -510,6 +510,223 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
510510
self.precision,
511511
)
512512

513+
#########################################################################
514+
##### embedding table quantization ######
515+
516+
517+
def replace_embedding_weight_only_grouped_int8_per_channel(
518+
module, device, bitwidth: int = 8, group_size: Optional[int] = None, packed=False
519+
):
520+
for name, child in module.named_children():
521+
# print(f"name: {name}")
522+
if isinstance(child, nn.Embedding):
523+
# print(f"{name, child}")
524+
# print(f"weights size: {child.weight.size()}")
525+
setattr(
526+
module,
527+
name,
528+
QuantizedGroupEmbedding(
529+
device=device,
530+
vocab_size=child.weight.shape[0],
531+
embedding_dim=child.weight.shape[1],
532+
group_size=group_size,
533+
dtype=child.weight.dtype,
534+
packed=packed,
535+
bitwidth=bitwidth,
536+
),
537+
)
538+
else:
539+
replace_embedding_weight_only_grouped_int8_per_channel(
540+
child, device, bitwidth, group_size, packed
541+
)
542+
543+
544+
class EmbeddingQuantHandler(QuantHandler):
545+
def __init__(
546+
self,
547+
mod,
548+
device="cpu",
549+
*,
550+
bitwidth: int = 8,
551+
group_size: Optional[int] = None,
552+
packed=False,
553+
precision: Optional[torch.dtype] = None,
554+
):
555+
if isinstance(packed, str):
556+
packed = packed == "True"
557+
self.mod = mod
558+
self.device = device
559+
self.group_size = group_size
560+
self.bitwidth = bitwidth
561+
self.packed = packed
562+
# Dtype of the weights right before quantization.
563+
self.precision = precision
564+
if (bitwidth not in [2, 4]) and packed:
565+
raise RuntimeError("pack only works with bitsize 2, 4")
566+
567+
@torch.no_grad()
568+
def create_quantized_state_dict(self, packed=False) -> Dict:
569+
cur_state_dict = self.mod.state_dict()
570+
571+
if self.bitwidth == 2:
572+
range_min = -2
573+
range_max = 1
574+
elif self.bitwidth == 4:
575+
range_min = -8
576+
range_max = 7
577+
elif self.bitwidth == 8:
578+
range_min = -128
579+
range_max = 127
580+
else:
581+
raise ValueError(f"Unsupported bitwidth {self.bitwidth}")
582+
583+
for fqn, mod in self.mod.named_modules():
584+
if isinstance(mod, nn.Embedding):
585+
# print("****")
586+
# print(f"Embedding identified: {fqn, mod}")
587+
# print(f"weights size: {mod.weight.size()}")
588+
# print(f"quantize {fqn}...")
589+
590+
print(
591+
f"quantize {fqn, mod} with group_size {self.group_size}, bitwidth {self.bitwidth}"
592+
)
593+
weight, scales, _ = dynamically_quantize_per_channel(
594+
(
595+
mod.weight.to(dtype=self.precision)
596+
if self.precision
597+
else mod.weight
598+
),
599+
range_min,
600+
range_max,
601+
torch.int8,
602+
self.group_size,
603+
scales_dtype=mod.weight.dtype,
604+
)
605+
606+
if packed:
607+
if self.bitwidth == 2:
608+
if weight.shape[-1] % 4 != 0:
609+
raise RuntimeError("automatic padding not implemented yet")
610+
weight_range_shifted = weight.add(2).view(torch.uint8)
611+
weight_view = weight_range_shifted.view(
612+
weight.shape[0], weight.shape[1] // 4, 4
613+
)
614+
weight_0 = weight_view[:, :, 0]
615+
weight_1 = weight_view[:, :, 1] << 2
616+
weight_2 = weight_view[:, :, 2] << 4
617+
weight_3 = weight_view[:, :, 3] << 6
618+
weight_packed = weight_0 + weight_1 + weight_2 + weight_3
619+
weight = weight_packed
620+
elif self.bitwidth == 4:
621+
if weight.shape[-1] % 2 != 0:
622+
raise RuntimeError("automatic padding not implemented yet")
623+
weight_range_shifted = weight.add(8).view(torch.uint8)
624+
weight_view = weight_range_shifted.view(
625+
weight.shape[0], weight.shape[1] // 2, 2
626+
)
627+
weight_even = weight_view[:, :, 0] * 16 # left shift 4
628+
weight_odd = weight_view[:, :, 1]
629+
weight_packed = weight_even + weight_odd
630+
weight = weight_packed
631+
632+
weight = weight.to(device=self.device)
633+
scales = scales.to(device=self.device)
634+
# Update state dict
635+
cur_state_dict[f"{fqn}.weight"] = weight
636+
# squeeze makes group_size=rowsize unidimensional
637+
cur_state_dict[f"{fqn}.scales"] = scales.squeeze(dim=-1)
638+
639+
return cur_state_dict
640+
641+
def convert_for_runtime(self) -> nn.Module:
642+
replace_embedding_weight_only_grouped_int8_per_channel(
643+
self.mod, self.device, self.bitwidth, self.group_size, self.packed
644+
)
645+
return self.mod
646+
647+
def quantized_model(self) -> nn.Module:
648+
model_updated_state_dict = self.create_quantized_state_dict(self.packed)
649+
self.convert_for_runtime()
650+
self.mod.load_state_dict(model_updated_state_dict, assign=True)
651+
return self.mod
652+
653+
654+
class QuantizedGroupEmbedding(torch.nn.Module):
655+
def __init__(
656+
self,
657+
device,
658+
vocab_size: int,
659+
embedding_dim: int,
660+
group_size: Optional[int] = None,
661+
dtype=torch.half,
662+
packed=False,
663+
bitwidth: int = 8,
664+
) -> None:
665+
super().__init__()
666+
if group_size is None or group_size == 0:
667+
group_size = embedding_dim
668+
self.group_size = group_size
669+
self.dtype = dtype
670+
self.packed = packed
671+
self.bitwidth = bitwidth
672+
if not packed:
673+
self.register_buffer(
674+
"weight",
675+
torch.zeros(
676+
(vocab_size, embedding_dim), dtype=torch.int8, device=device
677+
),
678+
)
679+
else: # packed
680+
if bitwidth == 2:
681+
self.register_buffer(
682+
"weight",
683+
torch.zeros(
684+
(vocab_size, embedding_dim // 4),
685+
dtype=torch.uint8,
686+
device=device,
687+
),
688+
)
689+
elif bitwidth == 4:
690+
self.register_buffer(
691+
"weight",
692+
torch.zeros(
693+
(vocab_size, embedding_dim // 2),
694+
dtype=torch.uint8,
695+
device=device,
696+
),
697+
)
698+
699+
groups_per_row = (embedding_dim + group_size - 1) // group_size
700+
if groups_per_row > 1:
701+
self.register_buffer(
702+
"scales",
703+
torch.ones(
704+
(vocab_size, groups_per_row), dtype=torch.float16, device=device
705+
),
706+
)
707+
else:
708+
self.register_buffer(
709+
"scales", torch.ones((vocab_size,), dtype=torch.float16, device=device)
710+
)
711+
712+
@torch.no_grad()
713+
def forward(self, indices: torch.Tensor) -> torch.Tensor:
714+
if not self.packed: # 8bit
715+
return torch.ops.quantized_decomposed.embedding_byte.dtype(
716+
self.weight, self.scales, None, -128, 127, indices, dtype=self.dtype
717+
)
718+
else: # packed
719+
if self.bitwidth == 2:
720+
return torch.ops.quantized_decomposed.embedding_2bit.dtype(
721+
self.weight, self.scales, None, -2, 1, indices, dtype=self.dtype
722+
)
723+
724+
# Remaining case (always return to make pyre happy)
725+
assert self.bitwidth == 4
726+
return torch.ops.quantized_decomposed.embedding_4bit.dtype(
727+
self.weight, self.scales, None, -8, 7, indices, dtype=self.dtype
728+
)
729+
513730

514731
############################ Source Transform Start #######################
515732

examples/models/llava/export_llava.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from executorch.examples.models.llama.source_transformation.quantize import (
2727
get_quant_embedding_transform,
2828
get_quant_weight_transform,
29+
EmbeddingQuantHandler,
2930
)
3031
from executorch.examples.models.llama.source_transformation.sdpa import (
3132
replace_sdpa_with_custom_op,
@@ -183,9 +184,19 @@ def forward(self, images):
183184

184185

185186
def export_token_embedding(llava, prompt):
186-
quantized_token_embed = get_quant_embedding_transform("8,32")(
187-
llava.model_.language_model.model
188-
)
187+
# quantized_token_embed = get_quant_embedding_transform("8,32")(
188+
# llava.model_.language_model.model
189+
# )
190+
def quant_embedding(model):
191+
return EmbeddingQuantHandler(
192+
model,
193+
bitwidth=8,
194+
group_size=32,
195+
packed=False,
196+
).quantized_model()
197+
198+
quantized_token_embed = quant_embedding(llava.model_.language_model.model)
199+
189200
token_dim_1 = Dim("token_dim_1", min=2, max=llava.text_model_args.max_seq_len)
190201
dynamic_shapes = [{1: token_dim_1}]
191202
with torch.no_grad():

0 commit comments

Comments
 (0)