@@ -124,11 +124,22 @@ class VulkanQuantizer(Quantizer):
124124 def __init__ (self ) -> None :
125125 super ().__init__ ()
126126 self .global_config : Optional [QuantizationConfig ] = None
127+ # If specified, only quantize nodes that return true for the filter
128+ # function.
129+ self .filter_fn : Optional [Callable [[Node ], bool ]] = None
127130
128131 def set_global (self , quantization_config : QuantizationConfig ) -> VulkanQuantizer :
129132 self .global_config = quantization_config
130133 return self
131134
135+ def set_filter_function (self , filter_fn : Callable [[Node ], bool ]):
136+ """
137+ Set the filter function. We only quantize nodes that return True for
138+ the filter function.
139+ """
140+ self .filter_fn = filter_fn
141+ return self
142+
132143 def transform_for_annotation (
133144 self , model : torch .fx .GraphModule
134145 ) -> torch .fx .GraphModule :
@@ -149,8 +160,14 @@ def _annotate_all_patterns(
149160 if quantization_config is None :
150161 return model
151162
163+ # Create a combined filter function, which returns True only when
164+ # both filter_fn and self.filter_fn return True.
165+ def combined_filter_fn (n : Node ) -> bool :
166+ combined_filter = [self .filter_fn , filter_fn ]
167+ return all (f (n ) for f in combined_filter if f is not None )
168+
152169 for op in _SUPPORTED_OPS :
153- OP_TO_ANNOTATOR [op ](model , quantization_config , filter_fn )
170+ OP_TO_ANNOTATOR [op ](model , quantization_config , combined_filter_fn )
154171 return model
155172
156173 def _annotate_for_quantization_config (
0 commit comments