@@ -121,6 +121,22 @@ def build_model(
121121 return export_llama (modelname , args )
122122
123123
124+ def _is_valid_torchao_qmode_type (value ):
125+ if not value .startswith ("torchao:" ):
126+ return False
127+
128+ patterns = [
129+ r"emb.(\d+),(\d+)&lin8da.(\d+),(\d+)" ,
130+ r"emb.(\d+),(\d+)" ,
131+ r"lin8da.(\d+),(\d+)" ,
132+ ]
133+ for pattern in patterns :
134+ matches = re .findall (pattern , value )
135+ if len (matches ) == 1 :
136+ return True
137+ return False
138+
139+
124140def build_args_parser () -> argparse .ArgumentParser :
125141 ckpt_dir = f"{ Path (__file__ ).absolute ().parent .as_posix ()} "
126142 parser = argparse .ArgumentParser ()
@@ -154,26 +170,17 @@ def build_args_parser() -> argparse.ArgumentParser:
154170 help = "Use PT2E quantization. Comma separated options. e.g. xnnpack_dynamic (for per channel 8 bit weight), xnnpack_dynamic_qc4 (for per channel 4 bit weight), embedding." ,
155171 )
156172
157- def _is_valid_torchao_qmode_type (value ):
158- if not value .startswith ("torchao:" ):
159- return False
160-
161- patterns = [
162- r"emb.(\d+),(\d+)&lin8da.(\d+),(\d+)" ,
163- r"emb.(\d+),(\d+)" ,
164- r"lin8da.(\d+),(\d+)" ,
165- ]
166- for pattern in patterns :
167- matches = re .findall (pattern , value )
168- if len (matches ) == 1 :
169- return True
170- return False
171-
172173 def _qmode_type (value ):
173174 choices = ["int8" , "8da4w" , "8da4w-gptq" , "vulkan_4w" ]
174175 if not (value in choices or _is_valid_torchao_qmode_type (value )):
175176 raise argparse .ArgumentTypeError (
176- f"Value must be one of: { choices } or a valid torchao regex"
177+ f"Got qmode { value } , but expected one of: { choices } or a valid torchao quantization pattern such as:"
178+ + "\n \t * torchao:emb.{embed_bitwidth},{embed_groupsize}"
179+ + "\n \t \t (e.g., torchao:emb.4,32)"
180+ + "\n \t * torchao:emb.{embed_bitwidth},{embed_groupsize}&lin8da.{linear_bitwidth},{linear_groupsize}"
181+ + "\n \t \t (e.g., torchao:emb.4,32&lin8da.4,128)"
182+ + "\n \t * torchao:lin8da.{linear_bitwidth},{linear_groupsize}"
183+ + "\n t\t \t (e.g., torchao:lin8da.4,128)"
177184 )
178185 return value
179186
@@ -604,6 +611,12 @@ def _validate_args(args):
604611 if args .num_sharding > 0 and not args .qnn :
605612 raise ValueError ("Model shard is only supported with qnn backend now." )
606613
614+ if _is_valid_torchao_qmode_type (args .quantization_mode ):
615+ if args .enable_dynamic_shape :
616+ raise ValueError (
617+ "Dynamic shape is not currently supported with torchao qmode. Please use --disable_dynamic_shape."
618+ )
619+
607620
608621def _export_llama (modelname , args ) -> LLMEdgeManager : # noqa: C901
609622 _validate_args (args )
0 commit comments