|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +# pyre-strict |
| 8 | + |
| 9 | +from typing import Callable, Optional |
| 10 | + |
| 11 | +import torch |
| 12 | +from torch.fx import Node |
| 13 | +from torchao.quantization.pt2e.quantizer import ( |
| 14 | + annotate_input_qspec_map, |
| 15 | + annotate_output_qspec, |
| 16 | + get_bias_qspec, |
| 17 | + get_input_act_qspec, |
| 18 | + get_output_act_qspec, |
| 19 | + get_weight_qspec, |
| 20 | + QuantizationAnnotation, |
| 21 | + QuantizationConfig, |
| 22 | + SharedQuantizationSpec, |
| 23 | +) |
| 24 | +from torchao.quantization.pt2e.utils import get_new_attr_name_with_prefix |
| 25 | + |
| 26 | +__all__ = [ |
| 27 | + "OP_TO_ANNOTATOR", |
| 28 | + "propagate_annotation", |
| 29 | + "_convert_scalars_to_attrs", |
| 30 | +] |
| 31 | + |
| 32 | + |
| 33 | +AnnotatorType = Callable[ |
| 34 | + [ |
| 35 | + torch.fx.GraphModule, |
| 36 | + Optional[QuantizationConfig], |
| 37 | + Optional[Callable[[Node], bool]], |
| 38 | + ], |
| 39 | + Optional[list[list[Node]]], |
| 40 | +] |
| 41 | +OP_TO_ANNOTATOR: dict[str, AnnotatorType] = {} |
| 42 | + |
| 43 | + |
| 44 | +def register_annotator(op: str) -> Callable[[AnnotatorType], None]: |
| 45 | + def decorator(annotator: AnnotatorType) -> None: |
| 46 | + OP_TO_ANNOTATOR[op] = annotator |
| 47 | + |
| 48 | + return decorator |
| 49 | + |
| 50 | + |
| 51 | +def _is_annotated(nodes: list[Node]) -> bool: |
| 52 | + """ |
| 53 | + Given a list of nodes (that represents an operator pattern), |
| 54 | + check if any of the node is annotated, return True if any of the node |
| 55 | + is annotated, otherwise return False |
| 56 | + """ |
| 57 | + annotated = False |
| 58 | + for node in nodes: |
| 59 | + annotated = annotated or ( |
| 60 | + "quantization_annotation" in node.meta |
| 61 | + and node.meta["quantization_annotation"]._annotated |
| 62 | + ) |
| 63 | + return annotated |
| 64 | + |
| 65 | + |
| 66 | +def _mark_nodes_as_annotated(nodes: list[Node]) -> None: |
| 67 | + for node in nodes: |
| 68 | + if node is not None: |
| 69 | + if "quantization_annotation" not in node.meta: |
| 70 | + node.meta["quantization_annotation"] = QuantizationAnnotation() |
| 71 | + node.meta["quantization_annotation"]._annotated = True |
| 72 | + |
| 73 | + |
| 74 | +@register_annotator("linear") |
| 75 | +def _annotate_linear( |
| 76 | + gm: torch.fx.GraphModule, |
| 77 | + quantization_config: Optional[QuantizationConfig], |
| 78 | + filter_fn: Optional[Callable[[Node], bool]] = None, |
| 79 | +) -> Optional[list[list[Node]]]: |
| 80 | + annotated_partitions = [] |
| 81 | + input_act_qspec = get_input_act_qspec(quantization_config) |
| 82 | + output_act_qspec = get_output_act_qspec(quantization_config) |
| 83 | + weight_qspec = get_weight_qspec(quantization_config) |
| 84 | + bias_qspec = get_bias_qspec(quantization_config) |
| 85 | + for node in gm.graph.nodes: |
| 86 | + if node.op != "call_function" or node.target != torch.ops.aten.linear.default: |
| 87 | + continue |
| 88 | + if filter_fn and not filter_fn(node): |
| 89 | + continue |
| 90 | + act_node = node.args[0] |
| 91 | + weight_node = node.args[1] |
| 92 | + bias_node = None |
| 93 | + if len(node.args) > 2: |
| 94 | + bias_node = node.args[2] |
| 95 | + |
| 96 | + if _is_annotated([node]) is False: # type: ignore[list-item] |
| 97 | + annotate_input_qspec_map( |
| 98 | + node, |
| 99 | + act_node, |
| 100 | + input_act_qspec, |
| 101 | + ) |
| 102 | + annotate_input_qspec_map( |
| 103 | + node, |
| 104 | + weight_node, |
| 105 | + weight_qspec, |
| 106 | + ) |
| 107 | + nodes_to_mark_annotated = [node, weight_node] |
| 108 | + if bias_node: |
| 109 | + annotate_input_qspec_map( |
| 110 | + node, |
| 111 | + bias_node, |
| 112 | + bias_qspec, |
| 113 | + ) |
| 114 | + nodes_to_mark_annotated.append(bias_node) |
| 115 | + annotate_output_qspec(node, output_act_qspec) |
| 116 | + _mark_nodes_as_annotated(nodes_to_mark_annotated) |
| 117 | + annotated_partitions.append(nodes_to_mark_annotated) |
| 118 | + |
| 119 | + return annotated_partitions |
| 120 | + |
| 121 | + |
| 122 | +def _is_share_obs_or_fq_op(op: Callable[..., torch.Tensor]) -> bool: |
| 123 | + return op in [ |
| 124 | + torch.ops.aten.relu.default, |
| 125 | + torch.ops.aten.hardtanh.default, |
| 126 | + torch.ops.aten.hardtanh_.default, |
| 127 | + torch.ops.aten.max_pool2d.default, |
| 128 | + torch.ops.aten.mean.default, |
| 129 | + torch.ops.aten.mean.dim, |
| 130 | + torch.ops.aten.permute.default, |
| 131 | + torch.ops.aten.permute_copy.default, |
| 132 | + torch.ops.aten.squeeze.dim, |
| 133 | + torch.ops.aten.squeeze_copy.dim, |
| 134 | + torch.ops.aten.adaptive_avg_pool2d.default, |
| 135 | + torch.ops.aten.view_copy.default, |
| 136 | + torch.ops.aten.view.default, |
| 137 | + torch.ops.aten.slice_copy.Tensor, |
| 138 | + torch.ops.aten.flatten.using_ints, |
| 139 | + ] |
| 140 | + |
| 141 | + |
| 142 | +def propagate_annotation(model: torch.fx.GraphModule) -> None: |
| 143 | + for n in model.graph.nodes: |
| 144 | + if n.op != "call_function" or not _is_share_obs_or_fq_op(n.target): |
| 145 | + continue |
| 146 | + |
| 147 | + prev_node = n.args[0] |
| 148 | + if not isinstance(prev_node, Node): |
| 149 | + continue |
| 150 | + |
| 151 | + quantization_annotation = prev_node.meta.get("quantization_annotation", None) |
| 152 | + if not quantization_annotation: |
| 153 | + continue |
| 154 | + |
| 155 | + output_qspec = quantization_annotation.output_qspec |
| 156 | + if not output_qspec: |
| 157 | + continue |
| 158 | + |
| 159 | + # make sure current node is not annotated |
| 160 | + if ( |
| 161 | + "quantization_annotation" in n.meta |
| 162 | + and n.meta["quantization_annotation"]._annotated |
| 163 | + ): |
| 164 | + continue |
| 165 | + |
| 166 | + shared_qspec = SharedQuantizationSpec(prev_node) |
| 167 | + # propagate the previous output_qspec to the current node |
| 168 | + n.meta["quantization_annotation"] = QuantizationAnnotation( |
| 169 | + input_qspec_map={ |
| 170 | + prev_node: shared_qspec, |
| 171 | + }, |
| 172 | + output_qspec=shared_qspec, |
| 173 | + _annotated=True, |
| 174 | + ) |
| 175 | + |
| 176 | + |
| 177 | +def _convert_scalars_to_attrs(model: torch.fx.GraphModule) -> torch.fx.GraphModule: |
| 178 | + for n in model.graph.nodes: |
| 179 | + if n.op != "call_function" or n.target not in [ |
| 180 | + torch.ops.aten.add.Tensor, |
| 181 | + torch.ops.aten.mul.Tensor, |
| 182 | + ]: |
| 183 | + continue |
| 184 | + args = list(n.args) |
| 185 | + new_args = [] |
| 186 | + for i in range(len(args)): |
| 187 | + if isinstance(args[i], torch.fx.Node): |
| 188 | + new_args.append(args[i]) |
| 189 | + continue |
| 190 | + prefix = "_tensor_constant_" |
| 191 | + get_new_attr_name = get_new_attr_name_with_prefix(prefix) |
| 192 | + tensor_constant_name = get_new_attr_name(model) |
| 193 | + float_tensor = torch.tensor(float(args[i])) |
| 194 | + model.register_buffer(tensor_constant_name, float_tensor) |
| 195 | + fake_mode = n.meta["val"].fake_mode |
| 196 | + with model.graph.inserting_before(n): |
| 197 | + get_attr_node = model.graph.create_node( |
| 198 | + "get_attr", tensor_constant_name, (), {} |
| 199 | + ) |
| 200 | + get_attr_node.meta["val"] = fake_mode.from_tensor( |
| 201 | + float_tensor, static_shapes=True |
| 202 | + ) |
| 203 | + new_args.append(get_attr_node) |
| 204 | + n.args = tuple(new_args) |
| 205 | + model.recompile() |
| 206 | + return model |
0 commit comments