|  | 
| 19 | 19 | from executorch.extension.llm.export.builder import DType | 
| 20 | 20 | 
 | 
| 21 | 21 | from sentencepiece import SentencePieceProcessor | 
|  | 22 | +from torch.nn.modules import linear | 
| 22 | 23 | 
 | 
| 23 | 24 | try: | 
| 24 | 25 |     from fairseq2.nn.embedding import ( | 
| @@ -75,33 +76,43 @@ def quantize(  # noqa C901 | 
| 75 | 76 |     elif qmode.startswith("torchao:"): | 
| 76 | 77 |         import os | 
| 77 | 78 |         import glob | 
| 78 |  | -        libs = glob.glob(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../../cmake-out/lib/libtorchao_ops_aten.*"))) | 
|  | 79 | +        libs = glob.glob(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../../cmake-out/third-party/ao/torchao/experimental/libtorchao_ops_aten.*"))) | 
| 79 | 80 |         assert len(libs) == 1, f"Expected 1 library but got {len(libs)}" | 
| 80 | 81 |         logging.info(f"Loading custom ops library: {libs[0]}") | 
| 81 | 82 |         torch.ops.load_library(libs[0]) | 
| 82 | 83 | 
 | 
| 83 | 84 |         logging.warning( | 
| 84 | 85 |             "When qmode is torchao, the groupsize is obtained from the qmode string with regex parse; blocksize is ignored." | 
| 85 | 86 |         ) | 
| 86 |  | -        linear_pattern = r"lin.8da(\d+)b(\d+)gw" | 
|  | 87 | +        embedding_pattern = r"emb.(\d+),(\d+)" | 
|  | 88 | +        linear_pattern = r"lin8da.(\d+),(\d+)" | 
|  | 89 | + | 
| 87 | 90 |         linear_matches = re.findall(linear_pattern, qmode) | 
| 88 | 91 |         if linear_matches: | 
|  | 92 | +            assert len(linear_matches) == 1, f"Expected 1 match but got {len(linear_matches)}" | 
| 89 | 93 |             bitwidth = int(linear_matches[0][0]) | 
| 90 |  | -            group_size = int(linear_matches[0][1]) | 
| 91 |  | -            from torchao.experimental.quant_api import Int8DynActIntxWeightQuantizer | 
| 92 |  | - | 
| 93 |  | -            model = Int8DynActIntxWeightQuantizer( | 
|  | 94 | +            groupsize = int(linear_matches[0][1]) | 
|  | 95 | +            from torchao.experimental.quant_api import Int8DynActIntxWeightLinearQuantizer | 
|  | 96 | +            model = Int8DynActIntxWeightLinearQuantizer( | 
| 94 | 97 |                 device="cpu", | 
| 95 | 98 |                 precision=torch_dtype, | 
| 96 |  | -                groupsize=group_size, | 
|  | 99 | +                groupsize=groupsize, | 
| 97 | 100 |                 bitwidth=bitwidth, | 
| 98 | 101 |                 has_weight_zeros=False, | 
| 99 | 102 |             ).quantize(model) | 
| 100 |  | - | 
| 101 |  | -        embedding_pattern = r"emb.(\d+)b(\d+)gw" | 
|  | 103 | +         | 
| 102 | 104 |         embedding_matches = re.findall(embedding_pattern, qmode) | 
| 103 | 105 |         if embedding_matches: | 
| 104 |  | -            pass  # TODO: add when embedding PR lands in torchao | 
|  | 106 | +            assert len(embedding_matches) == 1, f"Expected 1 match but got {len(embedding_matches)}" | 
|  | 107 | +            bitwidth = int(embedding_matches[0][0]) | 
|  | 108 | +            groupsize = int(embedding_matches[0][1]) | 
|  | 109 | +            from torchao.experimental.quant_api import IntxWeightEmbeddingQuantizer | 
|  | 110 | +            model = IntxWeightEmbeddingQuantizer( | 
|  | 111 | +                device="cpu", | 
|  | 112 | +                precision=torch_dtype, | 
|  | 113 | +                bitwidth=bitwidth, | 
|  | 114 | +                groupsize=groupsize, | 
|  | 115 | +            ).quantize(model) | 
| 105 | 116 | 
 | 
| 106 | 117 |         if verbose: | 
| 107 | 118 |             print("quantized model:", model) | 
|  | 
0 commit comments