Skip to content

Commit a943958

Browse files
jerryzh168facebook-github-bot
authored andcommitted
Add int8 per token dynamic activaiton quant and int4 weight quant for llama2 in executorch (pytorch#1904)
Summary: Pull Request resolved: pytorch#1904 representation we are getting now: https://www.internalfb.com/intern/everpaste/?handle=GEIHRRnpyYOEAIUBAFtHZapvTH5xbsIXAAAB Reviewed By: kimishpatel Differential Revision: D53211239 fbshipit-source-id: 255f87e44079877fa70afe65fa6f0c512f06d213
1 parent 5d4d0ca commit a943958

File tree

4 files changed

+663
-70
lines changed

4 files changed

+663
-70
lines changed

examples/models/llama2/export_llama_lib.py

Lines changed: 48 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,12 @@
3838
from ...portable.utils import export_to_edge, save_pte_program
3939
from ..model_factory import EagerModelFactory
4040
from .model import ModelArgs
41-
from .quantize import EmbeddingOnlyInt8QuantHandler, WeightOnlyInt8QuantHandler
41+
from .quantize import (
42+
EmbeddingOnlyInt8QuantHandler,
43+
Int8DynActInt4WeightQuantHandler,
44+
WeightOnlyInt8QuantHandler,
45+
)
46+
4247

4348
IS_FBCODE = True # os.environ.get("FBCODE_PLATFORM", False)
4449
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
@@ -104,19 +109,30 @@ def apply_pt2e_quantization(
104109
return m
105110

106111

107-
def quantize(model) -> torch.nn.Module:
112+
def quantize(model: torch.nn.Module, qmode: str) -> torch.nn.Module:
108113
"""
109114
Quantizes a model by converting all weights to int8.
110115
Args:
111116
model: A model to quantize.
117+
qmode: quantization mode, e.g. int8, int4
112118
Returns:
113119
A quantized model.
114120
"""
115-
model_int8 = WeightOnlyInt8QuantHandler(model)
116-
model_int8_state_dict = model_int8.create_quantized_state_dict()
117-
model_int8 = model_int8.convert_for_runtime()
118-
model_int8.load_state_dict(model_int8_state_dict)
119-
return model_int8
121+
if qmode == "int8":
122+
model_int8 = WeightOnlyInt8QuantHandler(model)
123+
model_int8_state_dict = model_int8.create_quantized_state_dict()
124+
model_int8 = model_int8.convert_for_runtime()
125+
model_int8.load_state_dict(model_int8_state_dict)
126+
return model_int8
127+
elif qmode == "int4":
128+
model_int4 = Int8DynActInt4WeightQuantHandler(model)
129+
model_int4_state_dict = model_int4.create_quantized_state_dict()
130+
model_int4 = model_int4.convert_for_runtime()
131+
print("quantized model:", model_int4)
132+
model_int4.load_state_dict(model_int4_state_dict)
133+
return model_int4
134+
else:
135+
raise Exception(f"Unrecognized quantize mode: {qmode}")
120136

121137

122138
def build_model(
@@ -145,13 +161,20 @@ def build_args_parser() -> argparse.ArgumentParser:
145161
parser.add_argument(
146162
"-q", "--quantized_ckpt", default=None, help="quantized checkpoint file"
147163
)
148-
parser.add_argument("-Q", "--quantize", default=None, action="store_true")
149164
parser.add_argument("-E", "--embedding-quantize", default=None, action="store_true")
150165
parser.add_argument(
151-
"--pt2_quantize",
166+
"--pt2e_quantize",
152167
default=None,
153168
help="Use PT2E quantization. Comma separated options. e.g. xnnpack_dynamic, embedding.",
154169
)
170+
parser.add_argument(
171+
"-qmode",
172+
"--quantization_mode",
173+
type=str,
174+
default=None,
175+
choices=["int8", "int4"],
176+
help="type of quantization",
177+
)
155178

156179
parser.add_argument(
157180
"-c",
@@ -181,6 +204,7 @@ def build_args_parser() -> argparse.ArgumentParser:
181204
parser.add_argument(
182205
"-s",
183206
"--so_library",
207+
default=None,
184208
required=False,
185209
help="shared library for quantized operators",
186210
)
@@ -253,12 +277,12 @@ def get_metadata(params: ModelArgs, args: argparse.Namespace) -> Dict[str, Any]:
253277

254278

255279
def _get_quantization_options(args):
256-
if args.pt2_quantize is None:
280+
if args.pt2e_quantize is None:
257281
return []
258-
if args.quantize:
259-
raise ValueError("Cannot specify both --quantize and --pt2_quantize")
282+
if args.quantization_mode:
283+
raise ValueError("Cannot specify both --quantization_mode and --pt2e_quantize")
260284

261-
quantization_options = args.pt2_quantize.split(",")
285+
quantization_options = args.pt2e_quantize.split(",")
262286
quantization_options = [option.strip() for option in quantization_options]
263287
return quantization_options
264288

@@ -312,16 +336,18 @@ def _export_llama(modelname, args) -> str: # noqa: C901
312336
dim = torch.export.Dim("token_dim", max=model.params.max_seq_len - 1)
313337
dynamic_shapes = {"tokens": {1: dim}}
314338

315-
if args.quantized_ckpt or args.quantize:
339+
if args.quantized_ckpt or args.quantization_mode:
316340
modelname = f"{modelname}_q"
317-
model = quantize(model)
341+
model = quantize(model, args.quantization_mode)
318342

319343
if args.verbose:
320344
print(f"{modelname}:")
321345
print(f"{model}")
322346

323347
if args.dtype_override is not None:
324-
if args.dtype_override == "fp16" and metadata["get_dtype"] != 5:
348+
if (
349+
args.dtype_override == "fp16" and metadata["get_dtype"] != 5
350+
) or args.quantization_mode == "int4":
325351
model.to(dtype=torch.float16)
326352
metadata["get_dtype"] = 5
327353
elif args.dtype_override == "fp32" and metadata["get_dtype"] != 6:
@@ -361,6 +387,12 @@ def _export_llama(modelname, args) -> str: # noqa: C901
361387
edge_manager = edge_manager.to_backend(XnnpackPartitioner())
362388
modelname = f"xnnpack_{modelname}"
363389

390+
# TODO: remove this after xnnpack delegation is ready
391+
if args.quantization_mode == "int4":
392+
raise Exception(
393+
"some quantized ops should be lowered to xnnpack, but xnnpack delegate is not ready yet"
394+
)
395+
364396
export_program = edge_manager.to_executorch(
365397
ExecutorchBackendConfig(
366398
extract_constant_segment=True,

examples/models/llama2/llama_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ class LlamaTest(unittest.TestCase):
1313
def test_quantized_llama(self):
1414
_ = build_model(
1515
modelname="model",
16-
extra_opts="--fairseq2 -Q",
16+
extra_opts="--fairseq2 -qmode int8",
1717
par_local_output=True,
1818
resource_pkg_name=__name__,
1919
)

examples/models/llama2/model.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,12 @@ def __init__(self, args: ModelArgs):
174174
self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
175175
self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
176176

177-
mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
177+
mask = torch.full(
178+
(1, 1, args.max_seq_len, args.max_seq_len),
179+
float("-inf"),
180+
dtype=torch.float16,
181+
)
182+
178183
mask = torch.triu(mask, diagonal=1)
179184
self.register_buffer("mask", mask)
180185

@@ -546,6 +551,12 @@ def __init__(self, **kwargs):
546551

547552
simple_quantizer = WeightOnlyInt8QuantHandler(self.model_)
548553
self.model_ = simple_quantizer.convert_for_runtime()
554+
elif "int4" in str(checkpoint_path):
555+
print("Using int4 weight-only quantization!")
556+
from .quantize import Int8DynActInt4WeightQuantHandler
557+
558+
simple_quantizer = INt8dynactint4weightquanthandler(self.model_)
559+
self.model_ = simple_quantizer.convert_for_runtime()
549560

550561
self.model_.load_state_dict(
551562
checkpoint, strict=False

0 commit comments

Comments
 (0)