Skip to content

Commit b4d664e

Browse files
Arm backend: Fix Mypy error related to _QuantProperty.qspec
Previously, _QuantProperty.qspec had the type hint `type[QuantizationSpecBase] | List[type[QuantizationSpecBase]]`, which implies that _QuantProperty.qspec should be a class object. However, in torchao the class `QuantizationAnnotation` has this property: `output_qspec: Optional[QuantizationSpecBase] = None` which is set to _QuantProperty.qspec through a series of function calls. `output_qspec` should, as the type hinting implies, be an instance of a class of `QuantizationSpecBase`, not a class object. Therefore, change `type[QuantizationSpecBase] | List[type[QuantizationSpecBase]]` to `QuantizationSpecBase | List[QuantizationSpecBase]`. This allows us to remove a bunch of mypy ignores. Change-Id: Idbb5ca4ba9ab17e8805b1e4d647e46e86f434b69 Signed-off-by: Sebastian Larsson <[email protected]>
1 parent 3b16bc1 commit b4d664e

File tree

1 file changed

+14
-14
lines changed

1 file changed

+14
-14
lines changed

backends/arm/quantizer/quantization_annotator.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ class _QuantProperty:
3737
"""Specify how the input/output at 'index' must be quantized."""
3838

3939
index: int
40-
qspec: type[QuantizationSpecBase] | List[type[QuantizationSpecBase]]
40+
qspec: QuantizationSpecBase | List[QuantizationSpecBase]
4141
optional: bool = False
4242
mark_annotated: bool = False
4343

@@ -510,24 +510,24 @@ def any_or_hardtanh_min_zero(n: Node):
510510
quant_properties.quant_inputs = [
511511
_QuantProperty(0, input_act_qspec),
512512
_QuantProperty(
513-
1, input_act_qspec if node.args[0] == node.args[1] else shared_qspec # type: ignore[arg-type]
513+
1, input_act_qspec if node.args[0] == node.args[1] else shared_qspec
514514
),
515515
]
516-
quant_properties.quant_output = _QuantProperty(0, shared_qspec) # type: ignore[arg-type]
516+
quant_properties.quant_output = _QuantProperty(0, shared_qspec)
517517
elif node.target in (torch.ops.aten.where.self,):
518518
shared_qspec = SharedQuantizationSpec(node.args[1]) # type: ignore[arg-type]
519519
quant_properties.quant_inputs = [
520-
_QuantProperty(1, shared_qspec), # type: ignore[arg-type]
521-
_QuantProperty(2, shared_qspec), # type: ignore[arg-type]
520+
_QuantProperty(1, shared_qspec),
521+
_QuantProperty(2, shared_qspec),
522522
]
523-
quant_properties.quant_output = _QuantProperty(0, shared_qspec) # type: ignore[arg-type]
523+
quant_properties.quant_output = _QuantProperty(0, shared_qspec)
524524
elif node.target in _one_to_one_shared_input_or_input_act_qspec:
525525
input_qspec = (
526526
SharedQuantizationSpec(node.args[0]) # type: ignore[arg-type]
527-
if is_output_annotated(node.args[0]) # type: ignore
527+
if is_output_annotated(node.args[0]) # type: ignore[arg-type]
528528
else input_act_qspec
529529
)
530-
quant_properties.quant_inputs = [_QuantProperty(0, input_qspec)] # type: ignore[arg-type]
530+
quant_properties.quant_inputs = [_QuantProperty(0, input_qspec)]
531531
quant_properties.quant_output = _QuantProperty(
532532
0, SharedQuantizationSpec((node.args[0], node)) # type: ignore[arg-type]
533533
)
@@ -545,7 +545,7 @@ def any_or_hardtanh_min_zero(n: Node):
545545
if len(node.args[0]) == 0:
546546
raise ValueError("Expected non-empty list for node.args[0]")
547547

548-
shared_qspec = SharedQuantizationSpec((node.args[0][0], node))
548+
shared_qspec = SharedQuantizationSpec((node.args[0][0], node)) # type: ignore[arg-type]
549549
quant_properties.quant_inputs = [
550550
_QuantProperty(
551551
0,
@@ -555,7 +555,7 @@ def any_or_hardtanh_min_zero(n: Node):
555555
],
556556
)
557557
]
558-
quant_properties.quant_output = _QuantProperty(0, shared_qspec) # type: ignore[arg-type]
558+
quant_properties.quant_output = _QuantProperty(0, shared_qspec)
559559
elif node.target in _one_to_one:
560560
quant_properties.quant_inputs = [_QuantProperty(0, input_act_qspec)]
561561
quant_properties.quant_output = _QuantProperty(0, output_act_qspec)
@@ -575,7 +575,7 @@ def any_or_hardtanh_min_zero(n: Node):
575575
quant_properties.quant_inputs = [
576576
_QuantProperty(0, input_act_qspec),
577577
_QuantProperty(
578-
1, input_act_qspec if node.args[0] == node.args[1] else shared_qspec # type: ignore[arg-type]
578+
1, input_act_qspec if node.args[0] == node.args[1] else shared_qspec
579579
),
580580
]
581581
quant_properties.quant_output = None
@@ -588,11 +588,11 @@ def any_or_hardtanh_min_zero(n: Node):
588588
quant_properties.quant_inputs = []
589589
quant_properties.quant_output = _QuantProperty(0, output_act_qspec)
590590
elif node.target in [operator.getitem]:
591-
if not is_output_annotated(node.args[0]): # type: ignore[attr-defined, arg-type]
591+
if not is_output_annotated(node.args[0]): # type: ignore[arg-type]
592592
return None
593593
shared_qspec = SharedQuantizationSpec(node.args[0]) # type: ignore[arg-type]
594-
quant_properties.quant_inputs = [_QuantProperty(0, shared_qspec)] # type: ignore[arg-type]
595-
quant_properties.quant_output = _QuantProperty(0, shared_qspec) # type: ignore[arg-type]
594+
quant_properties.quant_inputs = [_QuantProperty(0, shared_qspec)]
595+
quant_properties.quant_output = _QuantProperty(0, shared_qspec)
596596
else:
597597
return None
598598

0 commit comments

Comments
 (0)