@@ -39,11 +39,12 @@ def __init__(
3939            from  torchao .utils  import  find_multiple 
4040
4141            self .origin_in_features  =  in_features 
42-             in_features  =  find_multiple (in_features , (1024 ,))
42+             # pyre-ignore[6]: Incompatible parameter type 
43+             in_features  =  find_multiple (in_features , 1024 )
4344
45+         self .use_bias  =  bias 
4446        self .in_features  =  in_features 
4547        self .out_features  =  out_features 
46-         assert  not  bias , "require bias=False" 
4748        self .device  =  device 
4849        self .groupsize  =  groupsize 
4950        self .inner_k_tiles  =  inner_k_tiles 
@@ -80,20 +81,28 @@ def __init__(
8081                device = device ,
8182            ),
8283        )
84+         if  bias :
85+             self .register_buffer (
86+                 "bias" ,
87+                 torch .empty ((out_features ,), dtype = torch .float32 , device = device ),
88+             )
8389
8490    def  forward (self , input : torch .Tensor ) ->  torch .Tensor :
8591        if  self .padding :
8692            input  =  F .pad (input , pad = (0 , self .in_features  -  self .origin_in_features ))
8793        # The forward method is replaced. In the original implementation, the forward 
8894        # method is torchao.quantization.GPTQ.linear_forward_int4; here a Vulkan custom 
8995        # operator is called instead. 
90-         return  torch .ops .et_vk .linear_weight_int4 (
96+         r   =  torch .ops .et_vk .linear_weight_int4 (
9197            input ,
9298            self .weight ,
9399            self .groupsize ,
94100            self .scales_and_zeros ,
95101            self .inner_k_tiles ,
96102        )
103+         if  self .use_bias :
104+             return  r  +  self .bias 
105+         return  r 
97106
98107
99108# This function is coped from torchao.quantization.GPTQ._replace_linear_int4 
@@ -128,7 +137,7 @@ def _vk_replace_linear_int4(
128137                new_linear  =  linear_class (
129138                    child .in_features ,
130139                    child .out_features ,
131-                     bias = False ,
140+                     bias = child . bias   is   not   None ,
132141                    device = child .weight .device ,
133142                    groupsize = groupsize ,
134143                    inner_k_tiles = inner_k_tiles ,
@@ -138,6 +147,9 @@ def _vk_replace_linear_int4(
138147                if  copy_weights  and  child .weight .device  !=  torch .device ("meta" ):
139148                    # pyre-fixme[16]: `Module` has no attribute `weight`. 
140149                    new_linear .weight  =  child .weight 
150+                     if  child .bias  is  not   None :
151+                         # pyre-fixme[16]: `Module` has no attribute `bias`. 
152+                         new_linear .bias  =  child .bias 
141153                setattr (module , name , new_linear )
142154        else :
143155            _vk_replace_linear_int4 (
@@ -189,7 +201,6 @@ def _create_quantized_state_dict(
189201                mod .out_features  <  self .feature_limit 
190202                and  mod .in_features  <  self .feature_limit 
191203            ):
192-                 assert  not  mod .bias 
193204                out_features  =  mod .out_features 
194205                in_features  =  mod .in_features 
195206                logging .info (f"linear: { fqn }  , in={ in_features }  , out={ out_features }  " )
@@ -210,7 +221,8 @@ def _create_quantized_state_dict(
210221                        logging .warn (
211222                            f"warning: { fqn }   is padded to satisfy in_features % 1024 == 0" 
212223                        )
213-                         padded_in_features  =  find_multiple (in_features , (1024 ,))
224+                         # pyre-ignore[6]: Incompatible parameter type 
225+                         padded_in_features  =  find_multiple (in_features , 1024 )
214226                        weight  =  F .pad (
215227                            weight , pad = (0 , padded_in_features  -  in_features )
216228                        )
0 commit comments