@@ -210,278 +210,6 @@ def fuse_into_linear_qcnw_node(
210
210
graph_module .graph .erase_node (dq_weight_node )
211
211
212
212
213
- #########################
214
- ## linear_qta8a_qga4w ##
215
- #########################
216
-
217
-
218
- def _is_dequantize_affine_node (node : torch .fx .Node ) -> bool :
219
- """Check if a node is a dequantize_affine operation."""
220
- return (
221
- node .op == "call_function"
222
- and node .target is not None
223
- and hasattr (node .target , "__name__" )
224
- and "dequantize_affine" in getattr (node .target , "__name__" , "" )
225
- )
226
-
227
-
228
- def _is_view_copy_node (node : torch .fx .Node ) -> bool :
229
- """Check if a node is a view_copy operation."""
230
- return (
231
- node .op == "call_function"
232
- and node .target is not None
233
- and hasattr (node .target , "__name__" )
234
- and "view_copy" in getattr (node .target , "__name__" , "" )
235
- )
236
-
237
-
238
- def _validate_qta8a_qga4w_nodes (
239
- input_node : torch .fx .node .Argument , weight_node : torch .fx .node .Argument
240
- ) -> Optional [torch .fx .Node ]:
241
- """
242
- Validate input and weight nodes for QTA8A_QGA4W pattern.
243
- Returns the actual input node (after handling view operations) or None if invalid.
244
- """
245
- # Type checking - ensure we have torch.fx.Node objects
246
- if not isinstance (weight_node , torch .fx .Node ) or not isinstance (
247
- input_node , torch .fx .Node
248
- ):
249
- return None
250
-
251
- # Input may be preprocessed with a view node
252
- actual_input_node = input_node
253
- if _is_view_copy_node (input_node ):
254
- actual_input_node = input_node .args [0 ]
255
- if not isinstance (actual_input_node , torch .fx .Node ):
256
- return None
257
-
258
- # Check if input is dequantized with dequantize_affine (from dynamic quantization)
259
- if not _is_dequantize_affine_node (actual_input_node ):
260
- return None
261
-
262
- # Check if weight is dequantized with dequantize_affine
263
- if not _is_dequantize_affine_node (weight_node ):
264
- return None
265
-
266
- return actual_input_node
267
-
268
-
269
- def _extract_weight_params (
270
- program : ExportedProgram , weight_node : torch .fx .Node
271
- ) -> Optional [Tuple [torch .fx .Node , torch .fx .Node , torch .fx .Node ]]:
272
- """Extract and validate weight parameters from dequantize_affine node."""
273
- # Get the original quantized weight and quantization parameters
274
- if len (weight_node .args ) < 4 :
275
- return None
276
-
277
- orig_weight = weight_node .args [0 ]
278
- weight_scales = weight_node .args [2 ]
279
- weight_zeros = weight_node .args [3 ]
280
-
281
- # Type checking
282
- if not isinstance (orig_weight , torch .fx .Node ) or not is_param_node (
283
- program , orig_weight
284
- ):
285
- return None
286
- if not isinstance (weight_scales , torch .fx .Node ) or not is_param_node (
287
- program , weight_scales
288
- ):
289
- return None
290
- if not isinstance (weight_zeros , torch .fx .Node ) or not is_param_node (
291
- program , weight_zeros
292
- ):
293
- return None
294
-
295
- return orig_weight , weight_scales , weight_zeros
296
-
297
-
298
- def _validate_4bit_quantization (weight_tensor : torch .Tensor ) -> bool :
299
- """Check if weight tensor is quantized to 4 bits (values in [-8, 7] range)."""
300
- quant_min = weight_tensor .min ().item ()
301
- quant_max = weight_tensor .max ().item ()
302
- return quant_min >= - 8 and quant_max <= 7
303
-
304
-
305
- def _calculate_group_size (
306
- orig_weight_tensor : torch .Tensor , weight_scales_tensor : torch .Tensor
307
- ) -> Optional [int ]:
308
- """Calculate and validate group size from weight and scales tensors."""
309
- out_features , in_features = orig_weight_tensor .shape
310
-
311
- if len (weight_scales_tensor .shape ) != 2 :
312
- return None
313
-
314
- scales_out_features , num_groups = weight_scales_tensor .shape
315
-
316
- if scales_out_features != out_features :
317
- return None
318
-
319
- group_size = in_features // num_groups
320
- if in_features % group_size != 0 :
321
- return None
322
-
323
- return group_size
324
-
325
-
326
- def matches_linear_qta8a_qga4w_pattern (
327
- program : ExportedProgram , node : torch .fx .Node
328
- ) -> Optional [Tuple [int , int ]]:
329
- """
330
- Checks if the nodes surrounding a linear node matches the pattern for dynamic
331
- activation + grouped weight quantized linear (QTA8A_QGA4W).
332
-
333
- This pattern involves:
334
- 1. Dynamic quantization of input activations (8-bit)
335
- 2. Grouped quantization of weights (4-bit with group size)
336
-
337
- The expected pattern from Int8DynActInt4WeightQuantizer is:
338
- scale, zero_point = choose_qparams_affine(input)
339
- quantized_input = quantize_affine(input, scale, zero_point)
340
- dequantized_input = dequantize_affine(quantized_input, ...)
341
- dequantized_weight = dequantize_affine(weight, weight_scales, weight_zeros)
342
- output = linear(dequantized_input, dequantized_weight)
343
-
344
- If the pattern matches, return (group_size, weight_bits), otherwise None.
345
- """
346
- if not utils .is_linear_node (node ):
347
- return None
348
-
349
- input_node = node .args [0 ]
350
- weight_node = node .args [1 ]
351
-
352
- # Validate nodes and get actual input node
353
- actual_input_node = _validate_qta8a_qga4w_nodes (input_node , weight_node )
354
- if actual_input_node is None :
355
- return None
356
-
357
- # Extract weight parameters
358
- if not isinstance (weight_node , torch .fx .Node ):
359
- return None
360
- weight_params = _extract_weight_params (program , weight_node )
361
- if weight_params is None :
362
- return None
363
-
364
- orig_weight , weight_scales , weight_zeros = weight_params
365
-
366
- # Get tensors to analyze the quantization scheme
367
- orig_weight_tensor = get_param_tensor (program , orig_weight )
368
- weight_scales_tensor = get_param_tensor (program , weight_scales )
369
- weight_zeros_tensor = get_param_tensor (program , weight_zeros )
370
-
371
- if not isinstance (orig_weight_tensor , torch .Tensor ):
372
- return None
373
- if not isinstance (weight_scales_tensor , torch .Tensor ):
374
- return None
375
- if not isinstance (weight_zeros_tensor , torch .Tensor ):
376
- return None
377
-
378
- # Check if weight is quantized to 4 bits
379
- if not _validate_4bit_quantization (orig_weight_tensor ):
380
- return None
381
-
382
- # Calculate group size
383
- group_size = _calculate_group_size (orig_weight_tensor , weight_scales_tensor )
384
- if group_size is None :
385
- return None
386
-
387
- # Verify this is 4-bit grouped quantization
388
- weight_bits = 4
389
-
390
- return group_size , weight_bits
391
-
392
-
393
- def fuse_into_linear_qta8a_qga4w_node (
394
- program : ExportedProgram ,
395
- graph_module : torch .fx .GraphModule ,
396
- linear_node : torch .fx .Node ,
397
- group_size : int ,
398
- weight_bits : int ,
399
- ) -> None :
400
- """
401
- Fuse the dynamic activation + grouped weight quantized linear pattern into
402
- a single linear_qta8a_qga4w operator.
403
-
404
- The pattern:
405
- dequantized_input = dequantize_affine(quantized_input, block_size, scale, zero_point, ...)
406
- dequantized_weight = dequantize_affine(weight, block_size, weight_scales, weight_zeros, ...)
407
- output = linear(dequantized_input, dequantized_weight)
408
-
409
- Becomes:
410
- output = linear_qta8a_qga4w(quantized_input, input_scale, input_zero_point,
411
- weight, group_size, weight_scales, weight_zeros)
412
- """
413
- dq_input_node = linear_node .args [0 ]
414
- dq_weight_node = linear_node .args [1 ]
415
-
416
- assert isinstance (dq_input_node , torch .fx .Node )
417
-
418
- input_view_node = None
419
- # Input may be preprocessed with a view node
420
- if (
421
- dq_input_node .op == "call_function"
422
- and dq_input_node .target is not None
423
- and hasattr (dq_input_node .target , "__name__" )
424
- and "view_copy" in getattr (dq_input_node .target , "__name__" , "" )
425
- ):
426
- input_view_node = dq_input_node
427
- dq_input_node = dq_input_node .args [0 ]
428
- assert isinstance (dq_input_node , torch .fx .Node )
429
-
430
- assert isinstance (dq_input_node , torch .fx .Node )
431
- assert isinstance (dq_weight_node , torch .fx .Node )
432
-
433
- # Get the quantized input and quantization parameters from the input dequantize_affine node
434
- # Args: (input, block_size, scale, zero_point, input_dtype, quant_min, quant_max, output_dtype)
435
- quantized_input = dq_input_node .args [0 ]
436
- input_scale = dq_input_node .args [2 ] # scale is the 3rd argument
437
- input_zero_point = dq_input_node .args [3 ] if len (dq_input_node .args ) > 3 else None
438
-
439
- # Get the weight and its quantization parameters from dequantize_affine
440
- # Args: (weight, block_size, weight_scales, weight_zeros, input_dtype, quant_min, quant_max, output_dtype)
441
- orig_weight = dq_weight_node .args [0 ]
442
- weight_scales = dq_weight_node .args [2 ]
443
- weight_zeros = dq_weight_node .args [3 ]
444
-
445
- # Pack the 4-bit weight tensor for efficient storage
446
- assert isinstance (orig_weight , torch .fx .Node )
447
- orig_weight_tensor = get_param_tensor (program , orig_weight )
448
- assert isinstance (orig_weight_tensor , torch .Tensor )
449
- packed_weight_tensor = pack_4bit_weight_tensor (orig_weight_tensor )
450
- utils .update_program_state_dict (
451
- program ,
452
- orig_weight .name ,
453
- packed_weight_tensor ,
454
- )
455
- # Update the metadata to reflect the new packed shape
456
- orig_weight .meta ["val" ] = orig_weight .meta ["val" ][:, ::2 ].to (torch .uint8 )
457
-
458
- # Create the linear_qta8a_qga4w node
459
- with graph_module .graph .inserting_before (linear_node ):
460
- linear_qta8a_qga4w_node = graph_module .graph .create_node (
461
- "call_function" ,
462
- exir_ops .edge .et_vk .linear_qta8a_qga4w .default ,
463
- (
464
- quantized_input , # quantized input (int8)
465
- input_scale , # mat1_scale
466
- input_zero_point , # mat1_zero_point
467
- orig_weight , # mat2_data (packed 4-bit weights)
468
- group_size , # group_size (int)
469
- weight_scales , # weight_scales
470
- weight_zeros , # weight_zeros
471
- ),
472
- )
473
-
474
- # Replace the linear node with the new fused node
475
- linear_node .replace_all_uses_with (linear_qta8a_qga4w_node )
476
-
477
- # Erase nodes in the correct order (users first, then dependencies)
478
- graph_module .graph .erase_node (linear_node )
479
- if input_view_node is not None :
480
- graph_module .graph .erase_node (input_view_node )
481
- graph_module .graph .erase_node (dq_weight_node )
482
- graph_module .graph .erase_node (dq_input_node )
483
-
484
-
485
213
class FuseQuantizedOpsTransform (ExportPass ):
486
214
def __init__ (self , exported_program : ExportedProgram ) -> None :
487
215
super ().__init__ ()
@@ -498,15 +226,6 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
498
226
)
499
227
continue
500
228
501
- # Check for linear_qta8a_qga4w pattern (dynamic activation + grouped weight quantization)
502
- qta8a_qga4w_details = None
503
- if qta8a_qga4w_details is not None :
504
- group_size , weight_bits = qta8a_qga4w_details
505
- fuse_into_linear_qta8a_qga4w_node (
506
- self .program , graph_module , node , group_size , weight_bits
507
- )
508
- continue
509
-
510
229
graph_module .recompile ()
511
230
dead_code_elimination_pass (graph_module )
512
231
0 commit comments