diff --git a/backends/arm/quantizer/quantization_annotator.py b/backends/arm/quantizer/quantization_annotator.py index d7c85447dd5..ebc91c22bbb 100644 --- a/backends/arm/quantizer/quantization_annotator.py +++ b/backends/arm/quantizer/quantization_annotator.py @@ -6,7 +6,7 @@ import logging import operator from dataclasses import dataclass -from typing import Callable, List, Optional, Sequence +from typing import Callable, cast, List, Optional, Sequence import torch import torch.fx @@ -137,11 +137,18 @@ def _is_large_scalar(node: Node, gm: torch.fx.GraphModule): node since histc op (in HistogramObserver) only works for values up to certain upper bound. """ + HISTC_UPPER_BOUND = 3.4028235e15 if node.op == "get_attr" and isinstance(node.target, str): tensor = _get_node_target(gm, node.target) # torch.histc works until this upper bound - HISTC_UPPER_BOUND = 3.4028235e15 return tensor.numel() == 1 and abs(tensor.item()) > HISTC_UPPER_BOUND + if node.op == "call_function" and node.target in ( + torch.ops.aten.full.default, + torch.ops.aten.full, + torch.ops.aten.fill_.Scalar, + ): + fill_value = cast(float, node.args[1]) + return abs(fill_value) > HISTC_UPPER_BOUND return False @@ -358,9 +365,6 @@ def _match_pattern( torch.ops.aten.permute_copy.default, torch.ops.aten.avg_pool2d.default, torch.ops.aten.max_pool2d.default, - torch.ops.aten.full.default, - torch.ops.aten.full, - torch.ops.aten.fill_.Scalar, torch.ops.aten.flatten.using_ints, torch.ops.aten.dropout.default, torch.ops.aten.dropout_.default, @@ -518,9 +522,6 @@ def any_or_hardtanh_min_zero(n: Node): ] quant_properties.quant_output = _QuantProperty(0, shared_qspec) # type: ignore[arg-type] elif node.target in _one_to_one_shared_input_or_input_act_qspec: - if not isinstance(node.args[0], Node): - return None - input_qspec = ( SharedQuantizationSpec(node.args[0]) # type: ignore[arg-type] if is_output_annotated(node.args[0]) # type: ignore @@ -578,7 +579,12 @@ def any_or_hardtanh_min_zero(n: Node): ), ] quant_properties.quant_output = None - elif node.target in [torch.ops.aten.scalar_tensor.default]: + elif node.target in [ + torch.ops.aten.scalar_tensor.default, + torch.ops.aten.full.default, + torch.ops.aten.full, + torch.ops.aten.fill_.Scalar, + ]: quant_properties.quant_inputs = [] quant_properties.quant_output = _QuantProperty(0, output_act_qspec) elif node.target in [operator.getitem]: