@@ -121,25 +121,6 @@ def build_model(
121121 return export_llama (modelname , args )
122122
123123
124- def _is_valid_torchao_qmode_type (value ):
125- if not isinstance (value , str ):
126- return False
127-
128- if not value .startswith ("torchao:" ):
129- return False
130-
131- patterns = [
132- r"emb.(\d+),(\d+)&lin8da.(\d+),(\d+)" ,
133- r"emb.(\d+),(\d+)" ,
134- r"lin8da.(\d+),(\d+)" ,
135- ]
136- for pattern in patterns :
137- matches = re .findall (pattern , value )
138- if len (matches ) == 1 :
139- return True
140- return False
141-
142-
143124def build_args_parser () -> argparse .ArgumentParser :
144125 ckpt_dir = f"{ Path (__file__ ).absolute ().parent .as_posix ()} "
145126 parser = argparse .ArgumentParser ()
@@ -173,20 +154,6 @@ def build_args_parser() -> argparse.ArgumentParser:
173154 help = "Use PT2E quantization. Comma separated options. e.g. xnnpack_dynamic (for per channel 8 bit weight), xnnpack_dynamic_qc4 (for per channel 4 bit weight), embedding." ,
174155 )
175156
176- def _qmode_type (value ):
177- choices = ["int8" , "8da4w" , "8da4w-gptq" , "vulkan_4w" ]
178- if not (value in choices or _is_valid_torchao_qmode_type (value )):
179- raise argparse .ArgumentTypeError (
180- f"Got qmode { value } , but expected one of: { choices } or a valid torchao quantization pattern such as:"
181- + "\n \t * torchao:emb.{embed_bitwidth},{embed_groupsize}"
182- + "\n \t \t (e.g., torchao:emb.4,32)"
183- + "\n \t * torchao:emb.{embed_bitwidth},{embed_groupsize}&lin8da.{linear_bitwidth},{linear_groupsize}"
184- + "\n \t \t (e.g., torchao:emb.4,32&lin8da.4,128)"
185- + "\n \t * torchao:lin8da.{linear_bitwidth},{linear_groupsize}"
186- + "\n t\t \t (e.g., torchao:lin8da.4,128)"
187- )
188- return value
189-
190157 parser .add_argument (
191158 "-qmode" ,
192159 "--quantization_mode" ,
@@ -601,6 +568,40 @@ def get_quantizer_and_quant_params(args):
601568 return pt2e_quant_params , quantizers , quant_dtype
602569
603570
571+ def _is_valid_torchao_qmode_type (value ):
572+ if not isinstance (value , str ):
573+ return False
574+
575+ if not value .startswith ("torchao:" ):
576+ return False
577+
578+ patterns = [
579+ r"emb.(\d+),(\d+)&lin8da.(\d+),(\d+)" ,
580+ r"emb.(\d+),(\d+)" ,
581+ r"lin8da.(\d+),(\d+)" ,
582+ ]
583+ for pattern in patterns :
584+ matches = re .findall (pattern , value )
585+ if len (matches ) == 1 :
586+ return True
587+ return False
588+
589+
590+ def _qmode_type (value ):
591+ choices = ["int8" , "8da4w" , "8da4w-gptq" , "vulkan_4w" ]
592+ if not (value in choices or _is_valid_torchao_qmode_type (value )):
593+ raise argparse .ArgumentTypeError (
594+ f"Got qmode { value } , but expected one of: { choices } or a valid torchao quantization pattern such as:"
595+ + "\n \t * torchao:emb.{embed_bitwidth},{embed_groupsize}"
596+ + "\n \t \t (e.g., torchao:emb.4,32)"
597+ + "\n \t * torchao:emb.{embed_bitwidth},{embed_groupsize}&lin8da.{linear_bitwidth},{linear_groupsize}"
598+ + "\n \t \t (e.g., torchao:emb.4,32&lin8da.4,128)"
599+ + "\n \t * torchao:lin8da.{linear_bitwidth},{linear_groupsize}"
600+ + "\n t\t \t (e.g., torchao:lin8da.4,128)"
601+ )
602+ return value
603+
604+
604605def _validate_args (args ):
605606 """
606607 TODO: Combine all the backends under --backend args
@@ -618,6 +619,7 @@ def _validate_args(args):
618619 if args .enable_dynamic_shape :
619620 raise ValueError (
620621 "Dynamic shape is not currently supported with torchao qmode. Please use --disable_dynamic_shape."
622+ "If you need this feature, please file an issue."
621623 )
622624
623625
0 commit comments