|  | 
| 17 | 17 | from executorch.extension.llm.export.builder import DType | 
| 18 | 18 | 
 | 
| 19 | 19 | from sentencepiece import SentencePieceProcessor | 
|  | 20 | +from torch.nn.modules import linear | 
| 20 | 21 | 
 | 
| 21 | 22 | try: | 
| 22 | 23 |     from fairseq2.nn.embedding import ( | 
| @@ -73,33 +74,43 @@ def quantize(  # noqa: C901 | 
| 73 | 74 |     elif qmode.startswith("torchao:"): | 
| 74 | 75 |         import os | 
| 75 | 76 |         import glob | 
| 76 |  | -        libs = glob.glob(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../../cmake-out/lib/libtorchao_ops_aten.*"))) | 
|  | 77 | +        libs = glob.glob(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../../cmake-out/third-party/ao/torchao/experimental/libtorchao_ops_aten.*"))) | 
| 77 | 78 |         assert len(libs) == 1, f"Expected 1 library but got {len(libs)}" | 
| 78 | 79 |         logging.info(f"Loading custom ops library: {libs[0]}") | 
| 79 | 80 |         torch.ops.load_library(libs[0]) | 
| 80 | 81 | 
 | 
| 81 | 82 |         logging.warning( | 
| 82 | 83 |             "When qmode is torchao, the groupsize is obtained from the qmode string with regex parse; blocksize is ignored." | 
| 83 | 84 |         ) | 
| 84 |  | -        linear_pattern = r"lin.8da(\d+)b(\d+)gw" | 
|  | 85 | +        embedding_pattern = r"emb.(\d+),(\d+)" | 
|  | 86 | +        linear_pattern = r"lin8da.(\d+),(\d+)" | 
|  | 87 | + | 
| 85 | 88 |         linear_matches = re.findall(linear_pattern, qmode) | 
| 86 | 89 |         if linear_matches: | 
|  | 90 | +            assert len(linear_matches) == 1, f"Expected 1 match but got {len(linear_matches)}" | 
| 87 | 91 |             bitwidth = int(linear_matches[0][0]) | 
| 88 |  | -            group_size = int(linear_matches[0][1]) | 
| 89 |  | -            from torchao.experimental.quant_api import Int8DynActIntxWeightQuantizer | 
| 90 |  | - | 
| 91 |  | -            model = Int8DynActIntxWeightQuantizer( | 
|  | 92 | +            groupsize = int(linear_matches[0][1]) | 
|  | 93 | +            from torchao.experimental.quant_api import Int8DynActIntxWeightLinearQuantizer | 
|  | 94 | +            model = Int8DynActIntxWeightLinearQuantizer( | 
| 92 | 95 |                 device="cpu", | 
| 93 | 96 |                 precision=torch_dtype, | 
| 94 |  | -                groupsize=group_size, | 
|  | 97 | +                groupsize=groupsize, | 
| 95 | 98 |                 bitwidth=bitwidth, | 
| 96 | 99 |                 has_weight_zeros=False, | 
| 97 | 100 |             ).quantize(model) | 
| 98 |  | - | 
| 99 |  | -        embedding_pattern = r"emb.(\d+)b(\d+)gw" | 
|  | 101 | +         | 
| 100 | 102 |         embedding_matches = re.findall(embedding_pattern, qmode) | 
| 101 | 103 |         if embedding_matches: | 
| 102 |  | -            pass  # TODO: add when embedding PR lands in torchao | 
|  | 104 | +            assert len(embedding_matches) == 1, f"Expected 1 match but got {len(embedding_matches)}" | 
|  | 105 | +            bitwidth = int(embedding_matches[0][0]) | 
|  | 106 | +            groupsize = int(embedding_matches[0][1]) | 
|  | 107 | +            from torchao.experimental.quant_api import IntxWeightEmbeddingQuantizer | 
|  | 108 | +            model = IntxWeightEmbeddingQuantizer( | 
|  | 109 | +                device="cpu", | 
|  | 110 | +                precision=torch_dtype, | 
|  | 111 | +                bitwidth=bitwidth, | 
|  | 112 | +                groupsize=groupsize, | 
|  | 113 | +            ).quantize(model) | 
| 103 | 114 | 
 | 
| 104 | 115 |         if verbose: | 
| 105 | 116 |             print("quantized model:", model) | 
|  | 
0 commit comments