Skip to content

Commit 98afdeb

Browse files
authored
Allow passing filter function to Vulkan quantizer
Differential Revision: D85994594 Pull Request resolved: #15508
1 parent 3f4f500 commit 98afdeb

File tree

1 file changed

+18
-1
lines changed

1 file changed

+18
-1
lines changed

backends/vulkan/quantizer/vulkan_quantizer.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,11 +124,22 @@ class VulkanQuantizer(Quantizer):
124124
def __init__(self) -> None:
125125
super().__init__()
126126
self.global_config: Optional[QuantizationConfig] = None
127+
# If specified, only quantize nodes that return true for the filter
128+
# function.
129+
self.filter_fn: Optional[Callable[[Node], bool]] = None
127130

128131
def set_global(self, quantization_config: QuantizationConfig) -> VulkanQuantizer:
129132
self.global_config = quantization_config
130133
return self
131134

135+
def set_filter_function(self, filter_fn: Callable[[Node], bool]):
136+
"""
137+
Set the filter function. We only quantize nodes that return True for
138+
the filter function.
139+
"""
140+
self.filter_fn = filter_fn
141+
return self
142+
132143
def transform_for_annotation(
133144
self, model: torch.fx.GraphModule
134145
) -> torch.fx.GraphModule:
@@ -149,8 +160,14 @@ def _annotate_all_patterns(
149160
if quantization_config is None:
150161
return model
151162

163+
# Create a combined filter function, which returns True only when
164+
# both filter_fn and self.filter_fn return True.
165+
def combined_filter_fn(n: Node) -> bool:
166+
combined_filter = [self.filter_fn, filter_fn]
167+
return all(f(n) for f in combined_filter if f is not None)
168+
152169
for op in _SUPPORTED_OPS:
153-
OP_TO_ANNOTATOR[op](model, quantization_config, filter_fn)
170+
OP_TO_ANNOTATOR[op](model, quantization_config, combined_filter_fn)
154171
return model
155172

156173
def _annotate_for_quantization_config(

0 commit comments

Comments
 (0)