Skip to content

Commit 52f7a1f

Browse files
committed
up
1 parent c82f77b commit 52f7a1f

File tree

2 files changed

+32
-26
lines changed

2 files changed

+32
-26
lines changed

examples/models/llama2/export_llama_lib.py

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -156,22 +156,17 @@ def build_args_parser() -> argparse.ArgumentParser:
156156
def _is_valid_torchao_qmode_type(value):
157157
if not value.startswith("torchao:"):
158158
return False
159-
160-
linear_pattern = r"lin.8da(\d+)b(\d+)gw"
161-
linear_matches = re.findall(linear_pattern, value)
162-
print("LINEAR MATCHES", linear_matches)
163-
164-
if len(linear_matches) > 1:
165-
return False
166-
167-
embedding_pattern = r"emb.(\d+)b(\d+)gw"
168-
embedding_matches = re.findall(embedding_pattern, value)
169-
print("EMBEDDING MATCHES", embedding_matches)
170-
if len(embedding_matches) > 1:
171-
return False
172-
if len(linear_matches) + len(embedding_matches) == 0:
173-
return False
174-
return True
159+
160+
patterns = [
161+
r"emb.(\d+),(\d+)&lin8da.(\d+),(\d+)",
162+
r"emb.(\d+),(\d+)",
163+
r"lin8da.(\d+),(\d+)",
164+
]
165+
for pattern in patterns:
166+
matches = re.findall(pattern, value)
167+
if len(matches) == 1:
168+
return True
169+
return False
175170

176171
def _qmode_type(value):
177172
choices = ["int8", "8da4w", "8da4w-gptq"]

examples/models/llama2/source_transformation/quantize.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from executorch.extension.llm.export.builder import DType
1818

1919
from sentencepiece import SentencePieceProcessor
20+
from torch.nn.modules import linear
2021

2122
try:
2223
from fairseq2.nn.embedding import (
@@ -73,33 +74,43 @@ def quantize( # noqa: C901
7374
elif qmode.startswith("torchao:"):
7475
import os
7576
import glob
76-
libs = glob.glob(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../../cmake-out/lib/libtorchao_ops_aten.*")))
77+
libs = glob.glob(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../../cmake-out/third-party/ao/torchao/experimental/libtorchao_ops_aten.*")))
7778
assert len(libs) == 1, f"Expected 1 library but got {len(libs)}"
7879
logging.info(f"Loading custom ops library: {libs[0]}")
7980
torch.ops.load_library(libs[0])
8081

8182
logging.warning(
8283
"When qmode is torchao, the groupsize is obtained from the qmode string with regex parse; blocksize is ignored."
8384
)
84-
linear_pattern = r"lin.8da(\d+)b(\d+)gw"
85+
embedding_pattern = r"emb.(\d+),(\d+)"
86+
linear_pattern = r"lin8da.(\d+),(\d+)"
87+
8588
linear_matches = re.findall(linear_pattern, qmode)
8689
if linear_matches:
90+
assert len(linear_matches) == 1, f"Expected 1 match but got {len(linear_matches)}"
8791
bitwidth = int(linear_matches[0][0])
88-
group_size = int(linear_matches[0][1])
89-
from torchao.experimental.quant_api import Int8DynActIntxWeightQuantizer
90-
91-
model = Int8DynActIntxWeightQuantizer(
92+
groupsize = int(linear_matches[0][1])
93+
from torchao.experimental.quant_api import Int8DynActIntxWeightLinearQuantizer
94+
model = Int8DynActIntxWeightLinearQuantizer(
9295
device="cpu",
9396
precision=torch_dtype,
94-
groupsize=group_size,
97+
groupsize=groupsize,
9598
bitwidth=bitwidth,
9699
has_weight_zeros=False,
97100
).quantize(model)
98-
99-
embedding_pattern = r"emb.(\d+)b(\d+)gw"
101+
100102
embedding_matches = re.findall(embedding_pattern, qmode)
101103
if embedding_matches:
102-
pass # TODO: add when embedding PR lands in torchao
104+
assert len(embedding_matches) == 1, f"Expected 1 match but got {len(embedding_matches)}"
105+
bitwidth = int(embedding_matches[0][0])
106+
groupsize = int(embedding_matches[0][1])
107+
from torchao.experimental.quant_api import IntxWeightEmbeddingQuantizer
108+
model = IntxWeightEmbeddingQuantizer(
109+
device="cpu",
110+
precision=torch_dtype,
111+
bitwidth=bitwidth,
112+
groupsize=groupsize,
113+
).quantize(model)
103114

104115
if verbose:
105116
print("quantized model:", model)

0 commit comments

Comments
 (0)