@@ -210,278 +210,6 @@ def fuse_into_linear_qcnw_node(
210210 graph_module .graph .erase_node (dq_weight_node )
211211
212212
213- #########################
214- ## linear_qta8a_qga4w ##
215- #########################
216-
217-
218- def _is_dequantize_affine_node (node : torch .fx .Node ) -> bool :
219- """Check if a node is a dequantize_affine operation."""
220- return (
221- node .op == "call_function"
222- and node .target is not None
223- and hasattr (node .target , "__name__" )
224- and "dequantize_affine" in getattr (node .target , "__name__" , "" )
225- )
226-
227-
228- def _is_view_copy_node (node : torch .fx .Node ) -> bool :
229- """Check if a node is a view_copy operation."""
230- return (
231- node .op == "call_function"
232- and node .target is not None
233- and hasattr (node .target , "__name__" )
234- and "view_copy" in getattr (node .target , "__name__" , "" )
235- )
236-
237-
238- def _validate_qta8a_qga4w_nodes (
239- input_node : torch .fx .node .Argument , weight_node : torch .fx .node .Argument
240- ) -> Optional [torch .fx .Node ]:
241- """
242- Validate input and weight nodes for QTA8A_QGA4W pattern.
243- Returns the actual input node (after handling view operations) or None if invalid.
244- """
245- # Type checking - ensure we have torch.fx.Node objects
246- if not isinstance (weight_node , torch .fx .Node ) or not isinstance (
247- input_node , torch .fx .Node
248- ):
249- return None
250-
251- # Input may be preprocessed with a view node
252- actual_input_node = input_node
253- if _is_view_copy_node (input_node ):
254- actual_input_node = input_node .args [0 ]
255- if not isinstance (actual_input_node , torch .fx .Node ):
256- return None
257-
258- # Check if input is dequantized with dequantize_affine (from dynamic quantization)
259- if not _is_dequantize_affine_node (actual_input_node ):
260- return None
261-
262- # Check if weight is dequantized with dequantize_affine
263- if not _is_dequantize_affine_node (weight_node ):
264- return None
265-
266- return actual_input_node
267-
268-
269- def _extract_weight_params (
270- program : ExportedProgram , weight_node : torch .fx .Node
271- ) -> Optional [Tuple [torch .fx .Node , torch .fx .Node , torch .fx .Node ]]:
272- """Extract and validate weight parameters from dequantize_affine node."""
273- # Get the original quantized weight and quantization parameters
274- if len (weight_node .args ) < 4 :
275- return None
276-
277- orig_weight = weight_node .args [0 ]
278- weight_scales = weight_node .args [2 ]
279- weight_zeros = weight_node .args [3 ]
280-
281- # Type checking
282- if not isinstance (orig_weight , torch .fx .Node ) or not is_param_node (
283- program , orig_weight
284- ):
285- return None
286- if not isinstance (weight_scales , torch .fx .Node ) or not is_param_node (
287- program , weight_scales
288- ):
289- return None
290- if not isinstance (weight_zeros , torch .fx .Node ) or not is_param_node (
291- program , weight_zeros
292- ):
293- return None
294-
295- return orig_weight , weight_scales , weight_zeros
296-
297-
298- def _validate_4bit_quantization (weight_tensor : torch .Tensor ) -> bool :
299- """Check if weight tensor is quantized to 4 bits (values in [-8, 7] range)."""
300- quant_min = weight_tensor .min ().item ()
301- quant_max = weight_tensor .max ().item ()
302- return quant_min >= - 8 and quant_max <= 7
303-
304-
305- def _calculate_group_size (
306- orig_weight_tensor : torch .Tensor , weight_scales_tensor : torch .Tensor
307- ) -> Optional [int ]:
308- """Calculate and validate group size from weight and scales tensors."""
309- out_features , in_features = orig_weight_tensor .shape
310-
311- if len (weight_scales_tensor .shape ) != 2 :
312- return None
313-
314- scales_out_features , num_groups = weight_scales_tensor .shape
315-
316- if scales_out_features != out_features :
317- return None
318-
319- group_size = in_features // num_groups
320- if in_features % group_size != 0 :
321- return None
322-
323- return group_size
324-
325-
326- def matches_linear_qta8a_qga4w_pattern (
327- program : ExportedProgram , node : torch .fx .Node
328- ) -> Optional [Tuple [int , int ]]:
329- """
330- Checks if the nodes surrounding a linear node matches the pattern for dynamic
331- activation + grouped weight quantized linear (QTA8A_QGA4W).
332-
333- This pattern involves:
334- 1. Dynamic quantization of input activations (8-bit)
335- 2. Grouped quantization of weights (4-bit with group size)
336-
337- The expected pattern from Int8DynActInt4WeightQuantizer is:
338- scale, zero_point = choose_qparams_affine(input)
339- quantized_input = quantize_affine(input, scale, zero_point)
340- dequantized_input = dequantize_affine(quantized_input, ...)
341- dequantized_weight = dequantize_affine(weight, weight_scales, weight_zeros)
342- output = linear(dequantized_input, dequantized_weight)
343-
344- If the pattern matches, return (group_size, weight_bits), otherwise None.
345- """
346- if not utils .is_linear_node (node ):
347- return None
348-
349- input_node = node .args [0 ]
350- weight_node = node .args [1 ]
351-
352- # Validate nodes and get actual input node
353- actual_input_node = _validate_qta8a_qga4w_nodes (input_node , weight_node )
354- if actual_input_node is None :
355- return None
356-
357- # Extract weight parameters
358- if not isinstance (weight_node , torch .fx .Node ):
359- return None
360- weight_params = _extract_weight_params (program , weight_node )
361- if weight_params is None :
362- return None
363-
364- orig_weight , weight_scales , weight_zeros = weight_params
365-
366- # Get tensors to analyze the quantization scheme
367- orig_weight_tensor = get_param_tensor (program , orig_weight )
368- weight_scales_tensor = get_param_tensor (program , weight_scales )
369- weight_zeros_tensor = get_param_tensor (program , weight_zeros )
370-
371- if not isinstance (orig_weight_tensor , torch .Tensor ):
372- return None
373- if not isinstance (weight_scales_tensor , torch .Tensor ):
374- return None
375- if not isinstance (weight_zeros_tensor , torch .Tensor ):
376- return None
377-
378- # Check if weight is quantized to 4 bits
379- if not _validate_4bit_quantization (orig_weight_tensor ):
380- return None
381-
382- # Calculate group size
383- group_size = _calculate_group_size (orig_weight_tensor , weight_scales_tensor )
384- if group_size is None :
385- return None
386-
387- # Verify this is 4-bit grouped quantization
388- weight_bits = 4
389-
390- return group_size , weight_bits
391-
392-
393- def fuse_into_linear_qta8a_qga4w_node (
394- program : ExportedProgram ,
395- graph_module : torch .fx .GraphModule ,
396- linear_node : torch .fx .Node ,
397- group_size : int ,
398- weight_bits : int ,
399- ) -> None :
400- """
401- Fuse the dynamic activation + grouped weight quantized linear pattern into
402- a single linear_qta8a_qga4w operator.
403-
404- The pattern:
405- dequantized_input = dequantize_affine(quantized_input, block_size, scale, zero_point, ...)
406- dequantized_weight = dequantize_affine(weight, block_size, weight_scales, weight_zeros, ...)
407- output = linear(dequantized_input, dequantized_weight)
408-
409- Becomes:
410- output = linear_qta8a_qga4w(quantized_input, input_scale, input_zero_point,
411- weight, group_size, weight_scales, weight_zeros)
412- """
413- dq_input_node = linear_node .args [0 ]
414- dq_weight_node = linear_node .args [1 ]
415-
416- assert isinstance (dq_input_node , torch .fx .Node )
417-
418- input_view_node = None
419- # Input may be preprocessed with a view node
420- if (
421- dq_input_node .op == "call_function"
422- and dq_input_node .target is not None
423- and hasattr (dq_input_node .target , "__name__" )
424- and "view_copy" in getattr (dq_input_node .target , "__name__" , "" )
425- ):
426- input_view_node = dq_input_node
427- dq_input_node = dq_input_node .args [0 ]
428- assert isinstance (dq_input_node , torch .fx .Node )
429-
430- assert isinstance (dq_input_node , torch .fx .Node )
431- assert isinstance (dq_weight_node , torch .fx .Node )
432-
433- # Get the quantized input and quantization parameters from the input dequantize_affine node
434- # Args: (input, block_size, scale, zero_point, input_dtype, quant_min, quant_max, output_dtype)
435- quantized_input = dq_input_node .args [0 ]
436- input_scale = dq_input_node .args [2 ] # scale is the 3rd argument
437- input_zero_point = dq_input_node .args [3 ] if len (dq_input_node .args ) > 3 else None
438-
439- # Get the weight and its quantization parameters from dequantize_affine
440- # Args: (weight, block_size, weight_scales, weight_zeros, input_dtype, quant_min, quant_max, output_dtype)
441- orig_weight = dq_weight_node .args [0 ]
442- weight_scales = dq_weight_node .args [2 ]
443- weight_zeros = dq_weight_node .args [3 ]
444-
445- # Pack the 4-bit weight tensor for efficient storage
446- assert isinstance (orig_weight , torch .fx .Node )
447- orig_weight_tensor = get_param_tensor (program , orig_weight )
448- assert isinstance (orig_weight_tensor , torch .Tensor )
449- packed_weight_tensor = pack_4bit_weight_tensor (orig_weight_tensor )
450- utils .update_program_state_dict (
451- program ,
452- orig_weight .name ,
453- packed_weight_tensor ,
454- )
455- # Update the metadata to reflect the new packed shape
456- orig_weight .meta ["val" ] = orig_weight .meta ["val" ][:, ::2 ].to (torch .uint8 )
457-
458- # Create the linear_qta8a_qga4w node
459- with graph_module .graph .inserting_before (linear_node ):
460- linear_qta8a_qga4w_node = graph_module .graph .create_node (
461- "call_function" ,
462- exir_ops .edge .et_vk .linear_qta8a_qga4w .default ,
463- (
464- quantized_input , # quantized input (int8)
465- input_scale , # mat1_scale
466- input_zero_point , # mat1_zero_point
467- orig_weight , # mat2_data (packed 4-bit weights)
468- group_size , # group_size (int)
469- weight_scales , # weight_scales
470- weight_zeros , # weight_zeros
471- ),
472- )
473-
474- # Replace the linear node with the new fused node
475- linear_node .replace_all_uses_with (linear_qta8a_qga4w_node )
476-
477- # Erase nodes in the correct order (users first, then dependencies)
478- graph_module .graph .erase_node (linear_node )
479- if input_view_node is not None :
480- graph_module .graph .erase_node (input_view_node )
481- graph_module .graph .erase_node (dq_weight_node )
482- graph_module .graph .erase_node (dq_input_node )
483-
484-
485213class FuseQuantizedOpsTransform (ExportPass ):
486214 def __init__ (self , exported_program : ExportedProgram ) -> None :
487215 super ().__init__ ()
@@ -498,15 +226,6 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
498226 )
499227 continue
500228
501- # Check for linear_qta8a_qga4w pattern (dynamic activation + grouped weight quantization)
502- qta8a_qga4w_details = None
503- if qta8a_qga4w_details is not None :
504- group_size , weight_bits = qta8a_qga4w_details
505- fuse_into_linear_qta8a_qga4w_node (
506- self .program , graph_module , node , group_size , weight_bits
507- )
508- continue
509-
510229 graph_module .recompile ()
511230 dead_code_elimination_pass (graph_module )
512231
0 commit comments