33#
44# This source code is licensed under the BSD-style license found in the
55# LICENSE file in the root directory of this source tree.
6+ """Provide quantization configuration helpers for the Arm backend.
7+
8+ Define a small dataclass to carry activation/weight/bias specs and helper
9+ accessors that validate specs before use. Use this module to build and validate
10+ quantization specs consumed by the annotator.
11+
12+ """
613
714# pyre-unsafe
815
1926
2027@dataclass (eq = True , frozen = True )
2128class QuantizationConfig :
29+ """Provide a container for quantization specs.
30+
31+ Hold optional specs for input/output activations, weights, and bias, and
32+ expose validated accessors.
33+
34+ Attributes:
35+ input_activation (QuantizationSpec | None): Spec for input activations.
36+ output_activation (QuantizationSpec | None): Spec for output activations.
37+ weight (QuantizationSpec | None): Spec for weights.
38+ bias (QuantizationSpec | None): Spec for bias values.
39+
40+ """
41+
2242 input_activation : QuantizationSpec | None
2343 output_activation : QuantizationSpec | None
2444 weight : QuantizationSpec | None
2545 bias : QuantizationSpec | None
2646
2747 def get_input_act_qspec (self ) -> QuantizationSpec | None :
28- """Returns QuantizationSpec 'input_activation' after asserting that input_activation.qscheme is valid."""
48+ """Get the validated input activation spec.
49+
50+ Validate that the input activation qscheme is supported before
51+ returning the spec.
52+
53+ Returns:
54+ QuantizationSpec | None: Input activation spec, or ``None`` when
55+ unset.
56+
57+ Raises:
58+ ValueError: If the qscheme is not per-tensor affine or symmetric.
59+
60+ """
2961 if self .input_activation is None :
3062 return None
3163 # Validate that input_activation uses a supported qscheme
@@ -39,7 +71,19 @@ def get_input_act_qspec(self) -> QuantizationSpec | None:
3971 return self .input_activation
4072
4173 def get_output_act_qspec (self ) -> QuantizationSpec | None :
42- """Returns QuantizationSpec 'output_activation' after asserting that output_activation.qscheme is valid."""
74+ """Get the validated output activation spec.
75+
76+ Validate that the output activation qscheme is supported before
77+ returning the spec.
78+
79+ Returns:
80+ QuantizationSpec | None: Output activation spec, or ``None`` when
81+ unset.
82+
83+ Raises:
84+ ValueError: If the qscheme is not per-tensor affine or symmetric.
85+
86+ """
4387 if self .output_activation is None :
4488 return None
4589 # Validate that output_activation uses a supported qscheme
@@ -53,7 +97,18 @@ def get_output_act_qspec(self) -> QuantizationSpec | None:
5397 return self .output_activation
5498
5599 def get_weight_qspec (self ) -> QuantizationSpec | None :
56- """Returns QuantizationSpec 'weight' after asserting that weight.qscheme is valid."""
100+ """Get the validated weight spec.
101+
102+ Validate that the weight qscheme is supported (per-tensor or
103+ per-channel symmetric) before returning the spec.
104+
105+ Returns:
106+ QuantizationSpec | None: Weight spec, or ``None`` when unset.
107+
108+ Raises:
109+ ValueError: If the qscheme is not a supported symmetric scheme.
110+
111+ """
57112 if self .weight is None :
58113 return None
59114 # Validate that weight uses a supported qscheme
@@ -65,11 +120,46 @@ def get_weight_qspec(self) -> QuantizationSpec | None:
65120 return self .weight
66121
67122 def get_bias_qspec (self , node : torch .fx .Node ) -> QuantizationSpec | None :
68- """Returns QuantizationSpec 'bias' after asserting that bias.dtype is torch.float."""
123+ """Get the derived or validated bias spec.
124+
125+ For conv/linear ops, derive bias qparams from the input/weight observers.
126+ Otherwise, validate a user-provided floating-point bias spec.
127+
128+ Args:
129+ node (torch.fx.Node): Node whose bias spec is requested.
130+
131+ Returns:
132+ QuantizationSpec | None: Derived or provided bias spec, or ``None``
133+ when unset.
134+
135+ Raises:
136+ ValueError: If deriving qparams sees an unexpected number of
137+ observers/fake-quantizers, or if a provided bias dtype is not
138+ floating-point.
139+
140+ """
69141
70142 def _derive_qparams_fn (
71143 obs_or_fqs : list [ObserverOrFakeQuantize ],
72144 ) -> tuple [torch .Tensor , torch .Tensor ]:
145+ """Compute bias scale/zero-point from activation/weight observers.
146+
147+ Expect two observers or fake-quantize modules: one for the input
148+ activation and one for the weight. The bias scale is the product of
149+ input and weight scales, and the zero-point is a tensor of zeros.
150+
151+ Args:
152+ obs_or_fqs (list[ObserverOrFakeQuantize]): Observers/fake-quant
153+ in order ``[act, weight]``.
154+
155+ Returns:
156+ Tuple[torch.Tensor, torch.Tensor]: Bias scale tensor and
157+ integer zero-point tensor.
158+
159+ Raises:
160+ ValueError: If the list does not contain exactly two items.
161+
162+ """
73163 # Validate expected number of observers/fake-quantizes
74164 if len (obs_or_fqs ) != 2 :
75165 raise ValueError (
0 commit comments