Skip to content

Commit d4f6cd2

Browse files
committed
up
1 parent a24487d commit d4f6cd2

File tree

2 files changed

+32
-26
lines changed

2 files changed

+32
-26
lines changed

examples/models/llama/export_llama_lib.py

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -157,22 +157,17 @@ def build_args_parser() -> argparse.ArgumentParser:
157157
def _is_valid_torchao_qmode_type(value):
158158
if not value.startswith("torchao:"):
159159
return False
160-
161-
linear_pattern = r"lin.8da(\d+)b(\d+)gw"
162-
linear_matches = re.findall(linear_pattern, value)
163-
print("LINEAR MATCHES", linear_matches)
164-
165-
if len(linear_matches) > 1:
166-
return False
167-
168-
embedding_pattern = r"emb.(\d+)b(\d+)gw"
169-
embedding_matches = re.findall(embedding_pattern, value)
170-
print("EMBEDDING MATCHES", embedding_matches)
171-
if len(embedding_matches) > 1:
172-
return False
173-
if len(linear_matches) + len(embedding_matches) == 0:
174-
return False
175-
return True
160+
161+
patterns = [
162+
r"emb.(\d+),(\d+)&lin8da.(\d+),(\d+)",
163+
r"emb.(\d+),(\d+)",
164+
r"lin8da.(\d+),(\d+)",
165+
]
166+
for pattern in patterns:
167+
matches = re.findall(pattern, value)
168+
if len(matches) == 1:
169+
return True
170+
return False
176171

177172
def _qmode_type(value):
178173
choices = ["int8", "8da4w", "8da4w-gptq", "vulkan_4w"]

examples/models/llama/source_transformation/quantize.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from executorch.extension.llm.export.builder import DType
2020

2121
from sentencepiece import SentencePieceProcessor
22+
from torch.nn.modules import linear
2223

2324
try:
2425
from fairseq2.nn.embedding import (
@@ -75,33 +76,43 @@ def quantize( # noqa C901
7576
elif qmode.startswith("torchao:"):
7677
import os
7778
import glob
78-
libs = glob.glob(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../../cmake-out/lib/libtorchao_ops_aten.*")))
79+
libs = glob.glob(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../../cmake-out/third-party/ao/torchao/experimental/libtorchao_ops_aten.*")))
7980
assert len(libs) == 1, f"Expected 1 library but got {len(libs)}"
8081
logging.info(f"Loading custom ops library: {libs[0]}")
8182
torch.ops.load_library(libs[0])
8283

8384
logging.warning(
8485
"When qmode is torchao, the groupsize is obtained from the qmode string with regex parse; blocksize is ignored."
8586
)
86-
linear_pattern = r"lin.8da(\d+)b(\d+)gw"
87+
embedding_pattern = r"emb.(\d+),(\d+)"
88+
linear_pattern = r"lin8da.(\d+),(\d+)"
89+
8790
linear_matches = re.findall(linear_pattern, qmode)
8891
if linear_matches:
92+
assert len(linear_matches) == 1, f"Expected 1 match but got {len(linear_matches)}"
8993
bitwidth = int(linear_matches[0][0])
90-
group_size = int(linear_matches[0][1])
91-
from torchao.experimental.quant_api import Int8DynActIntxWeightQuantizer
92-
93-
model = Int8DynActIntxWeightQuantizer(
94+
groupsize = int(linear_matches[0][1])
95+
from torchao.experimental.quant_api import Int8DynActIntxWeightLinearQuantizer
96+
model = Int8DynActIntxWeightLinearQuantizer(
9497
device="cpu",
9598
precision=torch_dtype,
96-
groupsize=group_size,
99+
groupsize=groupsize,
97100
bitwidth=bitwidth,
98101
has_weight_zeros=False,
99102
).quantize(model)
100-
101-
embedding_pattern = r"emb.(\d+)b(\d+)gw"
103+
102104
embedding_matches = re.findall(embedding_pattern, qmode)
103105
if embedding_matches:
104-
pass # TODO: add when embedding PR lands in torchao
106+
assert len(embedding_matches) == 1, f"Expected 1 match but got {len(embedding_matches)}"
107+
bitwidth = int(embedding_matches[0][0])
108+
groupsize = int(embedding_matches[0][1])
109+
from torchao.experimental.quant_api import IntxWeightEmbeddingQuantizer
110+
model = IntxWeightEmbeddingQuantizer(
111+
device="cpu",
112+
precision=torch_dtype,
113+
bitwidth=bitwidth,
114+
groupsize=groupsize,
115+
).quantize(model)
105116

106117
if verbose:
107118
print("quantized model:", model)

0 commit comments

Comments
 (0)