@@ -292,6 +292,9 @@ def __init__(self) -> None:
292292 ] = {}
293293 self .module_type_config : dict [Callable , Optional [QuantizationConfig ]] = {}
294294 self .module_name_config : dict [str , Optional [QuantizationConfig ]] = {}
295+ # If specified, only quantize nodes that return true for the filter
296+ # function.
297+ self .filter_fn : Optional [Callable [[Node ], bool ]] = None
295298
296299 @classmethod
297300 def get_supported_quantization_configs (cls ) -> list [QuantizationConfig ]:
@@ -355,6 +358,14 @@ def set_module_name(
355358 self .module_name_config [module_name ] = quantization_config
356359 return self
357360
361+ def set_filter_function (self , filter_fn : Callable [[Node ], bool ]):
362+ """
363+ Set the filter function. We only quantize nodes that return True for
364+ the filter function.
365+ """
366+ self .filter_fn = filter_fn
367+ return self
368+
358369 def transform_for_annotation (
359370 self , model : torch .fx .GraphModule
360371 ) -> torch .fx .GraphModule :
@@ -378,17 +389,29 @@ def _annotate_all_patterns(
378389 if quantization_config is None :
379390 return model
380391
392+ # Create a combined filter function, which returns True only when
393+ # both filter_fn and self.filter_fn return True.
394+ def combined_filter_fn (n : Node ) -> bool :
395+ combined_filter = [self .filter_fn , filter_fn ]
396+ return all (f (n ) for f in combined_filter if f is not None )
397+
381398 for pattern in self .SUPPORTED_PATTERNS :
382399 if operator_target and operator_target not in pattern .op_overloads :
383400 # if operator_target is specified, skip patterns that aren't
384401 # associated with that target
385402 continue
386403 if quantization_config .input_activation .is_dynamic and pattern .is_dynamic :
387- OP_TO_ANNOTATOR [pattern .name ](model , quantization_config , filter_fn )
404+ OP_TO_ANNOTATOR [pattern .name ](
405+ model , quantization_config , combined_filter_fn
406+ )
388407 elif quantization_config .is_qat and pattern .is_qat :
389- OP_TO_ANNOTATOR [pattern .name ](model , quantization_config , filter_fn )
408+ OP_TO_ANNOTATOR [pattern .name ](
409+ model , quantization_config , combined_filter_fn
410+ )
390411 elif not quantization_config .input_activation .is_dynamic :
391- OP_TO_ANNOTATOR [pattern .name ](model , quantization_config , filter_fn )
412+ OP_TO_ANNOTATOR [pattern .name ](
413+ model , quantization_config , combined_filter_fn
414+ )
392415
393416 return model
394417
0 commit comments