File tree Expand file tree Collapse file tree 3 files changed +36
-0
lines changed Expand file tree Collapse file tree 3 files changed +36
-0
lines changed Original file line number Diff line number Diff line change 54
54
GemliteUIntXWeightOnlyConfig ,
55
55
Int4DynamicActivationInt4WeightConfig ,
56
56
Int4WeightOnlyConfig ,
57
+ Int8DynActOnlyConfig ,
57
58
Int8DynamicActivationInt4WeightConfig ,
58
59
Int8DynamicActivationInt8WeightConfig ,
59
60
Int8DynamicActivationIntxWeightConfig ,
144
145
"Int8DynamicActivationIntxWeightConfig" ,
145
146
"Int4WeightOnlyConfig" ,
146
147
"Float8DynamicActivationInt4WeightConfig" ,
148
+ "Int8DynActOnlyConfig" ,
147
149
"Int8WeightOnlyConfig" ,
148
150
"Float8WeightOnlyConfig" ,
149
151
"Float8DynamicActivationFloat8WeightConfig" ,
Original file line number Diff line number Diff line change @@ -292,6 +292,7 @@ def _infer_fake_quantize_configs(
292
292
# avoid circular imports
293
293
from torchao .quantization import (
294
294
Int4WeightOnlyConfig ,
295
+ Int8DynActOnlyConfig ,
295
296
Int8DynamicActivationInt4WeightConfig ,
296
297
)
297
298
@@ -315,5 +316,12 @@ def _infer_fake_quantize_configs(
315
316
zero_point_domain = base_config .zero_point_domain ,
316
317
)
317
318
return (None , weight_config )
319
+ elif isinstance (base_config , Int8DynActOnlyConfig ):
320
+ act_config = IntxFakeQuantizeConfig (
321
+ dtype = torch .int8 ,
322
+ granularity = "per_token" ,
323
+ is_symmetric = base_config .is_symmetric ,
324
+ )
325
+ return (act_config , None )
318
326
else :
319
327
raise ValueError ("Unexpected base config: %s" % base_config )
Original file line number Diff line number Diff line change 148
148
"gemlite_uintx_weight_only" ,
149
149
"float8_dynamic_activation_float8_weight" ,
150
150
"float8_static_activation_float8_weight" ,
151
+ "Int8DynActOnlyConfig" ,
151
152
"Int8DynActInt4WeightQuantizer" ,
152
153
"Float8DynamicActivationFloat8SemiSparseWeightConfig" ,
153
154
"ModuleFqnToConfig" ,
@@ -1312,6 +1313,31 @@ def _float8_cutlass_quant_sparse(
1312
1313
)
1313
1314
1314
1315
1316
+ @dataclass
1317
+ class Int8DynActOnlyConfig (AOBaseConfig ):
1318
+ """
1319
+ Configuration for applying int8 dynamic symmetric per-token activation quantization to linear layers.
1320
+ Args:
1321
+ is_symmetric: bool = False - Whether to use symmetric quantization for activations.
1322
+ """
1323
+
1324
+ is_symmetric : bool = False
1325
+
1326
+
1327
+ @register_quantize_module_handler (Int8DynActOnlyConfig )
1328
+ def _int8_dynamic_activation_transform (
1329
+ module : torch .nn .Module , config : Int8DynActOnlyConfig
1330
+ ) -> torch .nn .Module :
1331
+ weight = module .weight
1332
+ if config .is_symmetric == MappingType .SYMMETRIC :
1333
+ input_quant_func = _int8_symm_per_token_reduced_range_quant
1334
+ else :
1335
+ input_quant_func = _int8_asymm_per_token_quant
1336
+ weight = to_linear_activation_quantized (weight , input_quant_func )
1337
+ module .weight = torch .nn .Parameter (weight , requires_grad = False )
1338
+ return module
1339
+
1340
+
1315
1341
@dataclass
1316
1342
class Int8DynamicActivationInt8WeightConfig (AOBaseConfig ):
1317
1343
"""
You can’t perform that action at this time.
0 commit comments