@@ -89,19 +89,39 @@ def __init__(
8989 # Groupwise quantization for weight
9090 self .per_channel_group = False
9191 self .group_size = group_size
92+
93+ tensor = q_input .meta ["val" ]
94+
9295 if self .group_size > 0 :
9396 assert (
9497 self .per_channel is True
9598 ), "Only per channel quantization supports groupwise quantization"
9699 assert (
97100 cast (torch .Tensor , scale ).ndim == 2
98101 ), "Scale must be 2D for per channel groupwise quant"
99- self .per_channel_group = True
100- assert group_size > 0 , "Group size must be greater than 0"
101- self .is_per_channel_group = self .per_channel and self .group_size > 0
102+ # Assumed scale shape - [out_channels, in_channels/group_size]
103+ input_channels = cast (torch .Tensor , scale ).shape [1 ] * self .group_size
104+ # 2d weight tensor shape - [out_channels, in_channels]
105+ assert (
106+ tensor .shape [1 ] == input_channels
107+ ), "Invalid input channels for groupwise quant"
108+ # Prefer per_channel over per_channel_group when group_size == input_channels for non int4 cases only
109+ # int4 case need more fixes to map qb4w to qc4w. Incorrect scales being passed down to xnnpack.
110+ self .per_channel_group = (
111+ self .group_size <= input_channels
112+ if self .is_qc4w
113+ else self .group_size < input_channels
114+ )
115+
116+ if not self .per_channel_group :
117+ if cast (torch .Tensor , scale ).ndim == 2 :
118+ # TODO: don't reshape scale for per_channel cases
119+ assert (
120+ cast (torch .Tensor , scale ).shape [1 ] == 1
121+ ), "Invalid scale shape for per channel quantization"
122+ scale = cast (torch .Tensor , scale ).squeeze (1 )
102123
103- if per_channel and not self .is_per_channel_group :
104- tensor = q_input .meta ["val" ]
124+ if per_channel and not self .per_channel_group :
105125 assert (
106126 tensor .shape [self .axis ] == cast (torch .Tensor , self .scale ).shape [0 ]
107127 ), f"Invalid size of per channel quantization scales, axis: { self .axis } , scale size: { self .scale .shape } , tensor shape: { tensor .shape } "
@@ -110,6 +130,39 @@ def __init__(
110130 tensor .shape [self .axis ] == cast (torch .Tensor , self .zp ).shape [0 ]
111131 ), f"Invalid size of per channel quantization zero-points, axis: { self .axis } , zp size: { self .zp .shape } , tensor shape: { tensor .shape } "
112132
133+ def __str__ (self ) -> str :
134+ """String representation of QuantParams for debugging and logging."""
135+ assert isinstance (self .scale , float ) or isinstance (self .scale , torch .Tensor )
136+ scale_str = (
137+ f"{ self .scale } "
138+ if isinstance (self .scale , float )
139+ else f"tensor{ tuple (self .scale .shape )} "
140+ )
141+ assert isinstance (self .zp , float ) or isinstance (self .zp , torch .Tensor )
142+ zp_str = (
143+ f"{ self .zp } "
144+ if isinstance (self .zp , float )
145+ else f"tensor{ tuple (self .zp .shape )} "
146+ )
147+
148+ return (
149+ f"QuantParams("
150+ f"per_channel={ self .per_channel } , "
151+ f"per_channel_group={ self .per_channel_group } , "
152+ f"scale={ scale_str } , "
153+ f"zp={ zp_str } , "
154+ f"axis={ self .axis } , "
155+ f"dtype={ self .dtype } , "
156+ f"qmin={ self .qmin } , "
157+ f"qmax={ self .qmax } , "
158+ f"is_dynamic={ self .is_dynamic } , "
159+ f"is_input={ self .is_input } , "
160+ f"is_output={ self .is_output } , "
161+ f"group_size={ self .group_size } , "
162+ f"is_qc4w={ self .is_qc4w } "
163+ f")"
164+ )
165+
113166 def quantize_tensor (self , tensor : torch .Tensor ) -> torch .Tensor :
114167 # Do nothing if already quantized by the Quantizer
115168 if tensor .dtype == self .dtype :
0 commit comments