|
6 | 6 | import logging |
7 | 7 | import operator |
8 | 8 | from dataclasses import dataclass |
9 | | -from typing import Callable, List, Optional, Sequence |
| 9 | +from typing import Callable, cast, List, Optional, Sequence |
10 | 10 |
|
11 | 11 | import torch |
12 | 12 | import torch.fx |
@@ -137,11 +137,18 @@ def _is_large_scalar(node: Node, gm: torch.fx.GraphModule): |
137 | 137 | node since histc op (in HistogramObserver) only works for values up to certain upper |
138 | 138 | bound. |
139 | 139 | """ |
| 140 | + HISTC_UPPER_BOUND = 3.4028235e15 |
140 | 141 | if node.op == "get_attr" and isinstance(node.target, str): |
141 | 142 | tensor = _get_node_target(gm, node.target) |
142 | 143 | # torch.histc works until this upper bound |
143 | | - HISTC_UPPER_BOUND = 3.4028235e15 |
144 | 144 | return tensor.numel() == 1 and abs(tensor.item()) > HISTC_UPPER_BOUND |
| 145 | + if node.op == "call_function" and node.target in ( |
| 146 | + torch.ops.aten.full.default, |
| 147 | + torch.ops.aten.full, |
| 148 | + torch.ops.aten.fill_.Scalar, |
| 149 | + ): |
| 150 | + fill_value = cast(float, node.args[1]) |
| 151 | + return abs(fill_value) > HISTC_UPPER_BOUND |
145 | 152 | return False |
146 | 153 |
|
147 | 154 |
|
@@ -358,9 +365,6 @@ def _match_pattern( |
358 | 365 | torch.ops.aten.permute_copy.default, |
359 | 366 | torch.ops.aten.avg_pool2d.default, |
360 | 367 | torch.ops.aten.max_pool2d.default, |
361 | | - torch.ops.aten.full.default, |
362 | | - torch.ops.aten.full, |
363 | | - torch.ops.aten.fill_.Scalar, |
364 | 368 | torch.ops.aten.flatten.using_ints, |
365 | 369 | torch.ops.aten.dropout.default, |
366 | 370 | torch.ops.aten.dropout_.default, |
@@ -518,9 +522,6 @@ def any_or_hardtanh_min_zero(n: Node): |
518 | 522 | ] |
519 | 523 | quant_properties.quant_output = _QuantProperty(0, shared_qspec) # type: ignore[arg-type] |
520 | 524 | elif node.target in _one_to_one_shared_input_or_input_act_qspec: |
521 | | - if not isinstance(node.args[0], Node): |
522 | | - return None |
523 | | - |
524 | 525 | input_qspec = ( |
525 | 526 | SharedQuantizationSpec(node.args[0]) # type: ignore[arg-type] |
526 | 527 | if is_output_annotated(node.args[0]) # type: ignore |
@@ -578,7 +579,12 @@ def any_or_hardtanh_min_zero(n: Node): |
578 | 579 | ), |
579 | 580 | ] |
580 | 581 | quant_properties.quant_output = None |
581 | | - elif node.target in [torch.ops.aten.scalar_tensor.default]: |
| 582 | + elif node.target in [ |
| 583 | + torch.ops.aten.scalar_tensor.default, |
| 584 | + torch.ops.aten.full.default, |
| 585 | + torch.ops.aten.full, |
| 586 | + torch.ops.aten.fill_.Scalar, |
| 587 | + ]: |
582 | 588 | quant_properties.quant_inputs = [] |
583 | 589 | quant_properties.quant_output = _QuantProperty(0, output_act_qspec) |
584 | 590 | elif node.target in [operator.getitem]: |
|
0 commit comments