|
11 | 11 | # Utility functions for TOSAQuantizer |
12 | 12 | # |
13 | 13 |
|
14 | | -from typing import cast, Sequence |
| 14 | +from typing import cast |
15 | 15 |
|
16 | | -import torch |
17 | | -from torch._subclasses import FakeTensor |
18 | | -from torch.fx import GraphModule, Node |
| 16 | +from torch.fx import Node |
19 | 17 |
|
20 | 18 | from torchao.quantization.pt2e.quantizer import QuantizationAnnotation |
21 | 19 | from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY |
@@ -45,62 +43,3 @@ def mark_node_as_annotated(node: Node) -> None: |
45 | 43 | if Q_ANNOTATION_KEY not in node.meta: |
46 | 44 | node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation() |
47 | 45 | node.meta[Q_ANNOTATION_KEY]._annotated = True |
48 | | - |
49 | | - |
50 | | -def is_ok_for_quantization(node: Node, gm: GraphModule): |
51 | | - """Check if an node can be quantized. The node can not be quantized if: |
52 | | - - The node does not output a float tensor or, |
53 | | - - The node outputs a large scalar. |
54 | | - """ |
55 | | - return not (is_non_float_tensor(node) or is_large_scalar(node, gm)) |
56 | | - |
57 | | - |
58 | | -def get_node_target(module: torch.nn.Module | GraphModule, target_str: str): |
59 | | - targets = target_str.split(".") |
60 | | - for target in targets[:-1]: |
61 | | - module = module.get_submodule(target) |
62 | | - return getattr(module, targets[-1]) |
63 | | - |
64 | | - |
65 | | -def is_large_scalar(node: Node, gm: GraphModule): |
66 | | - """Check if input is a large scalar value. So that we can skip quantization for the node |
67 | | - since histc op (in HistogramObserver) only works for values up to certain upper bound |
68 | | - """ |
69 | | - if node.op == "get_attr" and isinstance(node.target, str): |
70 | | - tensor = get_node_target(gm, node.target) |
71 | | - # torch.histc works until this upper bound |
72 | | - HISTC_UPPER_BOUND = 3.4028235e15 |
73 | | - return tensor.numel() == 1 and abs(tensor.item()) > HISTC_UPPER_BOUND |
74 | | - return False |
75 | | - |
76 | | - |
77 | | -def is_non_float_tensor(node: Node) -> bool: |
78 | | - """Check if the output of a node has a data type other than `torch.float32`. |
79 | | -
|
80 | | - If the output is not `torch.float32`, quantization cannot be performed, as |
81 | | - observers only work with floating-point tensors. |
82 | | -
|
83 | | - Args: |
84 | | - node (Node): The node to check the output(s) for. |
85 | | -
|
86 | | - Returns: |
87 | | - bool: `True` if the data type is not float32, otherwise `False`. |
88 | | -
|
89 | | - Note: |
90 | | - - If `node.meta["val"]` is a `list`, the function returns `True` if **any** |
91 | | - element is **not** an instance of `FakeTensor` or does **not** have |
92 | | - `torch.float32` as its data type. |
93 | | - - If node.meta["val"] is missing or is not an instance of `FakeTensor`, the |
94 | | - function returns True. |
95 | | - """ |
96 | | - if "val" in node.meta and isinstance(node.meta["val"], Sequence): |
97 | | - return any( |
98 | | - not isinstance(fake_tensor, FakeTensor) |
99 | | - or fake_tensor.dtype != torch.float32 |
100 | | - for fake_tensor in node.meta["val"] |
101 | | - ) |
102 | | - |
103 | | - if "val" not in node.meta or not isinstance(node.meta["val"], FakeTensor): |
104 | | - return True |
105 | | - |
106 | | - return node.meta["val"].dtype != torch.float32 |
0 commit comments