@@ -568,39 +568,20 @@ def get_quantizer_and_quant_params(args):
568568    return  pt2e_quant_params , quantizers , quant_dtype 
569569
570570
571- def  _is_valid_torchao_qmode_type (value ):
572-     if  not  isinstance (value , str ):
573-         return  False 
574- 
575-     if  not  value .startswith ("torchao:" ):
576-         return  False 
577- 
578-     patterns  =  [
579-         r"emb.(\d+),(\d+)&lin8da.(\d+),(\d+)" ,
580-         r"emb.(\d+),(\d+)" ,
581-         r"lin8da.(\d+),(\d+)" ,
582-     ]
583-     for  pattern  in  patterns :
584-         matches  =  re .findall (pattern , value )
585-         if  len (matches ) ==  1 :
586-             return  True 
587-     return  False 
588- 
589- 
590571def  _qmode_type (value ):
591572    choices  =  ["int8" , "8da4w" , "8da4w-gptq" , "vulkan_4w" ]
592-     if  not  (value  in  choices  or  _is_valid_torchao_qmode_type (value )):
593-         raise  argparse .ArgumentTypeError (
594-             f"Got qmode { value } { choices }  
595-             +  "\n \t * torchao:emb.{embed_bitwidth},{embed_groupsize}" 
596-             +  "\n \t \t  (e.g., torchao:emb.4,32)" 
597-             +  "\n \t * torchao:emb.{embed_bitwidth},{embed_groupsize}&lin8da.{linear_bitwidth},{linear_groupsize}" 
598-             +  "\n \t \t  (e.g., torchao:emb.4,32&lin8da.4,128)" 
599-             +  "\n \t * torchao:lin8da.{linear_bitwidth},{linear_groupsize}" 
600-             +  "\n t\t \t  (e.g., torchao:lin8da.4,128)" 
601-         )
602-     return  value 
573+     patterns  =  [r"torchao:8da{\d+}w" ]
603574
575+     if  value  in  choices :
576+         return  value 
577+ 
578+     for  pattern  in  patterns :
579+         matches  =  re .findall (pattern , value )
580+         if  len (matches ) ==  1 :
581+             return  value 
582+     raise  argparse .ArgumentTypeError (
583+             f"Got qmode { value } { choices } { patterns }  
584+     )
604585
605586def  _validate_args (args ):
606587    """ 
@@ -615,10 +596,10 @@ def _validate_args(args):
615596    if  args .num_sharding  >  0  and  not  args .qnn :
616597        raise  ValueError ("Model shard is only supported with qnn backend now." )
617598
618-     if  _is_valid_torchao_qmode_type ( args .quantization_mode ):
599+     if  args .quantization_mode . startswith ( "torchao:" )  or   args . embedding_quantize . startswith ( "torchao:" ):
619600        if  args .enable_dynamic_shape :
620601            raise  ValueError (
621-                 "Dynamic shape is not currently supported with torchao qmode . Please use --disable_dynamic_shape." 
602+                 "Dynamic shape is not currently supported with torchao ops . Please use --disable_dynamic_shape." 
622603                "If you need this feature, please file an issue." 
623604            )
624605
0 commit comments