| 
19 | 19 | from executorch.extension.llm.export.builder import DType  | 
20 | 20 | 
 
  | 
21 | 21 | from sentencepiece import SentencePieceProcessor  | 
 | 22 | +from torch.nn.modules import linear  | 
22 | 23 | 
 
  | 
23 | 24 | try:  | 
24 | 25 |     from fairseq2.nn.embedding import (  | 
@@ -75,33 +76,43 @@ def quantize(  # noqa C901  | 
75 | 76 |     elif qmode.startswith("torchao:"):  | 
76 | 77 |         import os  | 
77 | 78 |         import glob  | 
78 |  | -        libs = glob.glob(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../../cmake-out/lib/libtorchao_ops_aten.*")))  | 
 | 79 | +        libs = glob.glob(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../../cmake-out/third-party/ao/torchao/experimental/libtorchao_ops_aten.*")))  | 
79 | 80 |         assert len(libs) == 1, f"Expected 1 library but got {len(libs)}"  | 
80 | 81 |         logging.info(f"Loading custom ops library: {libs[0]}")  | 
81 | 82 |         torch.ops.load_library(libs[0])  | 
82 | 83 | 
 
  | 
83 | 84 |         logging.warning(  | 
84 | 85 |             "When qmode is torchao, the groupsize is obtained from the qmode string with regex parse; blocksize is ignored."  | 
85 | 86 |         )  | 
86 |  | -        linear_pattern = r"lin.8da(\d+)b(\d+)gw"  | 
 | 87 | +        embedding_pattern = r"emb.(\d+),(\d+)"  | 
 | 88 | +        linear_pattern = r"lin8da.(\d+),(\d+)"  | 
 | 89 | + | 
87 | 90 |         linear_matches = re.findall(linear_pattern, qmode)  | 
88 | 91 |         if linear_matches:  | 
 | 92 | +            assert len(linear_matches) == 1, f"Expected 1 match but got {len(linear_matches)}"  | 
89 | 93 |             bitwidth = int(linear_matches[0][0])  | 
90 |  | -            group_size = int(linear_matches[0][1])  | 
91 |  | -            from torchao.experimental.quant_api import Int8DynActIntxWeightQuantizer  | 
92 |  | - | 
93 |  | -            model = Int8DynActIntxWeightQuantizer(  | 
 | 94 | +            groupsize = int(linear_matches[0][1])  | 
 | 95 | +            from torchao.experimental.quant_api import Int8DynActIntxWeightLinearQuantizer  | 
 | 96 | +            model = Int8DynActIntxWeightLinearQuantizer(  | 
94 | 97 |                 device="cpu",  | 
95 | 98 |                 precision=torch_dtype,  | 
96 |  | -                groupsize=group_size,  | 
 | 99 | +                groupsize=groupsize,  | 
97 | 100 |                 bitwidth=bitwidth,  | 
98 | 101 |                 has_weight_zeros=False,  | 
99 | 102 |             ).quantize(model)  | 
100 |  | - | 
101 |  | -        embedding_pattern = r"emb.(\d+)b(\d+)gw"  | 
 | 103 | +          | 
102 | 104 |         embedding_matches = re.findall(embedding_pattern, qmode)  | 
103 | 105 |         if embedding_matches:  | 
104 |  | -            pass  # TODO: add when embedding PR lands in torchao  | 
 | 106 | +            assert len(embedding_matches) == 1, f"Expected 1 match but got {len(embedding_matches)}"  | 
 | 107 | +            bitwidth = int(embedding_matches[0][0])  | 
 | 108 | +            groupsize = int(embedding_matches[0][1])  | 
 | 109 | +            from torchao.experimental.quant_api import IntxWeightEmbeddingQuantizer  | 
 | 110 | +            model = IntxWeightEmbeddingQuantizer(  | 
 | 111 | +                device="cpu",  | 
 | 112 | +                precision=torch_dtype,  | 
 | 113 | +                bitwidth=bitwidth,  | 
 | 114 | +                groupsize=groupsize,  | 
 | 115 | +            ).quantize(model)  | 
105 | 116 | 
 
  | 
106 | 117 |         if verbose:  | 
107 | 118 |             print("quantized model:", model)  | 
 | 
0 commit comments