@@ -191,116 +191,94 @@ def find_quantized_linear_patterns(
191191##
192192
193193
194- def pack_4bit_weight_tensor (inp : torch .Tensor ) -> torch .Tensor :
194+ def pack_4bit_weight_tensor (weight_tensor : torch .Tensor ) -> torch .Tensor :
195195 """
196196 Given a 8-bit weight tensor containing values quantized to 4 bits, create a packed
197- weight tensor by packing 2 4-bit values in one unsigned 8-bit value.
197+ weight tensor by transposing the weight tensor, then packing 2 4-bit values in one
198+ 8-bit value.
198199
199- An input weight tensor of shape (M, K) will produce a packed weight tensor of shape
200- (M, K / 2).
201-
202- The packing implemented here is the same as the packing produced by
203- backends/vulkan/_passes/int4_weight_only_quantizer.py
200+ An input weight tensor of shape (N, K) will produce a packed weight tensor of shape
201+ (K, N / 2).
204202 """
205203
206204 # Assert we got a properly quantized tensor.
207- min , max = inp .min ().item (), inp .max ().item ()
205+ min_val , max_val = weight_tensor .min ().item (), weight_tensor .max ().item ()
208206 assert (
209- max <= 7 and min >= - 8
210- ), f"pack_4bit_weight_tensor: [min,max ] out of [-8, 7] range, got [{ min } , { max } ]"
207+ max_val <= 7 and min_val >= - 8
208+ ), f"pack_4bit_weight_tensor: [min_val,max_val ] out of [-8, 7] range, got [{ min_val } , { max_val } ]"
211209
212210 # Assuming we have a 2d tensor
213- if inp .ndim != 2 :
214- inp = inp .squeeze ()
211+ if weight_tensor .ndim != 2 :
212+ weight_tensor = weight_tensor .squeeze ()
215213 assert (
216- inp .ndim == 2
217- ), f"pack_4bit_weight_tensor: expecting input tensor to be 2d, got { inp .ndim } "
214+ weight_tensor .ndim == 2
215+ ), f"pack_4bit_weight_tensor: expecting input tensor to be 2d, got { weight_tensor .ndim } "
218216
219- # pad ic
220- if inp .shape [- 1 ] % 2 != 0 :
221- inp = F .pad (input = inp , pad = (0 , 1 , 0 , 0 ), mode = "constant" , value = 0 )
217+ # Need to pad innermost dim to be a multiple of 8, since the minimum load granularity
218+ # is int32 (4 bytes), which contains 8 4-bit values.
219+ if weight_tensor .shape [- 1 ] % 8 != 0 :
220+ num_pad = 8 - (weight_tensor .shape [- 1 ] % 8 )
221+ weight_tensor = F .pad (input = weight_tensor , pad = (0 , num_pad ))
222222
223223 # Shape after padding
224- oc , ic = inp .shape
225- assert ic % 2 == 0 , "convert_to_qc4w: expecting ic to be even "
224+ _ , in_channels = weight_tensor .shape
225+ assert in_channels % 8 == 0 , "convert_to_qc4w: expecting ic to be divisible by 8 "
226226
227- # Adjust inp tensor for zp
228- inp = inp .to (dtype = torch .uint8 ) + 8
227+ # Adjust weight_tensor tensor for zp
228+ weight_tensor = weight_tensor .to (dtype = torch .uint8 ) + 8
229229 # Pack each 4-bit value into a single 8-bit value
230- return inp [::, ::2 ] << 4 | inp [::, 1 ::2 ]
231-
232-
233- def make_combined_scales_and_zeros_tensor (
234- scales : torch .Tensor , zeros : torch .Tensor
235- ) -> torch .Tensor :
236- """
237- Given a scales and zeros tensor, create a combined tensor by stacking them into a
238- single tensor.
239-
240- The scales and zeros tensors are expected to be 2D tensors of shape
241- (OUTPUT_CHANNELS, NUM_GROUPS). The combined tensor will have the shape
242- (NUM_GROUPS, OUTPUT_CHANNELS, 2).
243-
244- This is the scales and zeros format produced by
245- backends/vulkan/_passes/int4_weight_only_quantizer.py, which in turn is the scales
246- and zeros format expected by the _weight_int4pack_mm op in ATen.
247- """
248- scales_reshaped = scales .transpose (0 , 1 ).unsqueeze (2 )
249- zeros_reshaped = zeros .transpose (0 , 1 ).unsqueeze (2 )
250-
251- zeros_scaled = zeros_reshaped * scales_reshaped * - 1
252- return torch .cat ((scales_reshaped , zeros_scaled ), dim = 2 )
230+ return weight_tensor [::, 1 ::2 ] << 4 | weight_tensor [::, ::2 ]
253231
254232
255233##
256234## Pattern Replacement
257235##
258236
259237
260- def make_linear_q4ga_op (
238+ def make_linear_q4gsw_op (
261239 ep : ExportedProgram ,
262240 graph_module : torch .fx .GraphModule ,
263241 match : QuantizedLinearMatch ,
264242 weight_tensor : torch .Tensor ,
265243 weight_scales_tensor : torch .Tensor ,
266- weight_zeros_tensor : torch .Tensor ,
267244):
268- packed_quantized_weight_tensor = pack_4bit_weight_tensor (weight_tensor )
269- utils .update_program_state_dict (
270- ep , match .weight_node .name , packed_quantized_weight_tensor
271- )
272- # Need to make sure corresponding FakeTensor has same size
273- match .weight_node .meta ["val" ] = match .weight_node .meta ["val" ][:, ::2 ].to (
274- torch .uint8
275- )
276-
277- group_size = weight_tensor .shape [1 ] // weight_scales_tensor .shape [1 ]
278-
279- combined_scales_zeros_tensor = make_combined_scales_and_zeros_tensor (
280- weight_scales_tensor , weight_zeros_tensor
245+ num_groups = weight_scales_tensor .shape [- 1 ]
246+ in_channels = weight_tensor .shape [- 1 ]
247+ group_size = in_channels // num_groups
248+
249+ weight_tensor = pack_4bit_weight_tensor (weight_tensor )
250+ # Use this function for convenience to update the state dict with the packed
251+ # weight tensor. Alignment will already have been done in the above function.
252+ weight_tensor = utils .align_width_and_update_state_dict (
253+ ep , match .weight_node , weight_tensor , align_to = 1 , force_update = True
281254 )
282255
283- combined_scales_zeros_name = f"{ match .weight_node .name } _scales_zeros"
284- graph_module .register_parameter (
285- combined_scales_zeros_name , torch .nn .Parameter (combined_scales_zeros_tensor )
256+ # Also transpose the weight scales tensor to shape [num_groups, N]
257+ weight_scales_tensor = weight_scales_tensor .transpose (0 , 1 ).contiguous ()
258+ # Align to multiple of 8 to ensure that data loads from the weight scales
259+ # tensor do not go out of bounds. Each thread computes 8 output channels.
260+ utils .align_width_and_update_state_dict (
261+ ep ,
262+ match .weight_scales_node ,
263+ weight_scales_tensor ,
264+ align_to = 8 ,
265+ force_update = True ,
286266 )
287267
288268 with graph_module .graph .inserting_before (match .output_node ):
289- combined_scales_zeros = graph_module .graph .get_attr (combined_scales_zeros_name )
290- linear_q4ga_node = graph_module .graph .create_node (
269+ linear_q4gsw_node = graph_module .graph .create_node (
291270 "call_function" ,
292- exir_ops .edge .et_vk .linear_weight_int4 .default ,
271+ exir_ops .edge .et_vk .linear_q4gsw .default ,
293272 args = (
294273 match .fp_input_node ,
295274 match .weight_node ,
275+ match .weight_scales_node ,
296276 group_size ,
297- combined_scales_zeros ,
298- 1 ,
299277 ),
300278 )
301279
302- linear_q4ga_node .meta ["val" ] = match .output_node .meta ["val" ]
303- match .output_node .replace_all_uses_with (linear_q4ga_node )
280+ linear_q4gsw_node .meta ["val" ] = match .output_node .meta ["val" ]
281+ match .output_node .replace_all_uses_with (linear_q4gsw_node )
304282
305283
306284def make_linear_q8ta_q8csw_custom_op (
@@ -373,13 +351,8 @@ def replace_quantized_linear_patterns(
373351 and match .is_weight_pergroup_quantized ()
374352 and utils .is_in_4bit_range (weight_tensor )
375353 ):
376- make_linear_q4ga_op (
377- ep ,
378- graph_module ,
379- match ,
380- weight_tensor ,
381- weight_scales_tensor ,
382- weight_zeros_tensor ,
354+ make_linear_q4gsw_op (
355+ ep , graph_module , match , weight_tensor , weight_scales_tensor
383356 )
384357 elif (
385358 match .is_input_static_per_tensor_quantized ()
0 commit comments