@@ -400,13 +400,15 @@ class TorchAoConfig(QuantizationConfigMixin):
400400
401401 Args:
402402 quant_type (`str`):
403- The type of quantization we want to use, currently supporting: `int4_weight_only`, `int8_weight_only` and `int8_dynamic_activation_int8_weight`.
403+ The type of quantization we want to use, currently supporting: `int4_weight_only`, `int8_weight_only` and
404+ `int8_dynamic_activation_int8_weight`.
404405 modules_to_not_convert (`list`, *optional*, default to `None`):
405- The list of modules to not quantize, useful for quantizing models that explicitly require to have
406- some modules left in their original precision.
406+ The list of modules to not quantize, useful for quantizing models that explicitly require to have some
407+ modules left in their original precision.
407408 kwargs (`Dict[str, Any]`, *optional*):
408- The keyword arguments for the chosen type of quantization, for example, int4_weight_only quantization supports two keyword arguments
409- `group_size` and `inner_k_tiles` currently. More API examples and documentation of arguments can be found in
409+ The keyword arguments for the chosen type of quantization, for example, int4_weight_only quantization
410+ supports two keyword arguments `group_size` and `inner_k_tiles` currently. More API examples and
411+ documentation of arguments can be found in
410412 https://github.com/pytorch/ao/tree/main/torchao/quantization#other-available-quantization-techniques
411413
412414 Example:
@@ -415,15 +417,17 @@ class TorchAoConfig(QuantizationConfigMixin):
415417 TODO(aryan): update
416418 quantization_config = TorchAoConfig("int4_weight_only", group_size=32)
417419 # int4_weight_only quant is only working with *torch.bfloat16* dtype right now
418- model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cuda", torch_dtype=torch.bfloat16, quantization_config=quantization_config)
420+ model = AutoModelForCausalLM.from_pretrained(
421+ model_id, device_map="cuda", torch_dtype=torch.bfloat16, quantization_config=quantization_config
422+ )
419423 ```
420424 """
421425
422426 def __init__ (self , quant_type : str , modules_to_not_convert : Optional [List ] = None , ** kwargs ):
423427 self .quant_method = QuantizationMethod .TORCHAO
424428 self .quant_type = quant_type
425429 self .modules_to_not_convert = modules_to_not_convert
426-
430+
427431 # When we load from serialized config, "quant_type_kwargs" will be the key
428432 if "quant_type_kwargs" in kwargs :
429433 self .quant_type_kwargs = kwargs ["quant_type_kwargs" ]
@@ -448,7 +452,7 @@ def __init__(self, quant_type: str, modules_to_not_convert: Optional[List] = Non
448452
449453 if len (unsupported_kwargs ) > 0 :
450454 raise ValueError (
451- f" The quantization method \ "{ method } \ " does not supported the following keyword arguments: "
455+ f' The quantization method "{ method } " does not supported the following keyword arguments: '
452456 f"{ unsupported_kwargs } . The following keywords arguments are supported: { all_kwargs } ."
453457 )
454458
@@ -460,16 +464,17 @@ def _get_torchao_quant_type_to_method(cls):
460464
461465 if is_torchao_available ():
462466 from torchao .quantization import (
463- int4_weight_only ,
464- int8_dynamic_activation_int8_weight ,
465- int8_dynamic_activation_int4_weight ,
466- int8_weight_only ,
467467 float8_dynamic_activation_float8_weight ,
468468 float8_static_activation_float8_weight ,
469469 float8_weight_only ,
470470 fpx_weight_only ,
471+ int4_weight_only ,
472+ int8_dynamic_activation_int4_weight ,
473+ int8_dynamic_activation_int8_weight ,
474+ int8_weight_only ,
471475 uintx_weight_only ,
472476 )
477+
473478 # TODO(aryan): Add a note on how to use PerAxis and PerGroup observers
474479 from torchao .quantization .observer import PerRow , PerTensor
475480
@@ -502,8 +507,10 @@ def _get_torchao_quant_type_to_method(cls):
502507 def generate_float8dq_types (dtype : torch .dtype ):
503508 name = "e5m2" if dtype == torch .float8_e5m2 else "e4m3"
504509 types = {}
505-
506- types [f"float8dq_{ name } _a8w8" ] = partial (float8_dynamic_activation_float8_weight , activation_dtype = dtype , weight_dtype = dtype )
510+
511+ types [f"float8dq_{ name } _a8w8" ] = partial (
512+ float8_dynamic_activation_float8_weight , activation_dtype = dtype , weight_dtype = dtype
513+ )
507514 for activation_granularity_cls in [PerTensor , PerRow ]:
508515 for weight_granularity_cls in [PerTensor , PerRow ]:
509516 activation_name = "t" if activation_granularity_cls is PerTensor else "r"
@@ -526,22 +533,22 @@ def generate_float8dq_types(dtype: torch.dtype):
526533 weight_dtype = dtype ,
527534 granularity = (activation_granularity_cls (), weight_granularity_cls ()),
528535 )
529-
536+
530537 return types
531538
532539 def generate_fpx_quantization_types (bits : int ):
533540 types = {}
534-
541+
535542 for ebits in range (1 , bits ):
536543 mbits = bits - ebits - 1
537544 types [f"fp{ bits } _e{ ebits } m{ mbits } " ] = partial (fpx_weight_only , ebits = ebits , mbits = mbits )
538545 types [f"fp{ bits } _e{ ebits } m{ mbits } _a16w{ bits } " ] = partial (fpx_weight_only , ebits = ebits , mbits = mbits )
539-
546+
540547 non_sign_bits = bits - 1
541548 default_ebits = (non_sign_bits + 1 ) // 2
542549 default_mbits = non_sign_bits - default_ebits
543550 types [f"fp{ bits } " ] = partial (fpx_weight_only , ebits = default_ebits , mbits = default_mbits )
544-
551+
545552 return types
546553
547554 # TODO(aryan): handle cuda capability and torch 2.2/2.3
@@ -561,11 +568,19 @@ def generate_fpx_quantization_types(bits: int):
561568 # float8_e5m2 weight + float8 activation (dynamic)
562569 "float8_dynamic_activation_float8_weight" : float8_dynamic_activation_float8_weight ,
563570 "float8dq" : float8_dynamic_activation_float8_weight ,
564- "float8dq_e5m2" : partial (float8_dynamic_activation_float8_weight , activation_dtype = torch .float8_e5m2 , weight_dtype = torch .float8_e5m2 ),
571+ "float8dq_e5m2" : partial (
572+ float8_dynamic_activation_float8_weight ,
573+ activation_dtype = torch .float8_e5m2 ,
574+ weight_dtype = torch .float8_e5m2 ,
575+ ),
565576 "float8_a8w8" : float8_dynamic_activation_float8_weight ,
566577 ** generate_float8dq_types (torch .float8_e5m2 ),
567578 # float8_e4m3 weight + float8 activation (dynamic)
568- "float8dq_e4m3" : partial (float8_dynamic_activation_float8_weight , activation_dtype = torch .float8_e4m3fn , weight_dtype = torch .float8_e4m3fn ),
579+ "float8dq_e4m3" : partial (
580+ float8_dynamic_activation_float8_weight ,
581+ activation_dtype = torch .float8_e4m3fn ,
582+ weight_dtype = torch .float8_e4m3fn ,
583+ ),
569584 ** generate_float8dq_types (torch .float8_e4m3fn ),
570585 # float8 weight + float8 activation (static)
571586 "float8_static_activation_float8_weight" : float8_static_activation_float8_weight ,
0 commit comments