@@ -89,19 +89,39 @@ def __init__(
89
89
# Groupwise quantization for weight
90
90
self .per_channel_group = False
91
91
self .group_size = group_size
92
+
93
+ tensor = q_input .meta ["val" ]
94
+
92
95
if self .group_size > 0 :
93
96
assert (
94
97
self .per_channel is True
95
98
), "Only per channel quantization supports groupwise quantization"
96
99
assert (
97
100
cast (torch .Tensor , scale ).ndim == 2
98
101
), "Scale must be 2D for per channel groupwise quant"
99
- self .per_channel_group = True
100
- assert group_size > 0 , "Group size must be greater than 0"
101
- self .is_per_channel_group = self .per_channel and self .group_size > 0
102
+ # Assumed scale shape - [out_channels, in_channels/group_size]
103
+ input_channels = cast (torch .Tensor , scale ).shape [1 ] * self .group_size
104
+ # 2d weight tensor shape - [out_channels, in_channels]
105
+ assert (
106
+ tensor .shape [1 ] == input_channels
107
+ ), "Invalid input channels for groupwise quant"
108
+ # Prefer per_channel over per_channel_group when group_size == input_channels for non int4 cases only
109
+ # int4 case need more fixes to map qb4w to qc4w. Incorrect scales being passed down to xnnpack.
110
+ self .per_channel_group = (
111
+ self .group_size <= input_channels
112
+ if self .is_qc4w
113
+ else self .group_size < input_channels
114
+ )
115
+
116
+ if not self .per_channel_group :
117
+ if cast (torch .Tensor , scale ).ndim == 2 :
118
+ # TODO: don't reshape scale for per_channel cases
119
+ assert (
120
+ cast (torch .Tensor , scale ).shape [1 ] == 1
121
+ ), "Invalid scale shape for per channel quantization"
122
+ scale = cast (torch .Tensor , scale ).squeeze (1 )
102
123
103
- if per_channel and not self .is_per_channel_group :
104
- tensor = q_input .meta ["val" ]
124
+ if per_channel and not self .per_channel_group :
105
125
assert (
106
126
tensor .shape [self .axis ] == cast (torch .Tensor , self .scale ).shape [0 ]
107
127
), f"Invalid size of per channel quantization scales, axis: { self .axis } , scale size: { self .scale .shape } , tensor shape: { tensor .shape } "
@@ -110,6 +130,39 @@ def __init__(
110
130
tensor .shape [self .axis ] == cast (torch .Tensor , self .zp ).shape [0 ]
111
131
), f"Invalid size of per channel quantization zero-points, axis: { self .axis } , zp size: { self .zp .shape } , tensor shape: { tensor .shape } "
112
132
133
+ def __str__ (self ) -> str :
134
+ """String representation of QuantParams for debugging and logging."""
135
+ assert isinstance (self .scale , float ) or isinstance (self .scale , torch .Tensor )
136
+ scale_str = (
137
+ f"{ self .scale } "
138
+ if isinstance (self .scale , float )
139
+ else f"tensor{ tuple (self .scale .shape )} "
140
+ )
141
+ assert isinstance (self .zp , float ) or isinstance (self .zp , torch .Tensor )
142
+ zp_str = (
143
+ f"{ self .zp } "
144
+ if isinstance (self .zp , float )
145
+ else f"tensor{ tuple (self .zp .shape )} "
146
+ )
147
+
148
+ return (
149
+ f"QuantParams("
150
+ f"per_channel={ self .per_channel } , "
151
+ f"per_channel_group={ self .per_channel_group } , "
152
+ f"scale={ scale_str } , "
153
+ f"zp={ zp_str } , "
154
+ f"axis={ self .axis } , "
155
+ f"dtype={ self .dtype } , "
156
+ f"qmin={ self .qmin } , "
157
+ f"qmax={ self .qmax } , "
158
+ f"is_dynamic={ self .is_dynamic } , "
159
+ f"is_input={ self .is_input } , "
160
+ f"is_output={ self .is_output } , "
161
+ f"group_size={ self .group_size } , "
162
+ f"is_qc4w={ self .is_qc4w } "
163
+ f")"
164
+ )
165
+
113
166
def quantize_tensor (self , tensor : torch .Tensor ) -> torch .Tensor :
114
167
# Do nothing if already quantized by the Quantizer
115
168
if tensor .dtype == self .dtype :
0 commit comments