Skip to content

Commit 847a992

Browse files
committed
update embedding group quantizer
1 parent b42a94e commit 847a992

File tree

1 file changed

+31
-10
lines changed
  • examples/models/llama2/source_transformation

1 file changed

+31
-10
lines changed

examples/models/llama2/source_transformation/quantize.py

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -494,6 +494,7 @@ def replace_embedding_weight_only_grouped_int8_per_channel(
494494
group_size=group_size,
495495
dtype=child.weight.dtype,
496496
packed=packed,
497+
bitwidth=bitwidth,
497498
),
498499
)
499500
else:
@@ -614,13 +615,15 @@ def __init__(
614615
group_size: Optional[int] = None,
615616
dtype=torch.half,
616617
packed=False,
618+
bitwidth: int = 8,
617619
) -> None:
618620
super().__init__()
619621
if group_size is None or group_size == 0:
620622
group_size = embedding_dim
621623
self.group_size = group_size
622624
self.dtype = dtype
623625
self.packed = packed
626+
self.bitwidth = bitwidth
624627
if not packed:
625628
self.register_buffer(
626629
"weight",
@@ -629,12 +632,25 @@ def __init__(
629632
),
630633
)
631634
else: # packed
632-
self.register_buffer(
633-
"weight",
634-
torch.empty(
635-
(vocab_size, embedding_dim // 2), dtype=torch.uint8, device=device
636-
),
637-
)
635+
if bitwidth == 2:
636+
self.register_buffer(
637+
"weight",
638+
torch.empty(
639+
(vocab_size, embedding_dim // 4),
640+
dtype=torch.uint8,
641+
device=device,
642+
),
643+
)
644+
elif bitwidth == 4:
645+
self.register_buffer(
646+
"weight",
647+
torch.empty(
648+
(vocab_size, embedding_dim // 2),
649+
dtype=torch.uint8,
650+
device=device,
651+
),
652+
)
653+
638654
groups_per_row = (embedding_dim + group_size - 1) // group_size
639655
if groups_per_row > 1:
640656
self.register_buffer(
@@ -654,10 +670,15 @@ def forward(self, indices: torch.Tensor) -> torch.Tensor:
654670
return torch.ops.quantized_decomposed.embedding_byte.dtype(
655671
self.weight, self.scales, None, 0, 0, indices, dtype=self.dtype
656672
)
657-
else: # 4bit packed
658-
return torch.ops.quantized_decomposed.embedding_4bit.dtype(
659-
self.weight, self.scales, None, 0, 0, indices, dtype=self.dtype
660-
)
673+
else: # packed
674+
if self.bitwidth == 2:
675+
return torch.ops.quantized_decomposed.embedding_2bit.dtype(
676+
self.weight, self.scales, None, 0, 0, indices, dtype=self.dtype
677+
)
678+
elif self.bitwidth == 4:
679+
return torch.ops.quantized_decomposed.embedding_4bit.dtype(
680+
self.weight, self.scales, None, 0, 0, indices, dtype=self.dtype
681+
)
661682

662683

663684
############################ Source Transform Start #######################

0 commit comments

Comments
 (0)