Skip to content

Commit 934ceb0

Browse files
committed
update export lib
1 parent 96056eb commit 934ceb0

File tree

1 file changed

+29
-16
lines changed

1 file changed

+29
-16
lines changed

examples/models/llama/export_llama_lib.py

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,22 @@ def build_model(
121121
return export_llama(modelname, args)
122122

123123

124+
def _is_valid_torchao_qmode_type(value):
125+
if not value.startswith("torchao:"):
126+
return False
127+
128+
patterns = [
129+
r"emb.(\d+),(\d+)&lin8da.(\d+),(\d+)",
130+
r"emb.(\d+),(\d+)",
131+
r"lin8da.(\d+),(\d+)",
132+
]
133+
for pattern in patterns:
134+
matches = re.findall(pattern, value)
135+
if len(matches) == 1:
136+
return True
137+
return False
138+
139+
124140
def build_args_parser() -> argparse.ArgumentParser:
125141
ckpt_dir = f"{Path(__file__).absolute().parent.as_posix()}"
126142
parser = argparse.ArgumentParser()
@@ -154,26 +170,17 @@ def build_args_parser() -> argparse.ArgumentParser:
154170
help="Use PT2E quantization. Comma separated options. e.g. xnnpack_dynamic (for per channel 8 bit weight), xnnpack_dynamic_qc4 (for per channel 4 bit weight), embedding.",
155171
)
156172

157-
def _is_valid_torchao_qmode_type(value):
158-
if not value.startswith("torchao:"):
159-
return False
160-
161-
patterns = [
162-
r"emb.(\d+),(\d+)&lin8da.(\d+),(\d+)",
163-
r"emb.(\d+),(\d+)",
164-
r"lin8da.(\d+),(\d+)",
165-
]
166-
for pattern in patterns:
167-
matches = re.findall(pattern, value)
168-
if len(matches) == 1:
169-
return True
170-
return False
171-
172173
def _qmode_type(value):
173174
choices = ["int8", "8da4w", "8da4w-gptq", "vulkan_4w"]
174175
if not (value in choices or _is_valid_torchao_qmode_type(value)):
175176
raise argparse.ArgumentTypeError(
176-
f"Value must be one of: {choices} or a valid torchao regex"
177+
f"Got qmode {value}, but expected one of: {choices} or a valid torchao quantization pattern such as:"
178+
+ "\n\t* torchao:emb.{embed_bitwidth},{embed_groupsize}"
179+
+ "\n\t\t (e.g., torchao:emb.4,32)"
180+
+ "\n\t* torchao:emb.{embed_bitwidth},{embed_groupsize}&lin8da.{linear_bitwidth},{linear_groupsize}"
181+
+ "\n\t\t (e.g., torchao:emb.4,32&lin8da.4,128)"
182+
+ "\n\t* torchao:lin8da.{linear_bitwidth},{linear_groupsize}"
183+
+ "\nt\t\t (e.g., torchao:lin8da.4,128)"
177184
)
178185
return value
179186

@@ -604,6 +611,12 @@ def _validate_args(args):
604611
if args.num_sharding > 0 and not args.qnn:
605612
raise ValueError("Model shard is only supported with qnn backend now.")
606613

614+
if _is_valid_torchao_qmode_type(args.quantization_mode):
615+
if args.enable_dynamic_shape:
616+
raise ValueError(
617+
"Dynamic shape is not currently supported with torchao qmode. Please use --disable_dynamic_shape."
618+
)
619+
607620

608621
def _export_llama(modelname, args) -> LLMEdgeManager: # noqa: C901
609622
_validate_args(args)

0 commit comments

Comments
 (0)