14
14
15
15
import warnings
16
16
from enum import Enum
17
- from typing import Any , Dict , Optional , Union
17
+ from typing import Any , Dict , List , Optional , Union
18
18
19
19
import torch
20
20
from compressed_tensors .utils import Aliasable
@@ -153,8 +153,8 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
153
153
:param symmetric: whether or not quantization scale is symmetric about zero-point
154
154
:param strategy: string id determining the scope of scale/zero-point to apply
155
155
:param group_size: group length to use for the group strategy
156
- :param block_structure: 2d block structure to use for the block strategy, must be
157
- of the format "2x4", "8x16", etc .
156
+ :param block_structure: 2d block structure to use for the block strategy; must be
157
+ a list of two ints [rows, cols] like [128, 128] .
158
158
:param dynamic: set True to perform dynamic quantization - values will not be
159
159
calibrated during calibration phase, instead during inference new quantization
160
160
ranges will be observed with every sample. Defaults to False for static
@@ -169,7 +169,7 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
169
169
symmetric : bool = True
170
170
group_size : Optional [int ] = None
171
171
strategy : Optional [QuantizationStrategy ] = None
172
- block_structure : Optional [str ] = None
172
+ block_structure : Optional [List [ int ] ] = None
173
173
dynamic : Union [DynamicType , bool ] = False
174
174
actorder : Union [ActivationOrdering , bool , None ] = None
175
175
observer : Optional [str ] = Field (
@@ -207,6 +207,28 @@ def validate_group(cls, value) -> Union[int, None]:
207
207
208
208
return value
209
209
210
+ @field_validator ("block_structure" , mode = "before" )
211
+ def validate_block_structure (cls , value ) -> Optional [List [int ]]:
212
+ if value is None :
213
+ return value
214
+ # For backward compatibility, allow string format "2x4", "8x16", etc.
215
+ if isinstance (value , str ):
216
+ try :
217
+ return [int (x ) for x in value .split ("x" )]
218
+ except Exception :
219
+ raise ValueError (
220
+ f"Invalid block_structure '{ value } '. Must be a list of two ints [rows, cols]."
221
+ )
222
+ if isinstance (value , (list , tuple )):
223
+ if len (value ) != 2 or not all (isinstance (v , int ) for v in value ):
224
+ raise ValueError (
225
+ f"Invalid block_structure '{ value } '. Must be a list of two ints [rows, cols]."
226
+ )
227
+ return list (value )
228
+ raise ValueError (
229
+ f"Invalid block_structure '{ value } '. Must be a list of two ints [rows, cols]."
230
+ )
231
+
210
232
@field_validator ("strategy" , mode = "before" )
211
233
def validate_strategy (cls , value ) -> Union [QuantizationStrategy , None ]:
212
234
if isinstance (value , str ):
@@ -277,14 +299,15 @@ def validate_model_after(model: "QuantizationArgs") -> "QuantizationArgs":
277
299
278
300
# infer observer w.r.t. dynamic
279
301
if dynamic :
280
- if strategy not in (
302
+ supported_strategies = (
281
303
QuantizationStrategy .TOKEN ,
282
304
QuantizationStrategy .TENSOR ,
283
305
QuantizationStrategy .TENSOR_GROUP ,
284
- ):
306
+ QuantizationStrategy .GROUP ,
307
+ )
308
+ if strategy not in supported_strategies :
285
309
raise ValueError (
286
- f"One of { (QuantizationStrategy .TOKEN , QuantizationStrategy .TENSOR , QuantizationStrategy .TENSOR_GROUP )} "
287
- "must be used for dynamic quantization" ,
310
+ f"One of { supported_strategies } must be used for dynamic quantization"
288
311
)
289
312
290
313
if (
0 commit comments