Skip to content

Commit 871fe39

Browse files
Arm backend: Update full quantization annotation (#14585)
full, full.default and fill_.Scalar were previously part of _one_to_one_shared_input_or_input_act_qspec without having any input nodes. This meant that these nodes were never annotated and solely relied on the next node to annotate its input. This patch changes so that full, full.default and fill_.Scalar are annotated in the same way as scalar_tensor.default. Also adds these targets to _is_large_scalar(). Signed-off-by: Oscar Andersson <[email protected]>
1 parent 649f92d commit 871fe39

File tree

1 file changed

+15
-9
lines changed

1 file changed

+15
-9
lines changed

backends/arm/quantizer/quantization_annotator.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import logging
77
import operator
88
from dataclasses import dataclass
9-
from typing import Callable, List, Optional, Sequence
9+
from typing import Callable, cast, List, Optional, Sequence
1010

1111
import torch
1212
import torch.fx
@@ -137,11 +137,18 @@ def _is_large_scalar(node: Node, gm: torch.fx.GraphModule):
137137
node since histc op (in HistogramObserver) only works for values up to certain upper
138138
bound.
139139
"""
140+
HISTC_UPPER_BOUND = 3.4028235e15
140141
if node.op == "get_attr" and isinstance(node.target, str):
141142
tensor = _get_node_target(gm, node.target)
142143
# torch.histc works until this upper bound
143-
HISTC_UPPER_BOUND = 3.4028235e15
144144
return tensor.numel() == 1 and abs(tensor.item()) > HISTC_UPPER_BOUND
145+
if node.op == "call_function" and node.target in (
146+
torch.ops.aten.full.default,
147+
torch.ops.aten.full,
148+
torch.ops.aten.fill_.Scalar,
149+
):
150+
fill_value = cast(float, node.args[1])
151+
return abs(fill_value) > HISTC_UPPER_BOUND
145152
return False
146153

147154

@@ -358,9 +365,6 @@ def _match_pattern(
358365
torch.ops.aten.permute_copy.default,
359366
torch.ops.aten.avg_pool2d.default,
360367
torch.ops.aten.max_pool2d.default,
361-
torch.ops.aten.full.default,
362-
torch.ops.aten.full,
363-
torch.ops.aten.fill_.Scalar,
364368
torch.ops.aten.flatten.using_ints,
365369
torch.ops.aten.dropout.default,
366370
torch.ops.aten.dropout_.default,
@@ -518,9 +522,6 @@ def any_or_hardtanh_min_zero(n: Node):
518522
]
519523
quant_properties.quant_output = _QuantProperty(0, shared_qspec) # type: ignore[arg-type]
520524
elif node.target in _one_to_one_shared_input_or_input_act_qspec:
521-
if not isinstance(node.args[0], Node):
522-
return None
523-
524525
input_qspec = (
525526
SharedQuantizationSpec(node.args[0]) # type: ignore[arg-type]
526527
if is_output_annotated(node.args[0]) # type: ignore
@@ -578,7 +579,12 @@ def any_or_hardtanh_min_zero(n: Node):
578579
),
579580
]
580581
quant_properties.quant_output = None
581-
elif node.target in [torch.ops.aten.scalar_tensor.default]:
582+
elif node.target in [
583+
torch.ops.aten.scalar_tensor.default,
584+
torch.ops.aten.full.default,
585+
torch.ops.aten.full,
586+
torch.ops.aten.fill_.Scalar,
587+
]:
582588
quant_properties.quant_inputs = []
583589
quant_properties.quant_output = _QuantProperty(0, output_act_qspec)
584590
elif node.target in [operator.getitem]:

0 commit comments

Comments
 (0)