|
3 | 3 |
|
4 | 4 | import copy |
5 | 5 | import functools |
6 | | -from typing import Any, Callable, Optional, TYPE_CHECKING |
| 6 | +from dataclasses import dataclass |
| 7 | +from typing import Any, Callable, Optional, Set, TYPE_CHECKING |
7 | 8 |
|
8 | 9 | import torch |
9 | 10 | import torch._dynamo as torchdynamo |
@@ -235,37 +236,52 @@ def not_module_type_or_name_filter(n: Node) -> bool: |
235 | 236 | return not_module_type_or_name_filter |
236 | 237 |
|
237 | 238 |
|
238 | | -class XNNPACKQuantizer(Quantizer): |
239 | | - supported_config_and_operators = _get_supported_config_and_operators() |
240 | | - STATIC_QAT_ONLY_OPS = [ |
241 | | - "conv_bn_relu", |
242 | | - "conv_bn", |
243 | | - "conv_transpose_bn_relu", |
244 | | - "conv_transpose_bn", |
245 | | - ] |
| 239 | +@dataclass |
| 240 | +class QuantPattern: |
| 241 | + name: str |
| 242 | + is_dynamic: bool |
| 243 | + is_qat: bool |
| 244 | + op_overloads: Set[torch._ops.OpOverloadPacket] |
| 245 | + |
| 246 | + |
| 247 | +CONV_TARGETS = { |
| 248 | + torch.ops.aten.conv2d.default, |
| 249 | + torch.ops.aten.conv1d.default, |
| 250 | + torch.ops.aten.convolution.default, |
| 251 | +} |
| 252 | + |
| 253 | +LINEAR_TARGETS = { |
| 254 | + torch.ops.aten.linear.default, |
| 255 | +} |
| 256 | + |
| 257 | +ADAPTIVE_AVG_POOL2D_TARGETS = {torch.ops.aten.adaptive_avg_pool2d.default} |
| 258 | + |
| 259 | +ADD_TARGETS = {torch.ops.aten.add.Tensor} |
| 260 | + |
| 261 | +MUL_TARGETS = {torch.ops.aten.mul.Tensor} |
| 262 | + |
| 263 | +CAT_TARGETS = {torch.ops.aten.cat.default} |
246 | 264 |
|
247 | | - # static quantization ops (both PTQ and QAT) |
248 | | - # Preserve the order that fusions come before singular ops |
249 | | - STATIC_OPS = [ |
250 | | - "linear_relu", |
251 | | - "linear", |
252 | | - "conv", |
253 | | - "conv_transpose", |
254 | | - "conv_relu", |
255 | | - "conv_transpose_relu", |
256 | | - "adaptive_avg_pool2d", |
257 | | - # TODO: move this to BoltNNQuantizer? |
258 | | - "gru_io_only", |
259 | | - "add_relu", |
260 | | - "add", |
261 | | - "mul_relu", |
262 | | - "mul", |
263 | | - "cat", |
264 | | - ] |
265 | 265 |
|
266 | | - DYNAMIC_OPS = [ |
267 | | - "linear", |
268 | | - "conv", |
| 266 | +class XNNPACKQuantizer(Quantizer): |
| 267 | + supported_config_and_operators = _get_supported_config_and_operators() |
| 268 | + SUPPORTED_PATTERNS = [ |
| 269 | + QuantPattern("conv_bn_relu", False, True, CONV_TARGETS), |
| 270 | + QuantPattern("conv_bn", False, True, CONV_TARGETS), |
| 271 | + QuantPattern("conv_transpose_bn_relu", False, True, CONV_TARGETS), |
| 272 | + QuantPattern("conv_transpose_bn", False, True, CONV_TARGETS), |
| 273 | + QuantPattern("linear_relu", False, False, LINEAR_TARGETS), |
| 274 | + QuantPattern("linear", True, False, LINEAR_TARGETS), |
| 275 | + QuantPattern("conv", True, False, CONV_TARGETS), |
| 276 | + QuantPattern("conv_transpose", False, False, CONV_TARGETS), |
| 277 | + QuantPattern("conv_relu", False, False, CONV_TARGETS), |
| 278 | + QuantPattern("conv_transpose_relu", False, False, CONV_TARGETS), |
| 279 | + QuantPattern("adaptive_avg_pool2d", False, False, ADAPTIVE_AVG_POOL2D_TARGETS), |
| 280 | + QuantPattern("add_relu", False, False, ADD_TARGETS), |
| 281 | + QuantPattern("add", False, False, ADD_TARGETS), |
| 282 | + QuantPattern("mul_relu", False, False, MUL_TARGETS), |
| 283 | + QuantPattern("mul", False, False, MUL_TARGETS), |
| 284 | + QuantPattern("cat", False, False, CAT_TARGETS), |
269 | 285 | ] |
270 | 286 |
|
271 | 287 | def __init__(self) -> None: |
@@ -347,83 +363,58 @@ def transform_for_annotation( |
347 | 363 |
|
348 | 364 | def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: |
349 | 365 | """just handling global spec for now""" |
350 | | - # hacked for handling dynamic linear quant. will fix later. |
351 | | - if self.global_config and self.global_config.input_activation.is_dynamic: # type: ignore[union-attr] |
352 | | - model = self._annotate_for_dynamic_quantization_config(model) |
353 | | - else: |
354 | | - model = self._annotate_for_static_quantization_config(model) |
| 366 | + model = self._annotate_for_quantization_config(model) |
355 | 367 | propagate_annotation(model) |
356 | 368 | return model |
357 | 369 |
|
358 | | - def _annotate_all_static_patterns( |
| 370 | + def _annotate_all_patterns( |
359 | 371 | self, |
360 | 372 | model: torch.fx.GraphModule, |
361 | 373 | quantization_config: Optional[QuantizationConfig], |
362 | 374 | filter_fn: Optional[Callable[[Node], bool]] = None, |
363 | | - ) -> torch.fx.GraphModule: |
| 375 | + operator_target: Optional[torch._ops.OpOverloadPacket] = None, |
| 376 | + ): |
364 | 377 | # TODO: implement the support for None to be canceling out previous annotations |
365 | 378 | if quantization_config is None: |
366 | 379 | return model |
367 | 380 |
|
368 | | - if quantization_config.is_qat: |
369 | | - for op in self.STATIC_QAT_ONLY_OPS: |
370 | | - OP_TO_ANNOTATOR[op](model, quantization_config, filter_fn) |
371 | | - for op in self.STATIC_OPS: |
372 | | - OP_TO_ANNOTATOR[op](model, quantization_config, filter_fn) |
373 | | - return model |
| 381 | + for pattern in self.SUPPORTED_PATTERNS: |
| 382 | + if operator_target and operator_target not in pattern.op_overloads: |
| 383 | + # if operator_target is specified, skip patterns that aren't |
| 384 | + # associated with that target |
| 385 | + continue |
| 386 | + if quantization_config.input_activation.is_dynamic and pattern.is_dynamic: |
| 387 | + OP_TO_ANNOTATOR[pattern.name](model, quantization_config, filter_fn) |
| 388 | + elif quantization_config.is_qat and pattern.is_qat: |
| 389 | + OP_TO_ANNOTATOR[pattern.name](model, quantization_config, filter_fn) |
| 390 | + elif not quantization_config.input_activation.is_dynamic: |
| 391 | + OP_TO_ANNOTATOR[pattern.name](model, quantization_config, filter_fn) |
374 | 392 |
|
375 | | - def _annotate_all_dynamic_patterns( |
376 | | - self, |
377 | | - model: torch.fx.GraphModule, |
378 | | - quantization_config: Optional[QuantizationConfig], |
379 | | - filter_fn: Optional[Callable[[Node], bool]] = None, |
380 | | - ) -> torch.fx.GraphModule: |
381 | | - # TODO: implement the support for None to be canceling out previous annotations |
382 | | - if quantization_config is None: |
383 | | - return model |
384 | | - |
385 | | - for op in self.DYNAMIC_OPS: |
386 | | - OP_TO_ANNOTATOR[op](model, quantization_config, filter_fn) |
387 | 393 | return model |
388 | 394 |
|
389 | | - def _annotate_for_static_quantization_config( |
| 395 | + def _annotate_for_quantization_config( |
390 | 396 | self, model: torch.fx.GraphModule |
391 | 397 | ) -> torch.fx.GraphModule: |
392 | 398 | module_name_list = list(self.module_name_config.keys()) |
393 | 399 | for module_name, config in self.module_name_config.items(): |
394 | | - self._annotate_all_static_patterns( |
| 400 | + self._annotate_all_patterns( |
395 | 401 | model, config, _get_module_name_filter(module_name) |
396 | 402 | ) |
397 | 403 |
|
398 | 404 | tp_list = list(self.module_type_config.keys()) |
399 | 405 | for module_type, config in self.module_type_config.items(): |
400 | | - self._annotate_all_static_patterns( |
| 406 | + self._annotate_all_patterns( |
401 | 407 | model, config, _get_module_type_filter(module_type) |
402 | 408 | ) |
403 | 409 |
|
404 | | - self._annotate_all_static_patterns( |
405 | | - model, |
406 | | - self.global_config, |
407 | | - _get_not_module_type_or_name_filter(tp_list, module_name_list), |
408 | | - ) |
409 | | - return model |
410 | | - |
411 | | - def _annotate_for_dynamic_quantization_config( |
412 | | - self, model: torch.fx.GraphModule |
413 | | - ) -> torch.fx.GraphModule: |
414 | | - module_name_list = list(self.module_name_config.keys()) |
415 | | - for module_name, config in self.module_name_config.items(): |
416 | | - self._annotate_all_dynamic_patterns( |
417 | | - model, config, _get_module_name_filter(module_name) |
418 | | - ) |
419 | | - |
420 | | - tp_list = list(self.module_type_config.keys()) |
421 | | - for module_type, config in self.module_type_config.items(): |
422 | | - self._annotate_all_dynamic_patterns( |
423 | | - model, config, _get_module_type_filter(module_type) |
| 410 | + for op, config in self.operator_type_config.items(): |
| 411 | + self._annotate_all_patterns( |
| 412 | + model, |
| 413 | + config, |
| 414 | + _get_not_module_type_or_name_filter(tp_list, module_name_list), |
| 415 | + op, |
424 | 416 | ) |
425 | | - |
426 | | - self._annotate_all_dynamic_patterns( |
| 417 | + self._annotate_all_patterns( |
427 | 418 | model, |
428 | 419 | self.global_config, |
429 | 420 | _get_not_module_type_or_name_filter(tp_list, module_name_list), |
|
0 commit comments