Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 14 additions & 14 deletions backends/arm/quantizer/quantization_annotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class _QuantProperty:
"""Specify how the input/output at 'index' must be quantized."""

index: int
qspec: type[QuantizationSpecBase] | List[type[QuantizationSpecBase]]
qspec: QuantizationSpecBase | List[QuantizationSpecBase]
optional: bool = False
mark_annotated: bool = False

Expand Down Expand Up @@ -510,24 +510,24 @@ def any_or_hardtanh_min_zero(n: Node):
quant_properties.quant_inputs = [
_QuantProperty(0, input_act_qspec),
_QuantProperty(
1, input_act_qspec if node.args[0] == node.args[1] else shared_qspec # type: ignore[arg-type]
1, input_act_qspec if node.args[0] == node.args[1] else shared_qspec
),
]
quant_properties.quant_output = _QuantProperty(0, shared_qspec) # type: ignore[arg-type]
quant_properties.quant_output = _QuantProperty(0, shared_qspec)
elif node.target in (torch.ops.aten.where.self,):
shared_qspec = SharedQuantizationSpec(node.args[1]) # type: ignore[arg-type]
quant_properties.quant_inputs = [
_QuantProperty(1, shared_qspec), # type: ignore[arg-type]
_QuantProperty(2, shared_qspec), # type: ignore[arg-type]
_QuantProperty(1, shared_qspec),
_QuantProperty(2, shared_qspec),
]
quant_properties.quant_output = _QuantProperty(0, shared_qspec) # type: ignore[arg-type]
quant_properties.quant_output = _QuantProperty(0, shared_qspec)
elif node.target in _one_to_one_shared_input_or_input_act_qspec:
input_qspec = (
SharedQuantizationSpec(node.args[0]) # type: ignore[arg-type]
if is_output_annotated(node.args[0]) # type: ignore
if is_output_annotated(node.args[0]) # type: ignore[arg-type]
else input_act_qspec
)
quant_properties.quant_inputs = [_QuantProperty(0, input_qspec)] # type: ignore[arg-type]
quant_properties.quant_inputs = [_QuantProperty(0, input_qspec)]
quant_properties.quant_output = _QuantProperty(
0, SharedQuantizationSpec((node.args[0], node)) # type: ignore[arg-type]
)
Expand All @@ -545,7 +545,7 @@ def any_or_hardtanh_min_zero(n: Node):
if len(node.args[0]) == 0:
raise ValueError("Expected non-empty list for node.args[0]")

shared_qspec = SharedQuantizationSpec((node.args[0][0], node))
shared_qspec = SharedQuantizationSpec((node.args[0][0], node)) # type: ignore[arg-type]
quant_properties.quant_inputs = [
_QuantProperty(
0,
Expand All @@ -555,7 +555,7 @@ def any_or_hardtanh_min_zero(n: Node):
],
)
]
quant_properties.quant_output = _QuantProperty(0, shared_qspec) # type: ignore[arg-type]
quant_properties.quant_output = _QuantProperty(0, shared_qspec)
elif node.target in _one_to_one:
quant_properties.quant_inputs = [_QuantProperty(0, input_act_qspec)]
quant_properties.quant_output = _QuantProperty(0, output_act_qspec)
Expand All @@ -575,7 +575,7 @@ def any_or_hardtanh_min_zero(n: Node):
quant_properties.quant_inputs = [
_QuantProperty(0, input_act_qspec),
_QuantProperty(
1, input_act_qspec if node.args[0] == node.args[1] else shared_qspec # type: ignore[arg-type]
1, input_act_qspec if node.args[0] == node.args[1] else shared_qspec
),
]
quant_properties.quant_output = None
Expand All @@ -588,11 +588,11 @@ def any_or_hardtanh_min_zero(n: Node):
quant_properties.quant_inputs = []
quant_properties.quant_output = _QuantProperty(0, output_act_qspec)
elif node.target in [operator.getitem]:
if not is_output_annotated(node.args[0]): # type: ignore[attr-defined, arg-type]
if not is_output_annotated(node.args[0]): # type: ignore[arg-type]
return None
shared_qspec = SharedQuantizationSpec(node.args[0]) # type: ignore[arg-type]
quant_properties.quant_inputs = [_QuantProperty(0, shared_qspec)] # type: ignore[arg-type]
quant_properties.quant_output = _QuantProperty(0, shared_qspec) # type: ignore[arg-type]
quant_properties.quant_inputs = [_QuantProperty(0, shared_qspec)]
quant_properties.quant_output = _QuantProperty(0, shared_qspec)
else:
return None

Expand Down
Loading