|
19 | 19 | from executorch.extension.llm.export.builder import DType |
20 | 20 |
|
21 | 21 | from sentencepiece import SentencePieceProcessor |
| 22 | +from torch.nn.modules import linear |
22 | 23 |
|
23 | 24 | try: |
24 | 25 | from fairseq2.nn.embedding import ( |
@@ -75,33 +76,43 @@ def quantize( # noqa C901 |
75 | 76 | elif qmode.startswith("torchao:"): |
76 | 77 | import os |
77 | 78 | import glob |
78 | | - libs = glob.glob(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../../cmake-out/lib/libtorchao_ops_aten.*"))) |
| 79 | + libs = glob.glob(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../../cmake-out/third-party/ao/torchao/experimental/libtorchao_ops_aten.*"))) |
79 | 80 | assert len(libs) == 1, f"Expected 1 library but got {len(libs)}" |
80 | 81 | logging.info(f"Loading custom ops library: {libs[0]}") |
81 | 82 | torch.ops.load_library(libs[0]) |
82 | 83 |
|
83 | 84 | logging.warning( |
84 | 85 | "When qmode is torchao, the groupsize is obtained from the qmode string with regex parse; blocksize is ignored." |
85 | 86 | ) |
86 | | - linear_pattern = r"lin.8da(\d+)b(\d+)gw" |
| 87 | + embedding_pattern = r"emb.(\d+),(\d+)" |
| 88 | + linear_pattern = r"lin8da.(\d+),(\d+)" |
| 89 | + |
87 | 90 | linear_matches = re.findall(linear_pattern, qmode) |
88 | 91 | if linear_matches: |
| 92 | + assert len(linear_matches) == 1, f"Expected 1 match but got {len(linear_matches)}" |
89 | 93 | bitwidth = int(linear_matches[0][0]) |
90 | | - group_size = int(linear_matches[0][1]) |
91 | | - from torchao.experimental.quant_api import Int8DynActIntxWeightQuantizer |
92 | | - |
93 | | - model = Int8DynActIntxWeightQuantizer( |
| 94 | + groupsize = int(linear_matches[0][1]) |
| 95 | + from torchao.experimental.quant_api import Int8DynActIntxWeightLinearQuantizer |
| 96 | + model = Int8DynActIntxWeightLinearQuantizer( |
94 | 97 | device="cpu", |
95 | 98 | precision=torch_dtype, |
96 | | - groupsize=group_size, |
| 99 | + groupsize=groupsize, |
97 | 100 | bitwidth=bitwidth, |
98 | 101 | has_weight_zeros=False, |
99 | 102 | ).quantize(model) |
100 | | - |
101 | | - embedding_pattern = r"emb.(\d+)b(\d+)gw" |
| 103 | + |
102 | 104 | embedding_matches = re.findall(embedding_pattern, qmode) |
103 | 105 | if embedding_matches: |
104 | | - pass # TODO: add when embedding PR lands in torchao |
| 106 | + assert len(embedding_matches) == 1, f"Expected 1 match but got {len(embedding_matches)}" |
| 107 | + bitwidth = int(embedding_matches[0][0]) |
| 108 | + groupsize = int(embedding_matches[0][1]) |
| 109 | + from torchao.experimental.quant_api import IntxWeightEmbeddingQuantizer |
| 110 | + model = IntxWeightEmbeddingQuantizer( |
| 111 | + device="cpu", |
| 112 | + precision=torch_dtype, |
| 113 | + bitwidth=bitwidth, |
| 114 | + groupsize=groupsize, |
| 115 | + ).quantize(model) |
105 | 116 |
|
106 | 117 | if verbose: |
107 | 118 | print("quantized model:", model) |
|
0 commit comments