@@ -436,14 +436,14 @@ def __init__(self, quant_type: str, modules_to_not_convert: Optional[List] = Non
436436        else :
437437            self .quant_type_kwargs  =  kwargs 
438438
439-         _STR_TO_METHOD  =  self ._get_torchao_quant_type_to_method ()
440-         if  self .quant_type  not  in _STR_TO_METHOD .keys ():
439+         TORCHAO_QUANT_TYPE_METHODS  =  self ._get_torchao_quant_type_to_method ()
440+         if  self .quant_type  not  in TORCHAO_QUANT_TYPE_METHODS .keys ():
441441            raise  ValueError (
442442                f"Requested quantization type: { self .quant_type }  
443443                f"provided quantization type should be supported, please open an issue at https://github.com/huggingface/diffusers/issues." 
444444            )
445445
446-         method  =  _STR_TO_METHOD [self .quant_type ]
446+         method  =  TORCHAO_QUANT_TYPE_METHODS [self .quant_type ]
447447        signature  =  inspect .signature (method )
448448        all_kwargs  =  {
449449            param .name 
@@ -481,24 +481,6 @@ def _get_torchao_quant_type_to_method(cls):
481481            # TODO(aryan): Add a note on how to use PerAxis and PerGroup observers 
482482            from  torchao .quantization .observer  import  PerRow , PerTensor 
483483
484-             INT4_QUANTIZATION_TYPES  =  {
485-                 # int4 weight + bfloat16/float16 activation 
486-                 "int4wo" : int4_weight_only ,
487-                 "int4_weight_only" : int4_weight_only ,
488-                 # int4 weight + int8 activation 
489-                 "int4dq" : int8_dynamic_activation_int4_weight ,
490-                 "int8_dynamic_activation_int4_weight" : int8_dynamic_activation_int4_weight ,
491-             }
492- 
493-             INT8_QUANTIZATION_TYPES  =  {
494-                 # int8 weight + bfloat16/float16 activation 
495-                 "int8wo" : int8_weight_only ,
496-                 "int8_weight_only" : int8_weight_only ,
497-                 # int8 weight + int8 activation 
498-                 "int8dq" : int8_dynamic_activation_int8_weight ,
499-                 "int8_dynamic_activation_int8_weight" : int8_dynamic_activation_int8_weight ,
500-             }
501- 
502484            def  generate_float8dq_types (dtype : torch .dtype ):
503485                name  =  "e5m2"  if  dtype  ==  torch .float8_e5m2  else  "e4m3" 
504486                types  =  {}
@@ -529,6 +511,41 @@ def generate_fpx_quantization_types(bits: int):
529511
530512                return  types 
531513
514+             def  generate_uintx_quantization_types (bits : int ):
515+                 UINTX_TO_DTYPE  =  {
516+                     1 : torch .uint1 ,
517+                     2 : torch .uint2 ,
518+                     3 : torch .uint3 ,
519+                     4 : torch .uint4 ,
520+                     5 : torch .uint5 ,
521+                     6 : torch .uint6 ,
522+                     7 : torch .uint7 ,
523+                     # 8: torch.uint8,  # uint8 quantization is not supported 
524+                 }
525+ 
526+                 types  =  {}
527+                 types [f"uint{ bits }  ] =  partial (uintx_weight_only , dtype = UINTX_TO_DTYPE [bits ])
528+                 types [f"uint{ bits }  ] =  partial (uintx_weight_only , dtype = UINTX_TO_DTYPE [bits ])
529+                 return  types 
530+ 
531+             INT4_QUANTIZATION_TYPES  =  {
532+                 # int4 weight + bfloat16/float16 activation 
533+                 "int4wo" : int4_weight_only ,
534+                 "int4_weight_only" : int4_weight_only ,
535+                 # int4 weight + int8 activation 
536+                 "int4dq" : int8_dynamic_activation_int4_weight ,
537+                 "int8_dynamic_activation_int4_weight" : int8_dynamic_activation_int4_weight ,
538+             }
539+ 
540+             INT8_QUANTIZATION_TYPES  =  {
541+                 # int8 weight + bfloat16/float16 activation 
542+                 "int8wo" : int8_weight_only ,
543+                 "int8_weight_only" : int8_weight_only ,
544+                 # int8 weight + int8 activation 
545+                 "int8dq" : int8_dynamic_activation_int8_weight ,
546+                 "int8_dynamic_activation_int8_weight" : int8_dynamic_activation_int8_weight ,
547+             }
548+ 
532549            # TODO(aryan): handle torch 2.2/2.3 
533550            FLOATX_QUANTIZATION_TYPES  =  {
534551                # float8_e5m2 weight + bfloat16/float16 activation 
@@ -540,7 +557,7 @@ def generate_fpx_quantization_types(bits: int):
540557                # float8_e5m2 weight + float8 activation (dynamic) 
541558                "float8dq" : float8_dynamic_activation_float8_weight ,
542559                "float8_dynamic_activation_float8_weight" : float8_dynamic_activation_float8_weight ,
543-                 # ===== Matrix multiplication is not supported in float8_e5m2 so the following error  out. 
560+                 # ===== Matrix multiplication is not supported in float8_e5m2 so the following errors  out. 
544561                # However, changing activation_dtype=torch.float8_e4m3 might work here ===== 
545562                # "float8dq_e5m2": partial( 
546563                #     float8_dynamic_activation_float8_weight, 
@@ -566,27 +583,8 @@ def generate_fpx_quantization_types(bits: int):
566583                ** generate_fpx_quantization_types (5 ),
567584                ** generate_fpx_quantization_types (6 ),
568585                ** generate_fpx_quantization_types (7 ),
569-                 # ===== Errors out with "torch.cat(): expected a non-empty list of Tensors" ===== 
570-                 # **generate_fpx_quantization_types(8), 
571-             }
572- 
573-             UINTX_TO_DTYPE  =  {
574-                 1 : torch .uint1 ,
575-                 2 : torch .uint2 ,
576-                 3 : torch .uint3 ,
577-                 4 : torch .uint4 ,
578-                 5 : torch .uint5 ,
579-                 6 : torch .uint6 ,
580-                 7 : torch .uint7 ,
581-                 # 8: torch.uint8,  # uint8 quantization is not supported 
582586            }
583587
584-             def  generate_uintx_quantization_types (bits : int ):
585-                 types  =  {}
586-                 types [f"uint{ bits }  ] =  partial (uintx_weight_only , dtype = UINTX_TO_DTYPE [bits ])
587-                 types [f"uint{ bits }  ] =  partial (uintx_weight_only , dtype = UINTX_TO_DTYPE [bits ])
588-                 return  types 
589- 
590588            UINTX_QUANTIZATION_DTYPES  =  {
591589                "uintx" : uintx_weight_only ,
592590                "uintx_weight_only" : uintx_weight_only ,
@@ -614,6 +612,7 @@ def generate_uintx_quantization_types(bits: int):
614612                "uint_a16w7" : partial (uintx_weight_only , dtype = torch .uint7 ),
615613                # "uint_a16w8": partial(uintx_weight_only, dtype=torch.uint8),  # uint8 quantization is not supported 
616614            }
615+ 
617616            SHORTHAND_FLOAT_QUANTIZATION_TYPES  =  {
618617                "float_e5m2_a16w8" : partial (float8_weight_only , weight_dtype = torch .float8_e5m2 ),
619618                "float_e4m3_a16w8" : partial (float8_weight_only , weight_dtype = torch .float8_e4m3fn ),
@@ -645,8 +644,6 @@ def generate_uintx_quantization_types(bits: int):
645644    @staticmethod  
646645    def  _is_cuda_capability_atleast_8_9 () ->  bool :
647646        if  not  torch .cuda .is_available ():
648-             if  torch .mps .is_available ():
649-                 return  False 
650647            raise  RuntimeError ("TorchAO requires a CUDA compatible GPU and installation of PyTorch." )
651648
652649        major , minor  =  torch .cuda .get_device_capability ()
@@ -655,9 +652,23 @@ def _is_cuda_capability_atleast_8_9() -> bool:
655652        return  major  >=  9 
656653
657654    def  get_apply_tensor_subclass (self ):
658-         _STR_TO_METHOD  =  self ._get_torchao_quant_type_to_method ()
659-         return  _STR_TO_METHOD [self .quant_type ](** self .quant_type_kwargs )
655+         TORCHAO_QUANT_TYPE_METHODS  =  self ._get_torchao_quant_type_to_method ()
656+         return  TORCHAO_QUANT_TYPE_METHODS [self .quant_type ](** self .quant_type_kwargs )
660657
661658    def  __repr__ (self ):
659+         r""" 
660+         Example of how this looks for `TorchAoConfig("uint_a16w4", group_size=32)`: 
661+ 
662+         ``` 
663+         TorchAoConfig { 
664+             "modules_to_not_convert": null, 
665+             "quant_method": "torchao", 
666+             "quant_type": "uint_a16w4", 
667+             "quant_type_kwargs": { 
668+                 "group_size": 32 
669+             } 
670+         } 
671+         ``` 
672+         """ 
662673        config_dict  =  self .to_dict ()
663674        return  f"{ self .__class__ .__name__ } { json .dumps (config_dict , indent = 2 , sort_keys = True )} \n " 
0 commit comments