@@ -210,19 +210,302 @@ def fuse_into_linear_qcnw_node(
210210 graph_module .graph .erase_node (dq_weight_node )
211211
212212
213+ #########################
214+ ## linear_qta8a_qga4w ##
215+ #########################
216+
217+
218+ def _is_dequantize_affine_node (node : torch .fx .Node ) -> bool :
219+ """Check if a node is a dequantize_affine operation."""
220+ return (
221+ node .op == "call_function"
222+ and node .target is not None
223+ and hasattr (node .target , "__name__" )
224+ and "dequantize_affine" in getattr (node .target , "__name__" , "" )
225+ )
226+
227+
228+ def _is_view_copy_node (node : torch .fx .Node ) -> bool :
229+ """Check if a node is a view_copy operation."""
230+ return (
231+ node .op == "call_function"
232+ and node .target is not None
233+ and hasattr (node .target , "__name__" )
234+ and "view_copy" in getattr (node .target , "__name__" , "" )
235+ )
236+
237+
238+ def _validate_qta8a_qga4w_nodes (
239+ input_node : torch .fx .node .Argument , weight_node : torch .fx .node .Argument
240+ ) -> Optional [torch .fx .Node ]:
241+ """
242+ Validate input and weight nodes for QTA8A_QGA4W pattern.
243+ Returns the actual input node (after handling view operations) or None if invalid.
244+ """
245+ # Type checking - ensure we have torch.fx.Node objects
246+ if not isinstance (weight_node , torch .fx .Node ) or not isinstance (
247+ input_node , torch .fx .Node
248+ ):
249+ return None
250+
251+ # Input may be preprocessed with a view node
252+ actual_input_node = input_node
253+ if _is_view_copy_node (input_node ):
254+ actual_input_node = input_node .args [0 ]
255+ if not isinstance (actual_input_node , torch .fx .Node ):
256+ return None
257+
258+ # Check if input is dequantized with dequantize_affine (from dynamic quantization)
259+ if not _is_dequantize_affine_node (actual_input_node ):
260+ return None
261+
262+ # Check if weight is dequantized with dequantize_affine
263+ if not _is_dequantize_affine_node (weight_node ):
264+ return None
265+
266+ return actual_input_node
267+
268+
269+ def _extract_weight_params (
270+ program : ExportedProgram , weight_node : torch .fx .Node
271+ ) -> Optional [Tuple [torch .fx .Node , torch .fx .Node , torch .fx .Node ]]:
272+ """Extract and validate weight parameters from dequantize_affine node."""
273+ # Get the original quantized weight and quantization parameters
274+ if len (weight_node .args ) < 4 :
275+ return None
276+
277+ orig_weight = weight_node .args [0 ]
278+ weight_scales = weight_node .args [2 ]
279+ weight_zeros = weight_node .args [3 ]
280+
281+ # Type checking
282+ if not isinstance (orig_weight , torch .fx .Node ) or not is_param_node (
283+ program , orig_weight
284+ ):
285+ return None
286+ if not isinstance (weight_scales , torch .fx .Node ) or not is_param_node (
287+ program , weight_scales
288+ ):
289+ return None
290+ if not isinstance (weight_zeros , torch .fx .Node ) or not is_param_node (
291+ program , weight_zeros
292+ ):
293+ return None
294+
295+ return orig_weight , weight_scales , weight_zeros
296+
297+
298+ def _validate_4bit_quantization (weight_tensor : torch .Tensor ) -> bool :
299+ """Check if weight tensor is quantized to 4 bits (values in [-8, 7] range)."""
300+ quant_min = weight_tensor .min ().item ()
301+ quant_max = weight_tensor .max ().item ()
302+ return quant_min >= - 8 and quant_max <= 7
303+
304+
305+ def _calculate_group_size (
306+ orig_weight_tensor : torch .Tensor , weight_scales_tensor : torch .Tensor
307+ ) -> Optional [int ]:
308+ """Calculate and validate group size from weight and scales tensors."""
309+ out_features , in_features = orig_weight_tensor .shape
310+
311+ if len (weight_scales_tensor .shape ) != 2 :
312+ return None
313+
314+ scales_out_features , num_groups = weight_scales_tensor .shape
315+
316+ if scales_out_features != out_features :
317+ return None
318+
319+ group_size = in_features // num_groups
320+ if in_features % group_size != 0 :
321+ return None
322+
323+ return group_size
324+
325+
326+ def matches_linear_qta8a_qga4w_pattern (
327+ program : ExportedProgram , node : torch .fx .Node
328+ ) -> Optional [Tuple [int , int ]]:
329+ """
330+ Checks if the nodes surrounding a linear node matches the pattern for dynamic
331+ activation + grouped weight quantized linear (QTA8A_QGA4W).
332+
333+ This pattern involves:
334+ 1. Dynamic quantization of input activations (8-bit)
335+ 2. Grouped quantization of weights (4-bit with group size)
336+
337+ The expected pattern from Int8DynActInt4WeightQuantizer is:
338+ scale, zero_point = choose_qparams_affine(input)
339+ quantized_input = quantize_affine(input, scale, zero_point)
340+ dequantized_input = dequantize_affine(quantized_input, ...)
341+ dequantized_weight = dequantize_affine(weight, weight_scales, weight_zeros)
342+ output = linear(dequantized_input, dequantized_weight)
343+
344+ If the pattern matches, return (group_size, weight_bits), otherwise None.
345+ """
346+ if not utils .is_linear_node (node ):
347+ return None
348+
349+ input_node = node .args [0 ]
350+ weight_node = node .args [1 ]
351+
352+ # Validate nodes and get actual input node
353+ actual_input_node = _validate_qta8a_qga4w_nodes (input_node , weight_node )
354+ if actual_input_node is None :
355+ return None
356+
357+ # Extract weight parameters
358+ if not isinstance (weight_node , torch .fx .Node ):
359+ return None
360+ weight_params = _extract_weight_params (program , weight_node )
361+ if weight_params is None :
362+ return None
363+
364+ orig_weight , weight_scales , weight_zeros = weight_params
365+
366+ # Get tensors to analyze the quantization scheme
367+ orig_weight_tensor = get_param_tensor (program , orig_weight )
368+ weight_scales_tensor = get_param_tensor (program , weight_scales )
369+ weight_zeros_tensor = get_param_tensor (program , weight_zeros )
370+
371+ if not isinstance (orig_weight_tensor , torch .Tensor ):
372+ return None
373+ if not isinstance (weight_scales_tensor , torch .Tensor ):
374+ return None
375+ if not isinstance (weight_zeros_tensor , torch .Tensor ):
376+ return None
377+
378+ # Check if weight is quantized to 4 bits
379+ if not _validate_4bit_quantization (orig_weight_tensor ):
380+ return None
381+
382+ # Calculate group size
383+ group_size = _calculate_group_size (orig_weight_tensor , weight_scales_tensor )
384+ if group_size is None :
385+ return None
386+
387+ # Verify this is 4-bit grouped quantization
388+ weight_bits = 4
389+
390+ return group_size , weight_bits
391+
392+
393+ def fuse_into_linear_qta8a_qga4w_node (
394+ program : ExportedProgram ,
395+ graph_module : torch .fx .GraphModule ,
396+ linear_node : torch .fx .Node ,
397+ group_size : int ,
398+ weight_bits : int ,
399+ ) -> None :
400+ """
401+ Fuse the dynamic activation + grouped weight quantized linear pattern into
402+ a single linear_qta8a_qga4w operator.
403+
404+ The pattern:
405+ dequantized_input = dequantize_affine(quantized_input, block_size, scale, zero_point, ...)
406+ dequantized_weight = dequantize_affine(weight, block_size, weight_scales, weight_zeros, ...)
407+ output = linear(dequantized_input, dequantized_weight)
408+
409+ Becomes:
410+ output = linear_qta8a_qga4w(quantized_input, input_scale, input_zero_point,
411+ weight, group_size, weight_scales, weight_zeros)
412+ """
413+ dq_input_node = linear_node .args [0 ]
414+ dq_weight_node = linear_node .args [1 ]
415+
416+ assert isinstance (dq_input_node , torch .fx .Node )
417+
418+ input_view_node = None
419+ # Input may be preprocessed with a view node
420+ if (
421+ dq_input_node .op == "call_function"
422+ and dq_input_node .target is not None
423+ and hasattr (dq_input_node .target , "__name__" )
424+ and "view_copy" in getattr (dq_input_node .target , "__name__" , "" )
425+ ):
426+ input_view_node = dq_input_node
427+ dq_input_node = dq_input_node .args [0 ]
428+ assert isinstance (dq_input_node , torch .fx .Node )
429+
430+ assert isinstance (dq_input_node , torch .fx .Node )
431+ assert isinstance (dq_weight_node , torch .fx .Node )
432+
433+ # Get the quantized input and quantization parameters from the input dequantize_affine node
434+ # Args: (input, block_size, scale, zero_point, input_dtype, quant_min, quant_max, output_dtype)
435+ quantized_input = dq_input_node .args [0 ]
436+ input_scale = dq_input_node .args [2 ] # scale is the 3rd argument
437+ input_zero_point = dq_input_node .args [3 ] if len (dq_input_node .args ) > 3 else None
438+
439+ # Get the weight and its quantization parameters from dequantize_affine
440+ # Args: (weight, block_size, weight_scales, weight_zeros, input_dtype, quant_min, quant_max, output_dtype)
441+ orig_weight = dq_weight_node .args [0 ]
442+ weight_scales = dq_weight_node .args [2 ]
443+ weight_zeros = dq_weight_node .args [3 ]
444+
445+ # Pack the 4-bit weight tensor for efficient storage
446+ assert isinstance (orig_weight , torch .fx .Node )
447+ orig_weight_tensor = get_param_tensor (program , orig_weight )
448+ assert isinstance (orig_weight_tensor , torch .Tensor )
449+ packed_weight_tensor = pack_4bit_weight_tensor (orig_weight_tensor )
450+ utils .update_program_state_dict (
451+ program ,
452+ orig_weight .name ,
453+ packed_weight_tensor ,
454+ )
455+ # Update the metadata to reflect the new packed shape
456+ orig_weight .meta ["val" ] = orig_weight .meta ["val" ][:, ::2 ].to (torch .uint8 )
457+
458+ # Create the linear_qta8a_qga4w node
459+ with graph_module .graph .inserting_before (linear_node ):
460+ linear_qta8a_qga4w_node = graph_module .graph .create_node (
461+ "call_function" ,
462+ exir_ops .edge .et_vk .linear_qta8a_qga4w .default ,
463+ (
464+ quantized_input , # quantized input (int8)
465+ input_scale , # mat1_scale
466+ input_zero_point , # mat1_zero_point
467+ orig_weight , # mat2_data (packed 4-bit weights)
468+ group_size , # group_size (int)
469+ weight_scales , # weight_scales
470+ weight_zeros , # weight_zeros
471+ ),
472+ )
473+
474+ # Replace the linear node with the new fused node
475+ linear_node .replace_all_uses_with (linear_qta8a_qga4w_node )
476+
477+ # Erase nodes in the correct order (users first, then dependencies)
478+ graph_module .graph .erase_node (linear_node )
479+ if input_view_node is not None :
480+ graph_module .graph .erase_node (input_view_node )
481+ graph_module .graph .erase_node (dq_weight_node )
482+ graph_module .graph .erase_node (dq_input_node )
483+
484+
213485class FuseQuantizedOpsTransform (ExportPass ):
214486 def __init__ (self , exported_program : ExportedProgram ) -> None :
215487 super ().__init__ ()
216488 self .program = exported_program
217489
218490 def call (self , graph_module : torch .fx .GraphModule ) -> PassResult :
219491 for node in graph_module .graph .nodes :
492+ # Check for linear_qcnw pattern (weight-only quantization)
220493 qcnw_details = matches_linear_qcnw_pattern (self .program , node )
221494 if qcnw_details is not None :
222495 qcnw_method , qcnw_nbits = qcnw_details
223496 fuse_into_linear_qcnw_node (
224497 self .program , graph_module , node , qcnw_method , qcnw_nbits
225498 )
499+ continue
500+
501+ # Check for linear_qta8a_qga4w pattern (dynamic activation + grouped weight quantization)
502+ qta8a_qga4w_details = matches_linear_qta8a_qga4w_pattern (self .program , node )
503+ if qta8a_qga4w_details is not None :
504+ group_size , weight_bits = qta8a_qga4w_details
505+ fuse_into_linear_qta8a_qga4w_node (
506+ self .program , graph_module , node , group_size , weight_bits
507+ )
508+ continue
226509
227510 graph_module .recompile ()
228511 dead_code_elimination_pass (graph_module )
0 commit comments