Skip to content

Commit 65d71ba

Browse files
committed
up
1 parent af6f818 commit 65d71ba

File tree

1 file changed

+13
-32
lines changed

1 file changed

+13
-32
lines changed

examples/models/llama/export_llama_lib.py

Lines changed: 13 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -568,39 +568,20 @@ def get_quantizer_and_quant_params(args):
568568
return pt2e_quant_params, quantizers, quant_dtype
569569

570570

571-
def _is_valid_torchao_qmode_type(value):
572-
if not isinstance(value, str):
573-
return False
574-
575-
if not value.startswith("torchao:"):
576-
return False
577-
578-
patterns = [
579-
r"emb.(\d+),(\d+)&lin8da.(\d+),(\d+)",
580-
r"emb.(\d+),(\d+)",
581-
r"lin8da.(\d+),(\d+)",
582-
]
583-
for pattern in patterns:
584-
matches = re.findall(pattern, value)
585-
if len(matches) == 1:
586-
return True
587-
return False
588-
589-
590571
def _qmode_type(value):
591572
choices = ["int8", "8da4w", "8da4w-gptq", "vulkan_4w"]
592-
if not (value in choices or _is_valid_torchao_qmode_type(value)):
593-
raise argparse.ArgumentTypeError(
594-
f"Got qmode {value}, but expected one of: {choices} or a valid torchao quantization pattern such as:"
595-
+ "\n\t* torchao:emb.{embed_bitwidth},{embed_groupsize}"
596-
+ "\n\t\t (e.g., torchao:emb.4,32)"
597-
+ "\n\t* torchao:emb.{embed_bitwidth},{embed_groupsize}&lin8da.{linear_bitwidth},{linear_groupsize}"
598-
+ "\n\t\t (e.g., torchao:emb.4,32&lin8da.4,128)"
599-
+ "\n\t* torchao:lin8da.{linear_bitwidth},{linear_groupsize}"
600-
+ "\nt\t\t (e.g., torchao:lin8da.4,128)"
601-
)
602-
return value
573+
patterns = [r"torchao:8da{\d+}w"]
603574

575+
if value in choices:
576+
return value
577+
578+
for pattern in patterns:
579+
matches = re.findall(pattern, value)
580+
if len(matches) == 1:
581+
return value
582+
raise argparse.ArgumentTypeError(
583+
f"Got qmode {value}, but expected one of {choices}, or one of the regex patterns {patterns}."
584+
)
604585

605586
def _validate_args(args):
606587
"""
@@ -615,10 +596,10 @@ def _validate_args(args):
615596
if args.num_sharding > 0 and not args.qnn:
616597
raise ValueError("Model shard is only supported with qnn backend now.")
617598

618-
if _is_valid_torchao_qmode_type(args.quantization_mode):
599+
if args.quantization_mode.startswith("torchao:") or args.embedding_quantize.startswith("torchao:"):
619600
if args.enable_dynamic_shape:
620601
raise ValueError(
621-
"Dynamic shape is not currently supported with torchao qmode. Please use --disable_dynamic_shape."
602+
"Dynamic shape is not currently supported with torchao ops. Please use --disable_dynamic_shape."
622603
"If you need this feature, please file an issue."
623604
)
624605

0 commit comments

Comments
 (0)