Skip to content

Commit 6af6a07

Browse files
chenfucnshamaksx
authored andcommitted
Adding quantization example for gpt-2 medium (microsoft#140)
add gpt2 qdq example
1 parent 0504e80 commit 6af6a07

File tree

7 files changed

+415
-3
lines changed

7 files changed

+415
-3
lines changed

quantization/image_classification/cpu/ReadMe.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ Pre-processing prepares a float32 model for quantization. Run the following comm
1313
model `mobilenetv2-7.onnx`.
1414

1515
```console
16-
python -m onnxruntime.quantization.shape_inference --input mobilenetv2-7.onnx --output mobilenetv2-7-infer.onnx
16+
python -m onnxruntime.quantization.preprocess --input mobilenetv2-7.onnx --output mobilenetv2-7-infer.onnx
1717
```
1818

1919
The pre-processing consists of the following optional steps
@@ -30,7 +30,7 @@ merged Convolution + BatchNormalization node.
3030
It is highly recommended to run model optimization in pre-processing instead of in quantization.
3131
To learn more about each of these steps and finer controls, run:
3232
```console
33-
python -m onnxruntime.quantization.shape_inference --help
33+
python -m onnxruntime.quantization.preprocess --help
3434
```
3535

3636
## Quantization
@@ -76,7 +76,7 @@ For instance, you have a model `abc_float32_model.onnx`, and a quantized model
7676
by default. You can run the following code to produce an optimized float32 model:
7777

7878
```console
79-
python -m onnxruntime.quantization.shape_inference --input abc_float32_model.onnx --output abc_optimized.onnx --skip_symbolic_shape True
79+
python -m onnxruntime.quantization.preprocess --input abc_float32_model.onnx --output abc_optimized.onnx --skip_symbolic_shape True
8080
```
8181

8282
Then run the debugger comparing `abc_optimized.onnx` with `abc_quantized.onnx`.
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# Example of GPT-2-medium Quantization Example
2+
3+
This folder contains example code for quantizing GPT2-medium model. This is by an large similar to
4+
[this example](https://github.com/microsoft/onnxruntime-inference-examples/tree/main/quantization/image_classification/cpu).
5+
6+
## Obtaining the 32-bit floating point model
7+
8+
ONNX Runtime provides tools for converting GPT2 models to ONNX, run:
9+
10+
```console
11+
python -m onnxruntime.transformers.models.gpt2.convert_to_onnx -m gpt2-medium --output gpt2_medium_fp32.onnx -o -p fp32
12+
```
13+
14+
15+
## Preparing the floating point model for quantization
16+
17+
Here we pre-process the model, essentially run shape inferences and model optimization, both of
18+
which may improve the performance of quantization.
19+
20+
```console
21+
python -m onnxruntime.quantization.preprocess --input gpt2_medium_fp32.onnx --output gpt2_medium_fp32_preprocessed.onnx
22+
```
23+
24+
## Quantize
25+
26+
We use static quantization here, for which a calibration data set is required. You can run
27+
`generate_inputs.py` to generate random dummy input for gpt-2 medium. See the python source
28+
code for finer control options
29+
30+
31+
With calibration data set, run the following command to invoke the quantization tool, which
32+
will run the model with provided data set, compute quantization parameters for each
33+
weight and activation tensors, and output the quantized model:
34+
35+
```console
36+
python run_qdq.py --input_model gpt2_medium_fp32_preprocessed.onnx --output_model gpt2_medium_quant.onnx --calibrate_dataset ./test_input
37+
```
38+
39+
## Quantization Debugging
40+
41+
Python file `run_qdq_debug.py` showcase how to use our quantization debugging API to match up
42+
corresponding weight/activation tensors between floating point and quantized models. Run
43+
44+
```console
45+
python run_qdq_debug.py --float_model gpt2_medium_fp32_preprocessed.onnx --qdq_model gpt2_medium_quant.onnx --calibrate_dataset ./test_input
46+
```
47+
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
import random
2+
import torch
3+
from transformers import AutoTokenizer
4+
from typing import Sequence, Tuple
5+
6+
EXAMPLE_Text = ["best hotel in bay area", "here is an example of gpt2 model"]
7+
8+
9+
def get_tokenizer(model_name_or_path: str, cache_dir: str):
10+
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, cache_dir=cache_dir)
11+
tokenizer.padding_side = "left"
12+
tokenizer.pad_token = tokenizer.eos_token
13+
return tokenizer
14+
15+
16+
def get_example_inputs(
17+
model_name_or_path: str,
18+
cache_dir: str,
19+
num_attention_heads: int,
20+
num_layer: int,
21+
hidden_size: int,
22+
device: str,
23+
prompt_text: Sequence[str] = EXAMPLE_Text,
24+
):
25+
tokenizer = get_tokenizer(model_name_or_path, cache_dir)
26+
encodings_dict = tokenizer.batch_encode_plus(prompt_text, padding=True)
27+
28+
input_ids = torch.tensor(encodings_dict["input_ids"], dtype=torch.int32)
29+
attention_mask = torch.tensor(encodings_dict["attention_mask"], dtype=torch.int32)
30+
position_ids = attention_mask.long().cumsum(-1) - 1
31+
position_ids.masked_fill_(position_ids < 0, 0)
32+
position_ids = position_ids.to(torch.int32)
33+
34+
# Empty Past State for generating first word
35+
empty_past = []
36+
batch_size = input_ids.size(0)
37+
sequence_length = input_ids.size(1)
38+
past_shape = [
39+
2,
40+
batch_size,
41+
num_attention_heads,
42+
0,
43+
hidden_size // num_attention_heads,
44+
]
45+
for i in range(num_layer):
46+
empty_past.append(torch.empty(past_shape).type(torch.float32).to(device))
47+
48+
return input_ids, attention_mask, position_ids, empty_past
49+
50+
51+
def get_dummy_inputs(
52+
batch_size: int,
53+
past_sequence_length: int,
54+
sequence_length: int,
55+
num_attention_heads: int,
56+
hidden_size: int,
57+
num_layer: int,
58+
vocab_size: int,
59+
device: torch.device,
60+
has_position_ids: bool = True,
61+
has_attention_mask: bool = True,
62+
input_ids_dtype: torch.dtype = torch.int64,
63+
position_ids_dtype: torch.dtype = torch.int64,
64+
attention_mask_dtype: torch.dtype = torch.int64,
65+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
66+
"""Create random inputs for GPT2 model.
67+
Returns torch tensors of input_ids, position_ids, attention_mask and a list of past state tensors.
68+
"""
69+
past_shape = [
70+
2,
71+
batch_size,
72+
num_attention_heads,
73+
past_sequence_length,
74+
int(hidden_size / num_attention_heads),
75+
]
76+
77+
past = [
78+
(torch.rand(past_shape, dtype=torch.float32, device=device) * 2.0 - 1.0)
79+
for _ in range(num_layer)
80+
]
81+
input_ids = torch.randint(
82+
low=0,
83+
high=vocab_size - 1,
84+
size=(batch_size, sequence_length),
85+
dtype=input_ids_dtype,
86+
device=device,
87+
)
88+
89+
attention_mask = None
90+
if has_attention_mask:
91+
total_sequence_length = past_sequence_length + sequence_length
92+
attention_mask = torch.ones(
93+
[batch_size, total_sequence_length],
94+
dtype=attention_mask_dtype,
95+
device=device,
96+
)
97+
if total_sequence_length >= 2:
98+
padding_position = random.randint(
99+
0, total_sequence_length - 1
100+
) # test input with padding.
101+
attention_mask[:, padding_position] = 0
102+
103+
# Deduce position_ids from attention mask
104+
position_ids = None
105+
if has_position_ids:
106+
position_ids = attention_mask.long().cumsum(-1) - 1
107+
position_ids.masked_fill_(position_ids < 0, 0)
108+
position_ids = position_ids[:, past_sequence_length:].to(position_ids_dtype)
109+
110+
return (input_ids, attention_mask, position_ids, past)
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
import argparse
2+
import logging
3+
import numpy
4+
import torch
5+
from pathlib import Path
6+
7+
import data_utils
8+
9+
10+
def get_args():
11+
parser = argparse.ArgumentParser()
12+
parser.add_argument(
13+
"--output_dir",
14+
default="./test_input",
15+
help="Specify the destination folder of randomly generated input data sets.",
16+
)
17+
18+
parser.add_argument(
19+
"--num_batches",
20+
type=int,
21+
choices=range(2, 500),
22+
default=10,
23+
help="Specify how many batches of input data sets to generate.",
24+
)
25+
parser.add_argument("--batch_size", type=int, default=2, help="Input batch size")
26+
parser.add_argument("--past_sequence_length", type=int, default=4)
27+
parser.add_argument("--sequence_length", type=int, default=2)
28+
29+
args = parser.parse_args()
30+
return args
31+
32+
33+
def main():
34+
# Process input parameters and setup model input data reader
35+
args = get_args()
36+
37+
# Prepare output folder for storing input data files
38+
output_folder = Path(args.output_dir)
39+
if not output_folder.exists():
40+
output_folder.mkdir()
41+
elif not output_folder.is_dir():
42+
logging.error(f"File '{str(output_folder)}' exists and is not a folder!")
43+
return
44+
45+
# Generate num_batches sets of input data
46+
num_batches = 1 if args.num_batches < 1 else args.num_batches
47+
for batch_id in range(num_batches):
48+
data_file = output_folder / f"batch_{batch_id}.npz"
49+
if data_file.exists():
50+
logging.error(
51+
f"File '{data_file}' exists! Can't write generated input data!"
52+
)
53+
return
54+
55+
input_ids, attention_mask, position_ids, past = data_utils.get_dummy_inputs(
56+
batch_size=args.batch_size,
57+
past_sequence_length=args.past_sequence_length,
58+
sequence_length=args.sequence_length,
59+
num_attention_heads=16,
60+
hidden_size=1024,
61+
num_layer=24,
62+
vocab_size=50257,
63+
device="cpu",
64+
has_position_ids=True,
65+
has_attention_mask=True,
66+
input_ids_dtype=torch.int64,
67+
position_ids_dtype=torch.int64,
68+
attention_mask_dtype=torch.int64,
69+
)
70+
ort_inputs = {
71+
"input_ids": numpy.ascontiguousarray(input_ids.cpu().numpy()),
72+
"attention_mask": numpy.ascontiguousarray(attention_mask.cpu().numpy()),
73+
"position_ids": numpy.ascontiguousarray(position_ids.cpu().numpy()),
74+
}
75+
for i, past_i in enumerate(past):
76+
ort_inputs[f"past_{i}"] = numpy.ascontiguousarray(past_i.cpu().numpy())
77+
78+
numpy.savez(str(data_file), **ort_inputs)
79+
80+
81+
if __name__ == "__main__":
82+
main()
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import numpy
2+
from onnxruntime.quantization import CalibrationDataReader
3+
from pathlib import Path
4+
5+
6+
class Gpt2InputReader(CalibrationDataReader):
7+
def __init__(self, data_folder: str):
8+
self.batch_id = 0
9+
self.input_folder = Path(data_folder)
10+
11+
if not self.input_folder.is_dir():
12+
raise RuntimeError(
13+
f"Can't find input data directory: {str(self.input_folder)}"
14+
)
15+
data_file = self.input_folder / f"batch_{self.batch_id}.npz"
16+
if not data_file.exists():
17+
raise RuntimeError(f"No data files found under '{self.input_folder}'")
18+
19+
def get_next(self):
20+
self.input_dict = None
21+
data_file = self.input_folder / f"batch_{self.batch_id}.npz"
22+
if not data_file.exists():
23+
return None
24+
self.batch_id += 1
25+
26+
self.input_dict = {}
27+
npy_file = numpy.load(data_file)
28+
for name in npy_file.files:
29+
self.input_dict[name] = npy_file[name]
30+
31+
return self.input_dict
32+
33+
def rewind(self):
34+
self.batch_id = 0
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import argparse
2+
from onnxruntime.quantization import QuantFormat, QuantType, quantize_static
3+
4+
import gpt2_input_reader
5+
6+
7+
def get_args():
8+
parser = argparse.ArgumentParser()
9+
parser.add_argument(
10+
"--input_model",
11+
default="gpt2_medium_fp32_preprocessed.onnx",
12+
help="Path to float 32 gpt-2 model.",
13+
)
14+
parser.add_argument(
15+
"--output_model", required=False, help="Path to quantized model",
16+
default="gpt2_medium_fp32_quant.onnx"
17+
)
18+
parser.add_argument(
19+
"--calibrate_dataset",
20+
default="./test_input",
21+
help="Specify the destination folder of input data sets.",
22+
)
23+
args = parser.parse_args()
24+
return args
25+
26+
27+
def main():
28+
args = get_args()
29+
input_model_path = args.input_model
30+
output_model_path = args.output_model
31+
if not output_model_path:
32+
output_model_path = (
33+
input_model_path[: -len(".onnx")]
34+
if input_model_path.endswith(".onnx")
35+
else input_model_path
36+
)
37+
output_model_path += "_qdq.onnx"
38+
39+
calibration_dataset_path = args.calibrate_dataset
40+
input_reader = gpt2_input_reader.Gpt2InputReader(calibration_dataset_path)
41+
quantize_static(
42+
input_model_path,
43+
output_model_path,
44+
input_reader,
45+
quant_format=QuantFormat.QDQ,
46+
per_channel=False,
47+
weight_type=QuantType.QInt8,
48+
)
49+
print("Calibrated and quantized model saved.")
50+
51+
52+
if __name__ == "__main__":
53+
main()

0 commit comments

Comments
 (0)