Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 15 additions & 9 deletions backends/arm/quantizer/quantization_annotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import logging
import operator
from dataclasses import dataclass
from typing import Callable, List, Optional, Sequence
from typing import Callable, cast, List, Optional, Sequence

import torch
import torch.fx
Expand Down Expand Up @@ -137,11 +137,18 @@ def _is_large_scalar(node: Node, gm: torch.fx.GraphModule):
node since histc op (in HistogramObserver) only works for values up to certain upper
bound.
"""
HISTC_UPPER_BOUND = 3.4028235e15
if node.op == "get_attr" and isinstance(node.target, str):
tensor = _get_node_target(gm, node.target)
# torch.histc works until this upper bound
HISTC_UPPER_BOUND = 3.4028235e15
return tensor.numel() == 1 and abs(tensor.item()) > HISTC_UPPER_BOUND
if node.op == "call_function" and node.target in (
torch.ops.aten.full.default,
torch.ops.aten.full,
torch.ops.aten.fill_.Scalar,
):
fill_value = cast(float, node.args[1])
return abs(fill_value) > HISTC_UPPER_BOUND
return False


Expand Down Expand Up @@ -358,9 +365,6 @@ def _match_pattern(
torch.ops.aten.permute_copy.default,
torch.ops.aten.avg_pool2d.default,
torch.ops.aten.max_pool2d.default,
torch.ops.aten.full.default,
torch.ops.aten.full,
torch.ops.aten.fill_.Scalar,
torch.ops.aten.flatten.using_ints,
torch.ops.aten.dropout.default,
torch.ops.aten.dropout_.default,
Expand Down Expand Up @@ -518,9 +522,6 @@ def any_or_hardtanh_min_zero(n: Node):
]
quant_properties.quant_output = _QuantProperty(0, shared_qspec) # type: ignore[arg-type]
elif node.target in _one_to_one_shared_input_or_input_act_qspec:
if not isinstance(node.args[0], Node):
return None

input_qspec = (
SharedQuantizationSpec(node.args[0]) # type: ignore[arg-type]
if is_output_annotated(node.args[0]) # type: ignore
Expand Down Expand Up @@ -578,7 +579,12 @@ def any_or_hardtanh_min_zero(n: Node):
),
]
quant_properties.quant_output = None
elif node.target in [torch.ops.aten.scalar_tensor.default]:
elif node.target in [
torch.ops.aten.scalar_tensor.default,
torch.ops.aten.full.default,
torch.ops.aten.full,
torch.ops.aten.fill_.Scalar,
]:
quant_properties.quant_inputs = []
quant_properties.quant_output = _QuantProperty(0, output_act_qspec)
elif node.target in [operator.getitem]:
Expand Down
Loading