@@ -568,39 +568,20 @@ def get_quantizer_and_quant_params(args):
568568 return pt2e_quant_params , quantizers , quant_dtype
569569
570570
571- def _is_valid_torchao_qmode_type (value ):
572- if not isinstance (value , str ):
573- return False
574-
575- if not value .startswith ("torchao:" ):
576- return False
577-
578- patterns = [
579- r"emb.(\d+),(\d+)&lin8da.(\d+),(\d+)" ,
580- r"emb.(\d+),(\d+)" ,
581- r"lin8da.(\d+),(\d+)" ,
582- ]
583- for pattern in patterns :
584- matches = re .findall (pattern , value )
585- if len (matches ) == 1 :
586- return True
587- return False
588-
589-
590571def _qmode_type (value ):
591572 choices = ["int8" , "8da4w" , "8da4w-gptq" , "vulkan_4w" ]
592- if not (value in choices or _is_valid_torchao_qmode_type (value )):
593- raise argparse .ArgumentTypeError (
594- f"Got qmode { value } , but expected one of: { choices } or a valid torchao quantization pattern such as:"
595- + "\n \t * torchao:emb.{embed_bitwidth},{embed_groupsize}"
596- + "\n \t \t (e.g., torchao:emb.4,32)"
597- + "\n \t * torchao:emb.{embed_bitwidth},{embed_groupsize}&lin8da.{linear_bitwidth},{linear_groupsize}"
598- + "\n \t \t (e.g., torchao:emb.4,32&lin8da.4,128)"
599- + "\n \t * torchao:lin8da.{linear_bitwidth},{linear_groupsize}"
600- + "\n t\t \t (e.g., torchao:lin8da.4,128)"
601- )
602- return value
573+ patterns = [r"torchao:8da{\d+}w" ]
603574
575+ if value in choices :
576+ return value
577+
578+ for pattern in patterns :
579+ matches = re .findall (pattern , value )
580+ if len (matches ) == 1 :
581+ return value
582+ raise argparse .ArgumentTypeError (
583+ f"Got qmode { value } , but expected one of { choices } , or one of the regex patterns { patterns } ."
584+ )
604585
605586def _validate_args (args ):
606587 """
@@ -615,10 +596,10 @@ def _validate_args(args):
615596 if args .num_sharding > 0 and not args .qnn :
616597 raise ValueError ("Model shard is only supported with qnn backend now." )
617598
618- if _is_valid_torchao_qmode_type ( args .quantization_mode ):
599+ if args .quantization_mode . startswith ( "torchao:" ) or args . embedding_quantize . startswith ( "torchao:" ):
619600 if args .enable_dynamic_shape :
620601 raise ValueError (
621- "Dynamic shape is not currently supported with torchao qmode . Please use --disable_dynamic_shape."
602+ "Dynamic shape is not currently supported with torchao ops . Please use --disable_dynamic_shape."
622603 "If you need this feature, please file an issue."
623604 )
624605
0 commit comments