@@ -39,11 +39,12 @@ def __init__(
3939 from torchao .utils import find_multiple
4040
4141 self .origin_in_features = in_features
42- in_features = find_multiple (in_features , (1024 ,))
42+ # pyre-ignore[6]: Incompatible parameter type
43+ in_features = find_multiple (in_features , 1024 )
4344
45+ self .use_bias = bias
4446 self .in_features = in_features
4547 self .out_features = out_features
46- assert not bias , "require bias=False"
4748 self .device = device
4849 self .groupsize = groupsize
4950 self .inner_k_tiles = inner_k_tiles
@@ -80,20 +81,28 @@ def __init__(
8081 device = device ,
8182 ),
8283 )
84+ if bias :
85+ self .register_buffer (
86+ "bias" ,
87+ torch .empty ((out_features ,), dtype = torch .float32 , device = device ),
88+ )
8389
8490 def forward (self , input : torch .Tensor ) -> torch .Tensor :
8591 if self .padding :
8692 input = F .pad (input , pad = (0 , self .in_features - self .origin_in_features ))
8793 # The forward method is replaced. In the original implementation, the forward
8894 # method is torchao.quantization.GPTQ.linear_forward_int4; here a Vulkan custom
8995 # operator is called instead.
90- return torch .ops .et_vk .linear_weight_int4 (
96+ r = torch .ops .et_vk .linear_weight_int4 (
9197 input ,
9298 self .weight ,
9399 self .groupsize ,
94100 self .scales_and_zeros ,
95101 self .inner_k_tiles ,
96102 )
103+ if self .use_bias :
104+ return r + self .bias
105+ return r
97106
98107
99108# This function is coped from torchao.quantization.GPTQ._replace_linear_int4
@@ -128,7 +137,7 @@ def _vk_replace_linear_int4(
128137 new_linear = linear_class (
129138 child .in_features ,
130139 child .out_features ,
131- bias = False ,
140+ bias = child . bias is not None ,
132141 device = child .weight .device ,
133142 groupsize = groupsize ,
134143 inner_k_tiles = inner_k_tiles ,
@@ -138,6 +147,9 @@ def _vk_replace_linear_int4(
138147 if copy_weights and child .weight .device != torch .device ("meta" ):
139148 # pyre-fixme[16]: `Module` has no attribute `weight`.
140149 new_linear .weight = child .weight
150+ if child .bias is not None :
151+ # pyre-fixme[16]: `Module` has no attribute `bias`.
152+ new_linear .bias = child .bias
141153 setattr (module , name , new_linear )
142154 else :
143155 _vk_replace_linear_int4 (
@@ -189,7 +201,6 @@ def _create_quantized_state_dict(
189201 mod .out_features < self .feature_limit
190202 and mod .in_features < self .feature_limit
191203 ):
192- assert not mod .bias
193204 out_features = mod .out_features
194205 in_features = mod .in_features
195206 logging .info (f"linear: { fqn } , in={ in_features } , out={ out_features } " )
@@ -210,7 +221,8 @@ def _create_quantized_state_dict(
210221 logging .warn (
211222 f"warning: { fqn } is padded to satisfy in_features % 1024 == 0"
212223 )
213- padded_in_features = find_multiple (in_features , (1024 ,))
224+ # pyre-ignore[6]: Incompatible parameter type
225+ padded_in_features = find_multiple (in_features , 1024 )
214226 weight = F .pad (
215227 weight , pad = (0 , padded_in_features - in_features )
216228 )
0 commit comments