Skip to content

Commit 44511ba

Browse files
Arm backend: Fix Mypy error related to _QuantProperty.qspec (pytorch#14814)
Previously, _QuantProperty.qspec had the type hint `type[QuantizationSpecBase] | List[type[QuantizationSpecBase]]`, which implies that _QuantProperty.qspec should be a class object. However, in torchao the class `QuantizationAnnotation` has this property: `output_qspec: Optional[QuantizationSpecBase] = None` which is set to _QuantProperty.qspec through a series of function calls. `output_qspec` should, as the type hinting implies, be an instance of a class of `QuantizationSpecBase`, not a class object. Therefore, change `type[QuantizationSpecBase] | List[type[QuantizationSpecBase]]` to `QuantizationSpecBase | List[QuantizationSpecBase]`. This allows us to remove a bunch of mypy ignores. cc @freddan80 @per @zingo @oscarandersson8218 @digantdesai Signed-off-by: Sebastian Larsson <[email protected]>
1 parent 4f38afb commit 44511ba

File tree

1 file changed

+14
-14
lines changed

1 file changed

+14
-14
lines changed

backends/arm/quantizer/quantization_annotator.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ class _QuantProperty:
3737
"""Specify how the input/output at 'index' must be quantized."""
3838

3939
index: int
40-
qspec: type[QuantizationSpecBase] | List[type[QuantizationSpecBase]]
40+
qspec: QuantizationSpecBase | List[QuantizationSpecBase]
4141
optional: bool = False
4242
mark_annotated: bool = False
4343

@@ -515,24 +515,24 @@ def any_or_hardtanh_min_zero(n: Node):
515515
_QuantProperty(0, input_act_qspec),
516516
_QuantProperty(
517517
1,
518-
input_act_qspec if node.args[0] == node.args[1] else shared_qspec, # type: ignore[arg-type]
518+
input_act_qspec if node.args[0] == node.args[1] else shared_qspec,
519519
),
520520
]
521-
quant_properties.quant_output = _QuantProperty(0, shared_qspec) # type: ignore[arg-type]
521+
quant_properties.quant_output = _QuantProperty(0, shared_qspec)
522522
elif node.target in (torch.ops.aten.where.self,):
523523
shared_qspec = SharedQuantizationSpec(node.args[1]) # type: ignore[arg-type]
524524
quant_properties.quant_inputs = [
525-
_QuantProperty(1, shared_qspec), # type: ignore[arg-type]
526-
_QuantProperty(2, shared_qspec), # type: ignore[arg-type]
525+
_QuantProperty(1, shared_qspec),
526+
_QuantProperty(2, shared_qspec),
527527
]
528-
quant_properties.quant_output = _QuantProperty(0, shared_qspec) # type: ignore[arg-type]
528+
quant_properties.quant_output = _QuantProperty(0, shared_qspec)
529529
elif node.target in _one_to_one_shared_input_or_input_act_qspec:
530530
input_qspec = (
531531
SharedQuantizationSpec(node.args[0]) # type: ignore[arg-type]
532-
if is_output_annotated(node.args[0]) # type: ignore
532+
if is_output_annotated(node.args[0]) # type: ignore[arg-type]
533533
else input_act_qspec
534534
)
535-
quant_properties.quant_inputs = [_QuantProperty(0, input_qspec)] # type: ignore[arg-type]
535+
quant_properties.quant_inputs = [_QuantProperty(0, input_qspec)]
536536
quant_properties.quant_output = _QuantProperty(
537537
0,
538538
SharedQuantizationSpec((node.args[0], node)), # type: ignore[arg-type]
@@ -551,7 +551,7 @@ def any_or_hardtanh_min_zero(n: Node):
551551
if len(node.args[0]) == 0:
552552
raise ValueError("Expected non-empty list for node.args[0]")
553553

554-
shared_qspec = SharedQuantizationSpec((node.args[0][0], node))
554+
shared_qspec = SharedQuantizationSpec((node.args[0][0], node)) # type: ignore[arg-type]
555555
quant_properties.quant_inputs = [
556556
_QuantProperty(
557557
0,
@@ -561,7 +561,7 @@ def any_or_hardtanh_min_zero(n: Node):
561561
],
562562
)
563563
]
564-
quant_properties.quant_output = _QuantProperty(0, shared_qspec) # type: ignore[arg-type]
564+
quant_properties.quant_output = _QuantProperty(0, shared_qspec)
565565
elif node.target in _one_to_one:
566566
quant_properties.quant_inputs = [_QuantProperty(0, input_act_qspec)]
567567
quant_properties.quant_output = _QuantProperty(0, output_act_qspec)
@@ -583,7 +583,7 @@ def any_or_hardtanh_min_zero(n: Node):
583583
_QuantProperty(0, input_act_qspec),
584584
_QuantProperty(
585585
1,
586-
input_act_qspec if node.args[0] == node.args[1] else shared_qspec, # type: ignore[arg-type]
586+
input_act_qspec if node.args[0] == node.args[1] else shared_qspec,
587587
),
588588
]
589589
quant_properties.quant_output = None
@@ -596,11 +596,11 @@ def any_or_hardtanh_min_zero(n: Node):
596596
quant_properties.quant_inputs = []
597597
quant_properties.quant_output = _QuantProperty(0, output_act_qspec)
598598
elif node.target in [operator.getitem]:
599-
if not is_output_annotated(node.args[0]): # type: ignore[attr-defined, arg-type]
599+
if not is_output_annotated(node.args[0]): # type: ignore[arg-type]
600600
return None
601601
shared_qspec = SharedQuantizationSpec(node.args[0]) # type: ignore[arg-type]
602-
quant_properties.quant_inputs = [_QuantProperty(0, shared_qspec)] # type: ignore[arg-type]
603-
quant_properties.quant_output = _QuantProperty(0, shared_qspec) # type: ignore[arg-type]
602+
quant_properties.quant_inputs = [_QuantProperty(0, shared_qspec)]
603+
quant_properties.quant_output = _QuantProperty(0, shared_qspec)
604604
else:
605605
return None
606606

0 commit comments

Comments
 (0)