|
38 | 38 | from ...portable.utils import export_to_edge, save_pte_program |
39 | 39 | from ..model_factory import EagerModelFactory |
40 | 40 | from .model import ModelArgs |
41 | | -from .quantize import EmbeddingOnlyInt8QuantHandler, WeightOnlyInt8QuantHandler |
| 41 | +from .quantize import ( |
| 42 | + EmbeddingOnlyInt8QuantHandler, |
| 43 | + Int8DynActInt4WeightQuantHandler, |
| 44 | + WeightOnlyInt8QuantHandler, |
| 45 | +) |
| 46 | + |
42 | 47 |
|
43 | 48 | IS_FBCODE = True # os.environ.get("FBCODE_PLATFORM", False) |
44 | 49 | FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" |
@@ -104,19 +109,30 @@ def apply_pt2e_quantization( |
104 | 109 | return m |
105 | 110 |
|
106 | 111 |
|
107 | | -def quantize(model) -> torch.nn.Module: |
| 112 | +def quantize(model: torch.nn.Module, qmode: str) -> torch.nn.Module: |
108 | 113 | """ |
109 | 114 | Quantizes a model by converting all weights to int8. |
110 | 115 | Args: |
111 | 116 | model: A model to quantize. |
| 117 | + qmode: quantization mode, e.g. int8, int4 |
112 | 118 | Returns: |
113 | 119 | A quantized model. |
114 | 120 | """ |
115 | | - model_int8 = WeightOnlyInt8QuantHandler(model) |
116 | | - model_int8_state_dict = model_int8.create_quantized_state_dict() |
117 | | - model_int8 = model_int8.convert_for_runtime() |
118 | | - model_int8.load_state_dict(model_int8_state_dict) |
119 | | - return model_int8 |
| 121 | + if qmode == "int8": |
| 122 | + model_int8 = WeightOnlyInt8QuantHandler(model) |
| 123 | + model_int8_state_dict = model_int8.create_quantized_state_dict() |
| 124 | + model_int8 = model_int8.convert_for_runtime() |
| 125 | + model_int8.load_state_dict(model_int8_state_dict) |
| 126 | + return model_int8 |
| 127 | + elif qmode == "int4": |
| 128 | + model_int4 = Int8DynActInt4WeightQuantHandler(model) |
| 129 | + model_int4_state_dict = model_int4.create_quantized_state_dict() |
| 130 | + model_int4 = model_int4.convert_for_runtime() |
| 131 | + print("quantized model:", model_int4) |
| 132 | + model_int4.load_state_dict(model_int4_state_dict) |
| 133 | + return model_int4 |
| 134 | + else: |
| 135 | + raise Exception(f"Unrecognized quantize mode: {qmode}") |
120 | 136 |
|
121 | 137 |
|
122 | 138 | def build_model( |
@@ -145,13 +161,20 @@ def build_args_parser() -> argparse.ArgumentParser: |
145 | 161 | parser.add_argument( |
146 | 162 | "-q", "--quantized_ckpt", default=None, help="quantized checkpoint file" |
147 | 163 | ) |
148 | | - parser.add_argument("-Q", "--quantize", default=None, action="store_true") |
149 | 164 | parser.add_argument("-E", "--embedding-quantize", default=None, action="store_true") |
150 | 165 | parser.add_argument( |
151 | | - "--pt2_quantize", |
| 166 | + "--pt2e_quantize", |
152 | 167 | default=None, |
153 | 168 | help="Use PT2E quantization. Comma separated options. e.g. xnnpack_dynamic, embedding.", |
154 | 169 | ) |
| 170 | + parser.add_argument( |
| 171 | + "-qmode", |
| 172 | + "--quantization_mode", |
| 173 | + type=str, |
| 174 | + default=None, |
| 175 | + choices=["int8", "int4"], |
| 176 | + help="type of quantization", |
| 177 | + ) |
155 | 178 |
|
156 | 179 | parser.add_argument( |
157 | 180 | "-c", |
@@ -181,6 +204,7 @@ def build_args_parser() -> argparse.ArgumentParser: |
181 | 204 | parser.add_argument( |
182 | 205 | "-s", |
183 | 206 | "--so_library", |
| 207 | + default=None, |
184 | 208 | required=False, |
185 | 209 | help="shared library for quantized operators", |
186 | 210 | ) |
@@ -253,12 +277,12 @@ def get_metadata(params: ModelArgs, args: argparse.Namespace) -> Dict[str, Any]: |
253 | 277 |
|
254 | 278 |
|
255 | 279 | def _get_quantization_options(args): |
256 | | - if args.pt2_quantize is None: |
| 280 | + if args.pt2e_quantize is None: |
257 | 281 | return [] |
258 | | - if args.quantize: |
259 | | - raise ValueError("Cannot specify both --quantize and --pt2_quantize") |
| 282 | + if args.quantization_mode: |
| 283 | + raise ValueError("Cannot specify both --quantization_mode and --pt2e_quantize") |
260 | 284 |
|
261 | | - quantization_options = args.pt2_quantize.split(",") |
| 285 | + quantization_options = args.pt2e_quantize.split(",") |
262 | 286 | quantization_options = [option.strip() for option in quantization_options] |
263 | 287 | return quantization_options |
264 | 288 |
|
@@ -312,16 +336,18 @@ def _export_llama(modelname, args) -> str: # noqa: C901 |
312 | 336 | dim = torch.export.Dim("token_dim", max=model.params.max_seq_len - 1) |
313 | 337 | dynamic_shapes = {"tokens": {1: dim}} |
314 | 338 |
|
315 | | - if args.quantized_ckpt or args.quantize: |
| 339 | + if args.quantized_ckpt or args.quantization_mode: |
316 | 340 | modelname = f"{modelname}_q" |
317 | | - model = quantize(model) |
| 341 | + model = quantize(model, args.quantization_mode) |
318 | 342 |
|
319 | 343 | if args.verbose: |
320 | 344 | print(f"{modelname}:") |
321 | 345 | print(f"{model}") |
322 | 346 |
|
323 | 347 | if args.dtype_override is not None: |
324 | | - if args.dtype_override == "fp16" and metadata["get_dtype"] != 5: |
| 348 | + if ( |
| 349 | + args.dtype_override == "fp16" and metadata["get_dtype"] != 5 |
| 350 | + ) or args.quantization_mode == "int4": |
325 | 351 | model.to(dtype=torch.float16) |
326 | 352 | metadata["get_dtype"] = 5 |
327 | 353 | elif args.dtype_override == "fp32" and metadata["get_dtype"] != 6: |
@@ -361,6 +387,12 @@ def _export_llama(modelname, args) -> str: # noqa: C901 |
361 | 387 | edge_manager = edge_manager.to_backend(XnnpackPartitioner()) |
362 | 388 | modelname = f"xnnpack_{modelname}" |
363 | 389 |
|
| 390 | + # TODO: remove this after xnnpack delegation is ready |
| 391 | + if args.quantization_mode == "int4": |
| 392 | + raise Exception( |
| 393 | + "some quantized ops should be lowered to xnnpack, but xnnpack delegate is not ready yet" |
| 394 | + ) |
| 395 | + |
364 | 396 | export_program = edge_manager.to_executorch( |
365 | 397 | ExecutorchBackendConfig( |
366 | 398 | extract_constant_segment=True, |
|
0 commit comments