@@ -210,19 +210,302 @@ def fuse_into_linear_qcnw_node(
210
210
graph_module .graph .erase_node (dq_weight_node )
211
211
212
212
213
+ #########################
214
+ ## linear_qta8a_qga4w ##
215
+ #########################
216
+
217
+
218
+ def _is_dequantize_affine_node (node : torch .fx .Node ) -> bool :
219
+ """Check if a node is a dequantize_affine operation."""
220
+ return (
221
+ node .op == "call_function"
222
+ and node .target is not None
223
+ and hasattr (node .target , "__name__" )
224
+ and "dequantize_affine" in getattr (node .target , "__name__" , "" )
225
+ )
226
+
227
+
228
+ def _is_view_copy_node (node : torch .fx .Node ) -> bool :
229
+ """Check if a node is a view_copy operation."""
230
+ return (
231
+ node .op == "call_function"
232
+ and node .target is not None
233
+ and hasattr (node .target , "__name__" )
234
+ and "view_copy" in getattr (node .target , "__name__" , "" )
235
+ )
236
+
237
+
238
+ def _validate_qta8a_qga4w_nodes (
239
+ input_node : torch .fx .node .Argument , weight_node : torch .fx .node .Argument
240
+ ) -> Optional [torch .fx .Node ]:
241
+ """
242
+ Validate input and weight nodes for QTA8A_QGA4W pattern.
243
+ Returns the actual input node (after handling view operations) or None if invalid.
244
+ """
245
+ # Type checking - ensure we have torch.fx.Node objects
246
+ if not isinstance (weight_node , torch .fx .Node ) or not isinstance (
247
+ input_node , torch .fx .Node
248
+ ):
249
+ return None
250
+
251
+ # Input may be preprocessed with a view node
252
+ actual_input_node = input_node
253
+ if _is_view_copy_node (input_node ):
254
+ actual_input_node = input_node .args [0 ]
255
+ if not isinstance (actual_input_node , torch .fx .Node ):
256
+ return None
257
+
258
+ # Check if input is dequantized with dequantize_affine (from dynamic quantization)
259
+ if not _is_dequantize_affine_node (actual_input_node ):
260
+ return None
261
+
262
+ # Check if weight is dequantized with dequantize_affine
263
+ if not _is_dequantize_affine_node (weight_node ):
264
+ return None
265
+
266
+ return actual_input_node
267
+
268
+
269
+ def _extract_weight_params (
270
+ program : ExportedProgram , weight_node : torch .fx .Node
271
+ ) -> Optional [Tuple [torch .fx .Node , torch .fx .Node , torch .fx .Node ]]:
272
+ """Extract and validate weight parameters from dequantize_affine node."""
273
+ # Get the original quantized weight and quantization parameters
274
+ if len (weight_node .args ) < 4 :
275
+ return None
276
+
277
+ orig_weight = weight_node .args [0 ]
278
+ weight_scales = weight_node .args [2 ]
279
+ weight_zeros = weight_node .args [3 ]
280
+
281
+ # Type checking
282
+ if not isinstance (orig_weight , torch .fx .Node ) or not is_param_node (
283
+ program , orig_weight
284
+ ):
285
+ return None
286
+ if not isinstance (weight_scales , torch .fx .Node ) or not is_param_node (
287
+ program , weight_scales
288
+ ):
289
+ return None
290
+ if not isinstance (weight_zeros , torch .fx .Node ) or not is_param_node (
291
+ program , weight_zeros
292
+ ):
293
+ return None
294
+
295
+ return orig_weight , weight_scales , weight_zeros
296
+
297
+
298
+ def _validate_4bit_quantization (weight_tensor : torch .Tensor ) -> bool :
299
+ """Check if weight tensor is quantized to 4 bits (values in [-8, 7] range)."""
300
+ quant_min = weight_tensor .min ().item ()
301
+ quant_max = weight_tensor .max ().item ()
302
+ return quant_min >= - 8 and quant_max <= 7
303
+
304
+
305
+ def _calculate_group_size (
306
+ orig_weight_tensor : torch .Tensor , weight_scales_tensor : torch .Tensor
307
+ ) -> Optional [int ]:
308
+ """Calculate and validate group size from weight and scales tensors."""
309
+ out_features , in_features = orig_weight_tensor .shape
310
+
311
+ if len (weight_scales_tensor .shape ) != 2 :
312
+ return None
313
+
314
+ scales_out_features , num_groups = weight_scales_tensor .shape
315
+
316
+ if scales_out_features != out_features :
317
+ return None
318
+
319
+ group_size = in_features // num_groups
320
+ if in_features % group_size != 0 :
321
+ return None
322
+
323
+ return group_size
324
+
325
+
326
+ def matches_linear_qta8a_qga4w_pattern (
327
+ program : ExportedProgram , node : torch .fx .Node
328
+ ) -> Optional [Tuple [int , int ]]:
329
+ """
330
+ Checks if the nodes surrounding a linear node matches the pattern for dynamic
331
+ activation + grouped weight quantized linear (QTA8A_QGA4W).
332
+
333
+ This pattern involves:
334
+ 1. Dynamic quantization of input activations (8-bit)
335
+ 2. Grouped quantization of weights (4-bit with group size)
336
+
337
+ The expected pattern from Int8DynActInt4WeightQuantizer is:
338
+ scale, zero_point = choose_qparams_affine(input)
339
+ quantized_input = quantize_affine(input, scale, zero_point)
340
+ dequantized_input = dequantize_affine(quantized_input, ...)
341
+ dequantized_weight = dequantize_affine(weight, weight_scales, weight_zeros)
342
+ output = linear(dequantized_input, dequantized_weight)
343
+
344
+ If the pattern matches, return (group_size, weight_bits), otherwise None.
345
+ """
346
+ if not utils .is_linear_node (node ):
347
+ return None
348
+
349
+ input_node = node .args [0 ]
350
+ weight_node = node .args [1 ]
351
+
352
+ # Validate nodes and get actual input node
353
+ actual_input_node = _validate_qta8a_qga4w_nodes (input_node , weight_node )
354
+ if actual_input_node is None :
355
+ return None
356
+
357
+ # Extract weight parameters
358
+ if not isinstance (weight_node , torch .fx .Node ):
359
+ return None
360
+ weight_params = _extract_weight_params (program , weight_node )
361
+ if weight_params is None :
362
+ return None
363
+
364
+ orig_weight , weight_scales , weight_zeros = weight_params
365
+
366
+ # Get tensors to analyze the quantization scheme
367
+ orig_weight_tensor = get_param_tensor (program , orig_weight )
368
+ weight_scales_tensor = get_param_tensor (program , weight_scales )
369
+ weight_zeros_tensor = get_param_tensor (program , weight_zeros )
370
+
371
+ if not isinstance (orig_weight_tensor , torch .Tensor ):
372
+ return None
373
+ if not isinstance (weight_scales_tensor , torch .Tensor ):
374
+ return None
375
+ if not isinstance (weight_zeros_tensor , torch .Tensor ):
376
+ return None
377
+
378
+ # Check if weight is quantized to 4 bits
379
+ if not _validate_4bit_quantization (orig_weight_tensor ):
380
+ return None
381
+
382
+ # Calculate group size
383
+ group_size = _calculate_group_size (orig_weight_tensor , weight_scales_tensor )
384
+ if group_size is None :
385
+ return None
386
+
387
+ # Verify this is 4-bit grouped quantization
388
+ weight_bits = 4
389
+
390
+ return group_size , weight_bits
391
+
392
+
393
+ def fuse_into_linear_qta8a_qga4w_node (
394
+ program : ExportedProgram ,
395
+ graph_module : torch .fx .GraphModule ,
396
+ linear_node : torch .fx .Node ,
397
+ group_size : int ,
398
+ weight_bits : int ,
399
+ ) -> None :
400
+ """
401
+ Fuse the dynamic activation + grouped weight quantized linear pattern into
402
+ a single linear_qta8a_qga4w operator.
403
+
404
+ The pattern:
405
+ dequantized_input = dequantize_affine(quantized_input, block_size, scale, zero_point, ...)
406
+ dequantized_weight = dequantize_affine(weight, block_size, weight_scales, weight_zeros, ...)
407
+ output = linear(dequantized_input, dequantized_weight)
408
+
409
+ Becomes:
410
+ output = linear_qta8a_qga4w(quantized_input, input_scale, input_zero_point,
411
+ weight, group_size, weight_scales, weight_zeros)
412
+ """
413
+ dq_input_node = linear_node .args [0 ]
414
+ dq_weight_node = linear_node .args [1 ]
415
+
416
+ assert isinstance (dq_input_node , torch .fx .Node )
417
+
418
+ input_view_node = None
419
+ # Input may be preprocessed with a view node
420
+ if (
421
+ dq_input_node .op == "call_function"
422
+ and dq_input_node .target is not None
423
+ and hasattr (dq_input_node .target , "__name__" )
424
+ and "view_copy" in getattr (dq_input_node .target , "__name__" , "" )
425
+ ):
426
+ input_view_node = dq_input_node
427
+ dq_input_node = dq_input_node .args [0 ]
428
+ assert isinstance (dq_input_node , torch .fx .Node )
429
+
430
+ assert isinstance (dq_input_node , torch .fx .Node )
431
+ assert isinstance (dq_weight_node , torch .fx .Node )
432
+
433
+ # Get the quantized input and quantization parameters from the input dequantize_affine node
434
+ # Args: (input, block_size, scale, zero_point, input_dtype, quant_min, quant_max, output_dtype)
435
+ quantized_input = dq_input_node .args [0 ]
436
+ input_scale = dq_input_node .args [2 ] # scale is the 3rd argument
437
+ input_zero_point = dq_input_node .args [3 ] if len (dq_input_node .args ) > 3 else None
438
+
439
+ # Get the weight and its quantization parameters from dequantize_affine
440
+ # Args: (weight, block_size, weight_scales, weight_zeros, input_dtype, quant_min, quant_max, output_dtype)
441
+ orig_weight = dq_weight_node .args [0 ]
442
+ weight_scales = dq_weight_node .args [2 ]
443
+ weight_zeros = dq_weight_node .args [3 ]
444
+
445
+ # Pack the 4-bit weight tensor for efficient storage
446
+ assert isinstance (orig_weight , torch .fx .Node )
447
+ orig_weight_tensor = get_param_tensor (program , orig_weight )
448
+ assert isinstance (orig_weight_tensor , torch .Tensor )
449
+ packed_weight_tensor = pack_4bit_weight_tensor (orig_weight_tensor )
450
+ utils .update_program_state_dict (
451
+ program ,
452
+ orig_weight .name ,
453
+ packed_weight_tensor ,
454
+ )
455
+ # Update the metadata to reflect the new packed shape
456
+ orig_weight .meta ["val" ] = orig_weight .meta ["val" ][:, ::2 ].to (torch .uint8 )
457
+
458
+ # Create the linear_qta8a_qga4w node
459
+ with graph_module .graph .inserting_before (linear_node ):
460
+ linear_qta8a_qga4w_node = graph_module .graph .create_node (
461
+ "call_function" ,
462
+ exir_ops .edge .et_vk .linear_qta8a_qga4w .default ,
463
+ (
464
+ quantized_input , # quantized input (int8)
465
+ input_scale , # mat1_scale
466
+ input_zero_point , # mat1_zero_point
467
+ orig_weight , # mat2_data (packed 4-bit weights)
468
+ group_size , # group_size (int)
469
+ weight_scales , # weight_scales
470
+ weight_zeros , # weight_zeros
471
+ ),
472
+ )
473
+
474
+ # Replace the linear node with the new fused node
475
+ linear_node .replace_all_uses_with (linear_qta8a_qga4w_node )
476
+
477
+ # Erase nodes in the correct order (users first, then dependencies)
478
+ graph_module .graph .erase_node (linear_node )
479
+ if input_view_node is not None :
480
+ graph_module .graph .erase_node (input_view_node )
481
+ graph_module .graph .erase_node (dq_weight_node )
482
+ graph_module .graph .erase_node (dq_input_node )
483
+
484
+
213
485
class FuseQuantizedOpsTransform (ExportPass ):
214
486
def __init__ (self , exported_program : ExportedProgram ) -> None :
215
487
super ().__init__ ()
216
488
self .program = exported_program
217
489
218
490
def call (self , graph_module : torch .fx .GraphModule ) -> PassResult :
219
491
for node in graph_module .graph .nodes :
492
+ # Check for linear_qcnw pattern (weight-only quantization)
220
493
qcnw_details = matches_linear_qcnw_pattern (self .program , node )
221
494
if qcnw_details is not None :
222
495
qcnw_method , qcnw_nbits = qcnw_details
223
496
fuse_into_linear_qcnw_node (
224
497
self .program , graph_module , node , qcnw_method , qcnw_nbits
225
498
)
499
+ continue
500
+
501
+ # Check for linear_qta8a_qga4w pattern (dynamic activation + grouped weight quantization)
502
+ qta8a_qga4w_details = matches_linear_qta8a_qga4w_pattern (self .program , node )
503
+ if qta8a_qga4w_details is not None :
504
+ group_size , weight_bits = qta8a_qga4w_details
505
+ fuse_into_linear_qta8a_qga4w_node (
506
+ self .program , graph_module , node , group_size , weight_bits
507
+ )
508
+ continue
226
509
227
510
graph_module .recompile ()
228
511
dead_code_elimination_pass (graph_module )
0 commit comments