@@ -494,6 +494,7 @@ def replace_embedding_weight_only_grouped_int8_per_channel(
494494 group_size = group_size ,
495495 dtype = child .weight .dtype ,
496496 packed = packed ,
497+ bitwidth = bitwidth ,
497498 ),
498499 )
499500 else :
@@ -614,13 +615,15 @@ def __init__(
614615 group_size : Optional [int ] = None ,
615616 dtype = torch .half ,
616617 packed = False ,
618+ bitwidth : int = 8 ,
617619 ) -> None :
618620 super ().__init__ ()
619621 if group_size is None or group_size == 0 :
620622 group_size = embedding_dim
621623 self .group_size = group_size
622624 self .dtype = dtype
623625 self .packed = packed
626+ self .bitwidth = bitwidth
624627 if not packed :
625628 self .register_buffer (
626629 "weight" ,
@@ -629,12 +632,25 @@ def __init__(
629632 ),
630633 )
631634 else : # packed
632- self .register_buffer (
633- "weight" ,
634- torch .empty (
635- (vocab_size , embedding_dim // 2 ), dtype = torch .uint8 , device = device
636- ),
637- )
635+ if bitwidth == 2 :
636+ self .register_buffer (
637+ "weight" ,
638+ torch .empty (
639+ (vocab_size , embedding_dim // 4 ),
640+ dtype = torch .uint8 ,
641+ device = device ,
642+ ),
643+ )
644+ elif bitwidth == 4 :
645+ self .register_buffer (
646+ "weight" ,
647+ torch .empty (
648+ (vocab_size , embedding_dim // 2 ),
649+ dtype = torch .uint8 ,
650+ device = device ,
651+ ),
652+ )
653+
638654 groups_per_row = (embedding_dim + group_size - 1 ) // group_size
639655 if groups_per_row > 1 :
640656 self .register_buffer (
@@ -654,10 +670,15 @@ def forward(self, indices: torch.Tensor) -> torch.Tensor:
654670 return torch .ops .quantized_decomposed .embedding_byte .dtype (
655671 self .weight , self .scales , None , 0 , 0 , indices , dtype = self .dtype
656672 )
657- else : # 4bit packed
658- return torch .ops .quantized_decomposed .embedding_4bit .dtype (
659- self .weight , self .scales , None , 0 , 0 , indices , dtype = self .dtype
660- )
673+ else : # packed
674+ if self .bitwidth == 2 :
675+ return torch .ops .quantized_decomposed .embedding_2bit .dtype (
676+ self .weight , self .scales , None , 0 , 0 , indices , dtype = self .dtype
677+ )
678+ elif self .bitwidth == 4 :
679+ return torch .ops .quantized_decomposed .embedding_4bit .dtype (
680+ self .weight , self .scales , None , 0 , 0 , indices , dtype = self .dtype
681+ )
661682
662683
663684############################ Source Transform Start #######################
0 commit comments