Skip to content

Commit 9f955c9

Browse files
committed
up
1 parent e35bd21 commit 9f955c9

File tree

2 files changed

+40
-36
lines changed

2 files changed

+40
-36
lines changed

examples/models/llama/export_llama_lib.py

Lines changed: 35 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -121,25 +121,6 @@ def build_model(
121121
return export_llama(modelname, args)
122122

123123

124-
def _is_valid_torchao_qmode_type(value):
125-
if not isinstance(value, str):
126-
return False
127-
128-
if not value.startswith("torchao:"):
129-
return False
130-
131-
patterns = [
132-
r"emb.(\d+),(\d+)&lin8da.(\d+),(\d+)",
133-
r"emb.(\d+),(\d+)",
134-
r"lin8da.(\d+),(\d+)",
135-
]
136-
for pattern in patterns:
137-
matches = re.findall(pattern, value)
138-
if len(matches) == 1:
139-
return True
140-
return False
141-
142-
143124
def build_args_parser() -> argparse.ArgumentParser:
144125
ckpt_dir = f"{Path(__file__).absolute().parent.as_posix()}"
145126
parser = argparse.ArgumentParser()
@@ -173,20 +154,6 @@ def build_args_parser() -> argparse.ArgumentParser:
173154
help="Use PT2E quantization. Comma separated options. e.g. xnnpack_dynamic (for per channel 8 bit weight), xnnpack_dynamic_qc4 (for per channel 4 bit weight), embedding.",
174155
)
175156

176-
def _qmode_type(value):
177-
choices = ["int8", "8da4w", "8da4w-gptq", "vulkan_4w"]
178-
if not (value in choices or _is_valid_torchao_qmode_type(value)):
179-
raise argparse.ArgumentTypeError(
180-
f"Got qmode {value}, but expected one of: {choices} or a valid torchao quantization pattern such as:"
181-
+ "\n\t* torchao:emb.{embed_bitwidth},{embed_groupsize}"
182-
+ "\n\t\t (e.g., torchao:emb.4,32)"
183-
+ "\n\t* torchao:emb.{embed_bitwidth},{embed_groupsize}&lin8da.{linear_bitwidth},{linear_groupsize}"
184-
+ "\n\t\t (e.g., torchao:emb.4,32&lin8da.4,128)"
185-
+ "\n\t* torchao:lin8da.{linear_bitwidth},{linear_groupsize}"
186-
+ "\nt\t\t (e.g., torchao:lin8da.4,128)"
187-
)
188-
return value
189-
190157
parser.add_argument(
191158
"-qmode",
192159
"--quantization_mode",
@@ -601,6 +568,40 @@ def get_quantizer_and_quant_params(args):
601568
return pt2e_quant_params, quantizers, quant_dtype
602569

603570

571+
def _is_valid_torchao_qmode_type(value):
572+
if not isinstance(value, str):
573+
return False
574+
575+
if not value.startswith("torchao:"):
576+
return False
577+
578+
patterns = [
579+
r"emb.(\d+),(\d+)&lin8da.(\d+),(\d+)",
580+
r"emb.(\d+),(\d+)",
581+
r"lin8da.(\d+),(\d+)",
582+
]
583+
for pattern in patterns:
584+
matches = re.findall(pattern, value)
585+
if len(matches) == 1:
586+
return True
587+
return False
588+
589+
590+
def _qmode_type(value):
591+
choices = ["int8", "8da4w", "8da4w-gptq", "vulkan_4w"]
592+
if not (value in choices or _is_valid_torchao_qmode_type(value)):
593+
raise argparse.ArgumentTypeError(
594+
f"Got qmode {value}, but expected one of: {choices} or a valid torchao quantization pattern such as:"
595+
+ "\n\t* torchao:emb.{embed_bitwidth},{embed_groupsize}"
596+
+ "\n\t\t (e.g., torchao:emb.4,32)"
597+
+ "\n\t* torchao:emb.{embed_bitwidth},{embed_groupsize}&lin8da.{linear_bitwidth},{linear_groupsize}"
598+
+ "\n\t\t (e.g., torchao:emb.4,32&lin8da.4,128)"
599+
+ "\n\t* torchao:lin8da.{linear_bitwidth},{linear_groupsize}"
600+
+ "\nt\t\t (e.g., torchao:lin8da.4,128)"
601+
)
602+
return value
603+
604+
604605
def _validate_args(args):
605606
"""
606607
TODO: Combine all the backends under --backend args
@@ -618,6 +619,7 @@ def _validate_args(args):
618619
if args.enable_dynamic_shape:
619620
raise ValueError(
620621
"Dynamic shape is not currently supported with torchao qmode. Please use --disable_dynamic_shape."
622+
"If you need this feature, please file an issue."
621623
)
622624

623625

examples/models/llama/source_transformation/quantize.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,12 +79,14 @@ def quantize( # noqa C901
7979
libs = glob.glob(
8080
os.path.abspath(
8181
os.path.join(
82-
os.path.dirname(__file__),
83-
"../../../../cmake-out/third-party/ao/torchao/experimental/libtorchao_ops_aten.*",
82+
os.environ.get("CMAKE_INSTALL_PREFIX", ""),
83+
"lib/libtorchao_ops_aten.*",
8484
)
8585
)
8686
)
87-
assert len(libs) == 1, f"Expected 1 library but got {len(libs)}"
87+
assert (
88+
len(libs) == 1
89+
), f"Expected 1 library but got {len(libs)}. If you installed the torchao ops in a non-standard location, please set CMAKE_INSTALL_PREFIX correctly."
8890
logging.info(f"Loading custom ops library: {libs[0]}")
8991
torch.ops.load_library(libs[0])
9092

0 commit comments

Comments
 (0)