@@ -121,25 +121,6 @@ def build_model(
121121    return  export_llama (modelname , args )
122122
123123
124- def  _is_valid_torchao_qmode_type (value ):
125-     if  not  isinstance (value , str ):
126-         return  False 
127- 
128-     if  not  value .startswith ("torchao:" ):
129-         return  False 
130- 
131-     patterns  =  [
132-         r"emb.(\d+),(\d+)&lin8da.(\d+),(\d+)" ,
133-         r"emb.(\d+),(\d+)" ,
134-         r"lin8da.(\d+),(\d+)" ,
135-     ]
136-     for  pattern  in  patterns :
137-         matches  =  re .findall (pattern , value )
138-         if  len (matches ) ==  1 :
139-             return  True 
140-     return  False 
141- 
142- 
143124def  build_args_parser () ->  argparse .ArgumentParser :
144125    ckpt_dir  =  f"{ Path (__file__ ).absolute ().parent .as_posix ()}  
145126    parser  =  argparse .ArgumentParser ()
@@ -173,20 +154,6 @@ def build_args_parser() -> argparse.ArgumentParser:
173154        help = "Use PT2E quantization. Comma separated options. e.g. xnnpack_dynamic (for per channel 8 bit weight), xnnpack_dynamic_qc4 (for per channel 4 bit weight), embedding." ,
174155    )
175156
176-     def  _qmode_type (value ):
177-         choices  =  ["int8" , "8da4w" , "8da4w-gptq" , "vulkan_4w" ]
178-         if  not  (value  in  choices  or  _is_valid_torchao_qmode_type (value )):
179-             raise  argparse .ArgumentTypeError (
180-                 f"Got qmode { value } { choices }  
181-                 +  "\n \t * torchao:emb.{embed_bitwidth},{embed_groupsize}" 
182-                 +  "\n \t \t  (e.g., torchao:emb.4,32)" 
183-                 +  "\n \t * torchao:emb.{embed_bitwidth},{embed_groupsize}&lin8da.{linear_bitwidth},{linear_groupsize}" 
184-                 +  "\n \t \t  (e.g., torchao:emb.4,32&lin8da.4,128)" 
185-                 +  "\n \t * torchao:lin8da.{linear_bitwidth},{linear_groupsize}" 
186-                 +  "\n t\t \t  (e.g., torchao:lin8da.4,128)" 
187-             )
188-         return  value 
189- 
190157    parser .add_argument (
191158        "-qmode" ,
192159        "--quantization_mode" ,
@@ -601,6 +568,40 @@ def get_quantizer_and_quant_params(args):
601568    return  pt2e_quant_params , quantizers , quant_dtype 
602569
603570
571+ def  _is_valid_torchao_qmode_type (value ):
572+     if  not  isinstance (value , str ):
573+         return  False 
574+ 
575+     if  not  value .startswith ("torchao:" ):
576+         return  False 
577+ 
578+     patterns  =  [
579+         r"emb.(\d+),(\d+)&lin8da.(\d+),(\d+)" ,
580+         r"emb.(\d+),(\d+)" ,
581+         r"lin8da.(\d+),(\d+)" ,
582+     ]
583+     for  pattern  in  patterns :
584+         matches  =  re .findall (pattern , value )
585+         if  len (matches ) ==  1 :
586+             return  True 
587+     return  False 
588+ 
589+ 
590+ def  _qmode_type (value ):
591+     choices  =  ["int8" , "8da4w" , "8da4w-gptq" , "vulkan_4w" ]
592+     if  not  (value  in  choices  or  _is_valid_torchao_qmode_type (value )):
593+         raise  argparse .ArgumentTypeError (
594+             f"Got qmode { value } { choices }  
595+             +  "\n \t * torchao:emb.{embed_bitwidth},{embed_groupsize}" 
596+             +  "\n \t \t  (e.g., torchao:emb.4,32)" 
597+             +  "\n \t * torchao:emb.{embed_bitwidth},{embed_groupsize}&lin8da.{linear_bitwidth},{linear_groupsize}" 
598+             +  "\n \t \t  (e.g., torchao:emb.4,32&lin8da.4,128)" 
599+             +  "\n \t * torchao:lin8da.{linear_bitwidth},{linear_groupsize}" 
600+             +  "\n t\t \t  (e.g., torchao:lin8da.4,128)" 
601+         )
602+     return  value 
603+ 
604+ 
604605def  _validate_args (args ):
605606    """ 
606607    TODO: Combine all the backends under --backend args 
@@ -618,6 +619,7 @@ def _validate_args(args):
618619        if  args .enable_dynamic_shape :
619620            raise  ValueError (
620621                "Dynamic shape is not currently supported with torchao qmode. Please use --disable_dynamic_shape." 
622+                 "If you need this feature, please file an issue." 
621623            )
622624
623625
0 commit comments